Repository: bytedance/monolith Branch: master Commit: 135c491a52b1 Files: 1324 Total size: 26.8 MB Directory structure: gitextract_ddr9m3f9/ ├── .bazelrc ├── .gitignore ├── LICENSE ├── README.md ├── WORKSPACE ├── conf/ │ └── BUILD ├── deploy/ │ ├── .dockerignore │ ├── .gitignore │ ├── .golangci.yaml │ ├── Dockerfile │ ├── Makefile │ ├── PROJECT │ ├── README.md │ ├── api/ │ │ └── v1/ │ │ ├── groupversion_info.go │ │ ├── mlservice_types.go │ │ └── zz_generated.deepcopy.go │ ├── build.sh │ ├── config/ │ │ ├── crd/ │ │ │ ├── bases/ │ │ │ │ └── mlplatform.volcengine.com_mlservices.yaml │ │ │ ├── kustomization.yaml │ │ │ ├── kustomizeconfig.yaml │ │ │ └── patches/ │ │ │ ├── cainjection_in_mlservices.yaml │ │ │ └── webhook_in_mlservices.yaml │ │ ├── default/ │ │ │ ├── kustomization.yaml │ │ │ ├── manager_auth_proxy_patch.yaml │ │ │ └── manager_config_patch.yaml │ │ ├── manager/ │ │ │ ├── controller_manager_config.yaml │ │ │ ├── kustomization.yaml │ │ │ └── manager.yaml │ │ ├── prometheus/ │ │ │ ├── kustomization.yaml │ │ │ └── monitor.yaml │ │ ├── rbac/ │ │ │ ├── auth_proxy_client_clusterrole.yaml │ │ │ ├── auth_proxy_role.yaml │ │ │ ├── auth_proxy_role_binding.yaml │ │ │ ├── auth_proxy_service.yaml │ │ │ ├── kustomization.yaml │ │ │ ├── leader_election_role.yaml │ │ │ ├── leader_election_role_binding.yaml │ │ │ ├── mlservice_editor_role.yaml │ │ │ ├── mlservice_viewer_role.yaml │ │ │ ├── role.yaml │ │ │ ├── role_binding.yaml │ │ │ └── service_account.yaml │ │ └── samples/ │ │ └── mlplatform_v1_mlservice.yaml │ ├── controllers/ │ │ ├── constants.go │ │ ├── deployment_handler.go │ │ ├── mlservice_controller.go │ │ ├── service_handler.go │ │ └── status.go │ ├── go.mod │ ├── go.sum │ ├── hack/ │ │ └── boilerplate.go.txt │ ├── main.go │ └── serving/ │ ├── agent.conf │ ├── docker/ │ │ ├── Dockerfile │ │ ├── assets/ │ │ │ ├── Python-3.8.6.tar.xz │ │ │ ├── bashrc │ │ │ ├── build.sh │ │ │ ├── configurator_dumpenv.service │ │ │ ├── configurator_dumpenv.sh │ │ │ ├── pip.conf │ │ │ ├── rdma/ │ │ │ │ ├── ibverbs-providers_50mlnx1-1.50100.0_amd64.deb │ │ │ │ ├── ibverbs-utils_50mlnx1-1.50100.0_amd64.deb │ │ │ │ ├── libibumad3_50mlnx1-1.50100.0_amd64.deb │ │ │ │ ├── libibverbs-dev_50mlnx1-1.50100.0_amd64.deb │ │ │ │ ├── libibverbs1_50mlnx1-1.50100.0_amd64.deb │ │ │ │ ├── librdmacm1_50mlnx1-1.50100.0_amd64.deb │ │ │ │ └── rdmacm-utils_50mlnx1-1.50100.0_amd64.deb │ │ │ └── requirements.txt │ │ └── run │ ├── open_source_serving.sh │ └── scripts/ │ ├── build_serving.sh │ └── run_server ├── idl/ │ ├── BUILD │ └── matrix/ │ ├── compression/ │ │ ├── compression.cc │ │ ├── compression.h │ │ ├── compression_qtz8mm.cc │ │ ├── compression_qtz8mm.h │ │ └── float16.h │ └── proto/ │ ├── example.proto │ ├── feature.proto │ ├── line_id.proto │ └── proto_parser.proto ├── markdown/ │ ├── demo/ │ │ ├── AWS-EKS.md │ │ ├── BUILD │ │ ├── Batch.md │ │ ├── README.md │ │ ├── Stream.md │ │ ├── demo_local_runner.py │ │ ├── demo_model.py │ │ ├── kafka_producer.py │ │ ├── kafka_receiver.py │ │ ├── kafka_utils/ │ │ │ ├── add_data_topics.sh │ │ │ ├── delete_topics.sh │ │ │ ├── kafka_base.sh │ │ │ └── start_broker.sh │ │ └── ml_dataset.py │ ├── input_and_model_fn.md │ ├── primus_demo/ │ │ ├── README.md │ │ ├── main.sh │ │ ├── monolith.Dockerfile │ │ └── primus_monolith.json │ └── serving.md ├── monolith/ │ ├── BUILD │ ├── __init__.py │ ├── agent_service/ │ │ ├── BUILD │ │ ├── __init__.py │ │ ├── agent.conf │ │ ├── agent.py │ │ ├── agent_base.py │ │ ├── agent_client.py │ │ ├── agent_controller.py │ │ ├── agent_controller_test.py │ │ ├── agent_service.proto │ │ ├── agent_service.py │ │ ├── agent_service_test.py │ │ ├── agent_v1.py │ │ ├── agent_v3.py │ │ ├── agent_v3_test.py │ │ ├── backends.py │ │ ├── backends_test.py │ │ ├── client.py │ │ ├── constants.py │ │ ├── data_def.py │ │ ├── data_def_test.py │ │ ├── example_batch.pbtxt │ │ ├── mocked_tfserving.py │ │ ├── mocked_tfserving_test.py │ │ ├── mocked_zkclient.py │ │ ├── mocked_zkclient_test.py │ │ ├── model_manager.py │ │ ├── model_manager_test.py │ │ ├── profile.sh │ │ ├── replica_manager.py │ │ ├── replica_manager_test.py │ │ ├── resource_utils.py │ │ ├── resource_utils_test.py │ │ ├── run.py │ │ ├── svr_client.py │ │ ├── test_data/ │ │ │ ├── BUILD │ │ │ ├── inst.dump │ │ │ ├── inst.json │ │ │ └── inst.pbtext │ │ ├── tfs_client.py │ │ ├── tfs_client_test.py │ │ ├── tfs_monitor.py │ │ ├── tfs_monitor_test.py │ │ ├── tfs_wrapper.py │ │ ├── utils.py │ │ ├── utils_test.py │ │ ├── zk_mirror.py │ │ └── zk_mirror_test.py │ ├── base_runner.py │ ├── common/ │ │ └── python/ │ │ ├── BUILD │ │ └── mem_profiling.py │ ├── core/ │ │ ├── BUILD │ │ ├── __init__.py │ │ ├── auto_checkpoint_feed_hook.py │ │ ├── base_embedding_host_call.py │ │ ├── base_embedding_host_call_test.py │ │ ├── base_embedding_task.py │ │ ├── base_host_call.py │ │ ├── base_layer.py │ │ ├── base_layer_test.py │ │ ├── base_model_params.py │ │ ├── base_task.py │ │ ├── base_tpu_test.py │ │ ├── core_test_suite.py │ │ ├── dense.py │ │ ├── dense_test.py │ │ ├── feature.py │ │ ├── feature_test.py │ │ ├── host_call.py │ │ ├── hyperparams.py │ │ ├── hyperparams_test.py │ │ ├── mixed_emb_op_comb_nws.py │ │ ├── model.py │ │ ├── model_imports.py │ │ ├── model_registry.py │ │ ├── optimizers.py │ │ ├── py_utils.py │ │ ├── testing_utils.py │ │ ├── tpu_variable.py │ │ ├── util.py │ │ ├── util_test.py │ │ └── variance_scaling.py │ ├── gpu_runner.py │ ├── monolith_workspace.bzl │ ├── native_training/ │ │ ├── BUILD │ │ ├── alert/ │ │ │ ├── BUILD │ │ │ ├── alert.proto │ │ │ ├── alert_manager.py │ │ │ └── alert_manager_test.py │ │ ├── barrier_ops.py │ │ ├── barrier_ops_test.py │ │ ├── basic_restore_hook.py │ │ ├── basic_restore_hook_test.py │ │ ├── clip_ops.py │ │ ├── clip_ops_test.py │ │ ├── cluster_manager.py │ │ ├── cluster_manager_test.py │ │ ├── consul.py │ │ ├── consul_test.py │ │ ├── cpu_sync_training_test.py │ │ ├── cpu_training.py │ │ ├── cpu_training_distributed_test_binary.py │ │ ├── cpu_training_test.py │ │ ├── data/ │ │ │ ├── BUILD │ │ │ ├── __init__.py │ │ │ ├── data_op_config.proto │ │ │ ├── data_ops_test.py │ │ │ ├── data_service_parquet_test.py │ │ │ ├── data_service_test.py │ │ │ ├── datasets.py │ │ │ ├── docker-compose.yaml │ │ │ ├── eager_mode_test.py │ │ │ ├── extract_fid_test.py │ │ │ ├── feature_list.py │ │ │ ├── feature_list_test.py │ │ │ ├── feature_utils.py │ │ │ ├── feature_utils_test.py │ │ │ ├── item_pool_hook.py │ │ │ ├── item_pool_test.py │ │ │ ├── kafka_dataset_test.py │ │ │ ├── kernels/ │ │ │ │ ├── add_action_kernel.cc │ │ │ │ ├── add_label_kernel.cc │ │ │ │ ├── cache_one_dataset_kernel.cc │ │ │ │ ├── cache_one_dataset_kernel.h │ │ │ │ ├── df_resource_kernel.cc │ │ │ │ ├── df_resource_kernel.h │ │ │ │ ├── dynamic_match_file_dataset_kernel.cc │ │ │ │ ├── extract_fid_kernel.cc │ │ │ │ ├── feature_hash.cc │ │ │ │ ├── feature_name_mapper_tf_bridge.cc │ │ │ │ ├── feature_name_mapper_tf_bridge.h │ │ │ │ ├── fill_multi_rank_output_kernel.cc │ │ │ │ ├── filter_by_label_kernel.cc │ │ │ │ ├── gen_fid_mask.cc │ │ │ │ ├── instance_reweight_dataset_kernel.cc │ │ │ │ ├── instance_reweight_dataset_kernel.h │ │ │ │ ├── internal/ │ │ │ │ │ ├── BUILD │ │ │ │ │ ├── arrow_random_access_file.h │ │ │ │ │ ├── cache_mgr.cc │ │ │ │ │ ├── cache_mgr.h │ │ │ │ │ ├── cache_mgr_test.cc │ │ │ │ │ ├── datasource_utils.cc │ │ │ │ │ ├── datasource_utils.h │ │ │ │ │ ├── datasource_utils_test.cc │ │ │ │ │ ├── file_match_split_provider.cc │ │ │ │ │ ├── file_match_split_provider.h │ │ │ │ │ ├── file_match_split_provider_test.cc │ │ │ │ │ ├── label_utils.cc │ │ │ │ │ ├── label_utils.h │ │ │ │ │ ├── label_utils_test.cc │ │ │ │ │ ├── parquet_column_buffer.h │ │ │ │ │ ├── parquet_example_reader.h │ │ │ │ │ ├── relational_utils.h │ │ │ │ │ ├── relational_utils_test.cc │ │ │ │ │ ├── sized_random_access_file.h │ │ │ │ │ ├── uniq_hashtable.h │ │ │ │ │ ├── uniq_hashtable_test.cc │ │ │ │ │ ├── value_filter_by_feature.cc │ │ │ │ │ ├── value_filter_by_feature.h │ │ │ │ │ ├── value_filter_by_line_id.cc │ │ │ │ │ ├── value_filter_by_line_id.h │ │ │ │ │ └── value_filter_test.cc │ │ │ │ ├── item_pool_kernels.cc │ │ │ │ ├── item_pool_kernels.h │ │ │ │ ├── kafka_kernels.cc │ │ │ │ ├── label_normalization_kernel.cc │ │ │ │ ├── label_upper_bound_kernel.cc │ │ │ │ ├── map_id_kernels.cc │ │ │ │ ├── merge_flow_dataset_kernel.cc │ │ │ │ ├── multi_label_gen_kernel.cc │ │ │ │ ├── negative_gen_dataset_kernel.cc │ │ │ │ ├── negative_gen_dataset_kernel.h │ │ │ │ ├── parquet_dataset_kernel.cc │ │ │ │ ├── parse_example_lib.cc │ │ │ │ ├── parse_example_lib.h │ │ │ │ ├── parse_input_data_kernel.cc │ │ │ │ ├── parse_sparse_feature.cc │ │ │ │ ├── parse_sparse_feature.h │ │ │ │ ├── pb_dataset_kernel.cc │ │ │ │ ├── ragged_feature_kernel.cc │ │ │ │ ├── scatter_label_kernel.cc │ │ │ │ ├── split_flow_dataset_kernel.cc │ │ │ │ ├── string_to_variant.cc │ │ │ │ ├── tf_example_to_example_kernel.cc │ │ │ │ ├── transform_dataset_kernel.cc │ │ │ │ ├── transform_dataset_kernel.h │ │ │ │ └── variant_filter_kernel.cc │ │ │ ├── multi_flow_test.py │ │ │ ├── negative_gen_test.py │ │ │ ├── ops/ │ │ │ │ ├── feature_utils_ops.cc │ │ │ │ ├── parse_input_data_ops.cc │ │ │ │ └── pb_dataset_ops.cc │ │ │ ├── parse_sparse_feature_test.py │ │ │ ├── parsers.py │ │ │ ├── test_data/ │ │ │ │ ├── BUILD │ │ │ │ └── mhy.conf │ │ │ ├── tf_example_to_example_test.py │ │ │ ├── training_instance/ │ │ │ │ ├── BUILD │ │ │ │ ├── cc/ │ │ │ │ │ ├── cached_mem_pool.cc │ │ │ │ │ ├── cached_mem_pool.h │ │ │ │ │ ├── cached_mem_pool_test.cc │ │ │ │ │ ├── data_format_options.h │ │ │ │ │ ├── data_read_write_test.cc │ │ │ │ │ ├── data_reader.cc │ │ │ │ │ ├── data_reader.h │ │ │ │ │ ├── data_writer.cc │ │ │ │ │ ├── data_writer.h │ │ │ │ │ ├── fid.h │ │ │ │ │ ├── fid_test.cc │ │ │ │ │ ├── instance_dataset_kernel.cc │ │ │ │ │ ├── instance_dataset_ops.cc │ │ │ │ │ ├── instance_processor.cc │ │ │ │ │ ├── instance_reader.cc │ │ │ │ │ ├── instance_utils.cc │ │ │ │ │ ├── instance_utils.h │ │ │ │ │ ├── instance_utils_test.cc │ │ │ │ │ ├── parse_instance_kernel.cc │ │ │ │ │ ├── parse_instance_lib.cc │ │ │ │ │ ├── parse_instance_lib.h │ │ │ │ │ ├── parse_instance_ops.cc │ │ │ │ │ ├── pb_variant.cc │ │ │ │ │ ├── pb_variant.h │ │ │ │ │ ├── reader_util.cc │ │ │ │ │ ├── reader_util.h │ │ │ │ │ ├── reader_util_test.cc │ │ │ │ │ ├── snappy_inputbuffer.cc │ │ │ │ │ ├── snappy_inputbuffer.h │ │ │ │ │ ├── ue_compress.cc │ │ │ │ │ ├── ue_compress.h │ │ │ │ │ ├── ue_compress_test.cc │ │ │ │ │ ├── zstd_inputbuffer.cc │ │ │ │ │ └── zstd_inputbuffer.h │ │ │ │ └── python/ │ │ │ │ ├── instance_dataset_op.py │ │ │ │ ├── instance_dataset_op_test_stdin.py │ │ │ │ ├── instance_negative_gen_dataset_op_test.py │ │ │ │ ├── parse_instance_ops.py │ │ │ │ ├── parse_instance_ops_test.py │ │ │ │ ├── parser_utils.py │ │ │ │ ├── pb_datasource_ops.py │ │ │ │ └── test_data_utils.py │ │ │ ├── transform/ │ │ │ │ ├── BUILD │ │ │ │ ├── cc/ │ │ │ │ │ ├── transforms.cc │ │ │ │ │ └── transforms.h │ │ │ │ ├── transform_config.proto │ │ │ │ ├── transforms.py │ │ │ │ └── transforms_test.py │ │ │ ├── transform_dataset_test.py │ │ │ └── utils.py │ │ ├── debugging/ │ │ │ ├── BUILD │ │ │ ├── README.md │ │ │ ├── debugging_client.py │ │ │ └── debugging_server.py │ │ ├── demo.py │ │ ├── dense_reload_utils.py │ │ ├── dense_reload_utils_test.py │ │ ├── device_utils.py │ │ ├── device_utils_test.py │ │ ├── distribute/ │ │ │ ├── BUILD │ │ │ ├── distributed_dataset.py │ │ │ ├── distributed_dataset_test.py │ │ │ ├── str_queue.py │ │ │ └── str_queue_test.py │ │ ├── distributed_ps.py │ │ ├── distributed_ps_benchmark.py │ │ ├── distributed_ps_factory.py │ │ ├── distributed_ps_factory_test.py │ │ ├── distributed_ps_sync.py │ │ ├── distributed_ps_sync_test.py │ │ ├── distributed_ps_test.py │ │ ├── distributed_serving_ops.py │ │ ├── distributed_serving_ops_test.py │ │ ├── distribution_ops.py │ │ ├── distribution_ops_benchmark.py │ │ ├── distribution_ops_fused_benchmark.py │ │ ├── distribution_ops_fused_test.py │ │ ├── distribution_ops_test.py │ │ ├── distribution_utils.py │ │ ├── embedding_combiners.py │ │ ├── embedding_combiners_test.py │ │ ├── entry.py │ │ ├── entry_test.py │ │ ├── env_utils.py │ │ ├── env_utils_test.py │ │ ├── estimator.py │ │ ├── estimator_dist_test.py │ │ ├── estimator_mode_test.py │ │ ├── estimator_test.py │ │ ├── feature.py │ │ ├── feature_test.py │ │ ├── feature_utils.py │ │ ├── feature_utils_test.py │ │ ├── file_ops.py │ │ ├── file_ops_test.py │ │ ├── fountain/ │ │ │ ├── BUILD │ │ │ └── README.md │ │ ├── fused_embedding_to_layout_test.py │ │ ├── gen_seq_mask.py │ │ ├── gen_seq_mask_test.py │ │ ├── gflags_utils.py │ │ ├── gflags_utils_test.py │ │ ├── graph_meta.py │ │ ├── graph_utils.py │ │ ├── hash_filter_ops.py │ │ ├── hash_filter_ops_test.py │ │ ├── hash_table_ops.proto │ │ ├── hash_table_ops.py │ │ ├── hash_table_ops_benchmark.py │ │ ├── hash_table_ops_test.py │ │ ├── hash_table_utils.py │ │ ├── hash_table_utils_test.py │ │ ├── hooks/ │ │ │ ├── BUILD │ │ │ ├── ckpt_hooks.proto │ │ │ ├── ckpt_hooks.py │ │ │ ├── ckpt_hooks_test.py │ │ │ ├── ckpt_info.py │ │ │ ├── ckpt_info_test.py │ │ │ ├── controller_hooks.proto │ │ │ ├── controller_hooks.py │ │ │ ├── controller_hooks_test.py │ │ │ ├── feature_engineering_hooks.py │ │ │ ├── hook_utils.py │ │ │ ├── hook_utils_test.py │ │ │ ├── ps_check_hooks.py │ │ │ ├── ps_check_hooks_test.py │ │ │ ├── server/ │ │ │ │ ├── BUILD │ │ │ │ ├── client_lib.py │ │ │ │ ├── constants.py │ │ │ │ ├── server_lib.py │ │ │ │ ├── server_lib_test.py │ │ │ │ └── service.proto │ │ │ ├── session_hooks.py │ │ │ └── session_hooks_test.py │ │ ├── hvd_lib.py │ │ ├── input.py │ │ ├── layers/ │ │ │ ├── BUILD │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── add_bias.py │ │ │ ├── add_bias_test.py │ │ │ ├── advanced_activations.py │ │ │ ├── advanced_activations_test.py │ │ │ ├── agru.py │ │ │ ├── agru_test.py │ │ │ ├── dense.py │ │ │ ├── dense_test.py │ │ │ ├── feature_cross.py │ │ │ ├── feature_cross_test.py │ │ │ ├── feature_seq.py │ │ │ ├── feature_seq_test.py │ │ │ ├── feature_trans.py │ │ │ ├── feature_trans_test.py │ │ │ ├── kernels/ │ │ │ │ ├── feature_insight_kernels.cc │ │ │ │ ├── ffm_kernels.cc │ │ │ │ ├── ffm_kernels.cu.cc │ │ │ │ ├── ffm_kernels.h │ │ │ │ └── fid_counter_kernel.cc │ │ │ ├── layer_ops.py │ │ │ ├── layer_ops_test.py │ │ │ ├── lhuc.py │ │ │ ├── lhuc_test.py │ │ │ ├── logit_correction.py │ │ │ ├── logit_correction_test.py │ │ │ ├── mlp.py │ │ │ ├── mlp_test.py │ │ │ ├── multi_task.py │ │ │ ├── multi_task_test.py │ │ │ ├── norms.py │ │ │ ├── norms_test.py │ │ │ ├── ops/ │ │ │ │ ├── feature_insight_ops.cc │ │ │ │ ├── ffm_ops.cc │ │ │ │ ├── fid_counter_op.cc │ │ │ │ └── nas_ops.cc │ │ │ ├── pooling.py │ │ │ ├── pooling_test.py │ │ │ ├── sparse_nas.py │ │ │ ├── sparse_nas_test.py │ │ │ └── utils.py │ │ ├── learning_rate_functions.py │ │ ├── learning_rate_functions_test.py │ │ ├── logging_ops.py │ │ ├── logging_ops_test.py │ │ ├── losses/ │ │ │ ├── BUILD │ │ │ ├── batch_softmax_loss.py │ │ │ ├── batch_softmax_loss_test.py │ │ │ ├── inbatch_auc_loss.py │ │ │ ├── inbatch_auc_loss_test.py │ │ │ └── ltr_losses.py │ │ ├── metric/ │ │ │ ├── BUILD │ │ │ ├── cli.py │ │ │ ├── deep_insight_ops.py │ │ │ ├── deep_insight_ops_test.py │ │ │ ├── exit_hook.py │ │ │ ├── kafka_utils.py │ │ │ ├── metric_hook.py │ │ │ ├── metric_hook_test.py │ │ │ ├── utils.py │ │ │ └── utils_test.py │ │ ├── mlp_utils.py │ │ ├── model.py │ │ ├── model_comp_test.py │ │ ├── model_dump/ │ │ │ ├── BUILD │ │ │ ├── dump_utils.py │ │ │ ├── graph_utils.py │ │ │ ├── graph_utils_test.py │ │ │ └── monolith_model.proto │ │ ├── model_export/ │ │ │ ├── BUILD │ │ │ ├── __init__.py │ │ │ ├── data_gen_utils.py │ │ │ ├── data_gen_utils_test.py │ │ │ ├── demo_export.py │ │ │ ├── demo_export_test.py │ │ │ ├── demo_predictor.py │ │ │ ├── demo_predictor_client.py │ │ │ ├── export.proto │ │ │ ├── export_context.py │ │ │ ├── export_hooks.py │ │ │ ├── export_hooks_test.py │ │ │ ├── export_state_utils.py │ │ │ ├── export_state_utils_test.py │ │ │ ├── export_utils.py │ │ │ ├── export_utils_test.py │ │ │ ├── saved_model_exporters.py │ │ │ ├── saved_model_exporters_test.py │ │ │ ├── saved_model_visulizer.py │ │ │ ├── testdata/ │ │ │ │ ├── BUILD │ │ │ │ └── saved_model/ │ │ │ │ ├── entry/ │ │ │ │ │ └── 1622716114/ │ │ │ │ │ └── variables/ │ │ │ │ │ ├── variables.data-00000-of-00001 │ │ │ │ │ └── variables.index │ │ │ │ ├── ps_0/ │ │ │ │ │ └── 1622716114/ │ │ │ │ │ ├── assets/ │ │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_0-00000-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_0-00001-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_0-00002-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_0-00003-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_1-00000-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_1-00001-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_1-00002-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_1-00003-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_2-00000-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_2-00001-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_2-00002-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_2-00003-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_3-00000-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_3-00001-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_3-00002-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_3-00003-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_4-00000-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_4-00001-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_4-00002-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_4-00003-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_0-00000-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_0-00001-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_0-00002-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_0-00003-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_1-00000-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_1-00001-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_1-00002-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_1-00003-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_2-00000-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_2-00001-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_2-00002-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_2-00003-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_3-00000-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_3-00001-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_3-00002-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_3-00003-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_4-00000-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_4-00001-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_4-00002-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_4-00003-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_0-00000-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_0-00001-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_0-00002-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_0-00003-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_1-00000-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_1-00001-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_1-00002-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_1-00003-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_2-00000-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_2-00001-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_2-00002-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_2-00003-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_3-00000-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_3-00001-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_3-00002-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_3-00003-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_4-00000-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_4-00001-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_4-00002-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_4-00003-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_0-00000-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_0-00001-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_0-00002-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_0-00003-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_1-00000-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_1-00001-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_1-00002-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_1-00003-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_2-00000-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_2-00001-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_2-00002-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_2-00003-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_3-00000-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_3-00001-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_3-00002-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_3-00003-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_4-00000-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_4-00001-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_4-00002-of-00004 │ │ │ │ │ │ └── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_4-00003-of-00004 │ │ │ │ │ └── variables/ │ │ │ │ │ ├── variables.data-00000-of-00001 │ │ │ │ │ └── variables.index │ │ │ │ ├── ps_1/ │ │ │ │ │ └── 1622716114/ │ │ │ │ │ ├── assets/ │ │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_0-00000-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_0-00001-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_0-00002-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_0-00003-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_1-00000-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_1-00001-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_1-00002-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_1-00003-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_2-00000-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_2-00001-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_2-00002-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_2-00003-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_3-00000-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_3-00001-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_3-00002-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_3-00003-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_4-00000-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_4-00001-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_4-00002-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_4-00003-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_0-00000-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_0-00001-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_0-00002-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_0-00003-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_1-00000-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_1-00001-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_1-00002-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_1-00003-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_2-00000-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_2-00001-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_2-00002-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_2-00003-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_3-00000-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_3-00001-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_3-00002-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_3-00003-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_4-00000-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_4-00001-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_4-00002-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_4-00003-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_0-00000-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_0-00001-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_0-00002-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_0-00003-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_1-00000-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_1-00001-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_1-00002-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_1-00003-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_2-00000-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_2-00001-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_2-00002-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_2-00003-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_3-00000-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_3-00001-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_3-00002-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_3-00003-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_4-00000-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_4-00001-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_4-00002-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_4-00003-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_0-00000-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_0-00001-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_0-00002-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_0-00003-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_1-00000-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_1-00001-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_1-00002-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_1-00003-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_2-00000-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_2-00001-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_2-00002-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_2-00003-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_3-00000-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_3-00001-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_3-00002-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_3-00003-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_4-00000-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_4-00001-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_4-00002-of-00004 │ │ │ │ │ │ └── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_4-00003-of-00004 │ │ │ │ │ └── variables/ │ │ │ │ │ ├── variables.data-00000-of-00001 │ │ │ │ │ └── variables.index │ │ │ │ ├── ps_2/ │ │ │ │ │ └── 1622716114/ │ │ │ │ │ ├── assets/ │ │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_0-00000-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_0-00001-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_0-00002-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_0-00003-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_1-00000-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_1-00001-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_1-00002-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_1-00003-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_2-00000-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_2-00001-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_2-00002-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_2-00003-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_3-00000-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_3-00001-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_3-00002-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_3-00003-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_4-00000-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_4-00001-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_4-00002-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_4-00003-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_0-00000-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_0-00001-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_0-00002-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_0-00003-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_1-00000-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_1-00001-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_1-00002-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_1-00003-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_2-00000-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_2-00001-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_2-00002-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_2-00003-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_3-00000-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_3-00001-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_3-00002-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_3-00003-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_4-00000-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_4-00001-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_4-00002-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_4-00003-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_0-00000-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_0-00001-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_0-00002-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_0-00003-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_1-00000-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_1-00001-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_1-00002-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_1-00003-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_2-00000-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_2-00001-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_2-00002-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_2-00003-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_3-00000-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_3-00001-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_3-00002-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_3-00003-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_4-00000-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_4-00001-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_4-00002-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_4-00003-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_0-00000-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_0-00001-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_0-00002-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_0-00003-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_1-00000-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_1-00001-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_1-00002-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_1-00003-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_2-00000-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_2-00001-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_2-00002-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_2-00003-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_3-00000-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_3-00001-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_3-00002-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_3-00003-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_4-00000-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_4-00001-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_4-00002-of-00004 │ │ │ │ │ │ └── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_4-00003-of-00004 │ │ │ │ │ └── variables/ │ │ │ │ │ ├── variables.data-00000-of-00001 │ │ │ │ │ └── variables.index │ │ │ │ ├── ps_3/ │ │ │ │ │ └── 1622716114/ │ │ │ │ │ ├── assets/ │ │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_0-00000-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_0-00001-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_0-00002-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_0-00003-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_1-00000-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_1-00001-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_1-00002-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_1-00003-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_2-00000-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_2-00001-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_2-00002-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_2-00003-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_3-00000-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_3-00001-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_3-00002-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_3-00003-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_4-00000-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_4-00001-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_4-00002-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_4-00003-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_0-00000-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_0-00001-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_0-00002-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_0-00003-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_1-00000-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_1-00001-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_1-00002-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_1-00003-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_2-00000-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_2-00001-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_2-00002-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_2-00003-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_3-00000-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_3-00001-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_3-00002-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_3-00003-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_4-00000-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_4-00001-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_4-00002-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_4-00003-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_0-00000-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_0-00001-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_0-00002-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_0-00003-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_1-00000-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_1-00001-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_1-00002-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_1-00003-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_2-00000-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_2-00001-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_2-00002-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_2-00003-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_3-00000-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_3-00001-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_3-00002-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_3-00003-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_4-00000-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_4-00001-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_4-00002-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_4-00003-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_0-00000-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_0-00001-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_0-00002-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_0-00003-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_1-00000-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_1-00001-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_1-00002-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_1-00003-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_2-00000-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_2-00001-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_2-00002-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_2-00003-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_3-00000-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_3-00001-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_3-00002-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_3-00003-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_4-00000-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_4-00001-of-00004 │ │ │ │ │ │ ├── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_4-00002-of-00004 │ │ │ │ │ │ └── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_4-00003-of-00004 │ │ │ │ │ └── variables/ │ │ │ │ │ ├── variables.data-00000-of-00001 │ │ │ │ │ └── variables.index │ │ │ │ └── ps_4/ │ │ │ │ └── 1622716114/ │ │ │ │ ├── assets/ │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_0-00000-of-00004 │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_0-00001-of-00004 │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_0-00002-of-00004 │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_0-00003-of-00004 │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_1-00000-of-00004 │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_1-00001-of-00004 │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_1-00002-of-00004 │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_1-00003-of-00004 │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_2-00000-of-00004 │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_2-00001-of-00004 │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_2-00002-of-00004 │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_2-00003-of-00004 │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_3-00000-of-00004 │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_3-00001-of-00004 │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_3-00002-of-00004 │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_3-00003-of-00004 │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_4-00000-of-00004 │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_4-00001-of-00004 │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_4-00002-of-00004 │ │ │ │ │ ├── MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_4-00003-of-00004 │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_0-00000-of-00004 │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_0-00001-of-00004 │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_0-00002-of-00004 │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_0-00003-of-00004 │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_1-00000-of-00004 │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_1-00001-of-00004 │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_1-00002-of-00004 │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_1-00003-of-00004 │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_2-00000-of-00004 │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_2-00001-of-00004 │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_2-00002-of-00004 │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_2-00003-of-00004 │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_3-00000-of-00004 │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_3-00001-of-00004 │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_3-00002-of-00004 │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_3-00003-of-00004 │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_4-00000-of-00004 │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_4-00001-of-00004 │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_4-00002-of-00004 │ │ │ │ │ ├── MonolithHashTable_3fc25c64637605aa3983374cc61db982_4-00003-of-00004 │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_0-00000-of-00004 │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_0-00001-of-00004 │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_0-00002-of-00004 │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_0-00003-of-00004 │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_1-00000-of-00004 │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_1-00001-of-00004 │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_1-00002-of-00004 │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_1-00003-of-00004 │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_2-00000-of-00004 │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_2-00001-of-00004 │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_2-00002-of-00004 │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_2-00003-of-00004 │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_3-00000-of-00004 │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_3-00001-of-00004 │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_3-00002-of-00004 │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_3-00003-of-00004 │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_4-00000-of-00004 │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_4-00001-of-00004 │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_4-00002-of-00004 │ │ │ │ │ ├── MonolithHashTable_e3997af6324e55640d4611001fa3a15b_4-00003-of-00004 │ │ │ │ │ ├── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_0-00000-of-00004 │ │ │ │ │ ├── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_0-00001-of-00004 │ │ │ │ │ ├── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_0-00002-of-00004 │ │ │ │ │ ├── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_0-00003-of-00004 │ │ │ │ │ ├── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_1-00000-of-00004 │ │ │ │ │ ├── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_1-00001-of-00004 │ │ │ │ │ ├── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_1-00002-of-00004 │ │ │ │ │ ├── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_1-00003-of-00004 │ │ │ │ │ ├── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_2-00000-of-00004 │ │ │ │ │ ├── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_2-00001-of-00004 │ │ │ │ │ ├── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_2-00002-of-00004 │ │ │ │ │ ├── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_2-00003-of-00004 │ │ │ │ │ ├── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_3-00000-of-00004 │ │ │ │ │ ├── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_3-00001-of-00004 │ │ │ │ │ ├── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_3-00002-of-00004 │ │ │ │ │ ├── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_3-00003-of-00004 │ │ │ │ │ ├── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_4-00000-of-00004 │ │ │ │ │ ├── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_4-00001-of-00004 │ │ │ │ │ ├── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_4-00002-of-00004 │ │ │ │ │ └── MonolithHashTable_f6962510869b682fd764009be0e4e9c3_4-00003-of-00004 │ │ │ │ └── variables/ │ │ │ │ ├── variables.data-00000-of-00001 │ │ │ │ └── variables.index │ │ │ ├── warmup_data_decoder.py │ │ │ ├── warmup_data_gen.py │ │ │ └── warmup_example_batch.py │ │ ├── monolith_checkpoint_state.proto │ │ ├── monolith_export.py │ │ ├── multi_hash_table_ops.proto │ │ ├── multi_hash_table_ops.py │ │ ├── multi_hash_table_ops_test.py │ │ ├── multi_type_hash_table.py │ │ ├── multi_type_hash_table_test.py │ │ ├── native_model.py │ │ ├── native_task.py │ │ ├── native_task_context.py │ │ ├── nested_tensors.py │ │ ├── nested_tensors_test.py │ │ ├── net_utils.py │ │ ├── net_utils_test.py │ │ ├── optimizers/ │ │ │ ├── BUILD │ │ │ ├── adamom.py │ │ │ ├── adamom_test.py │ │ │ ├── cc/ │ │ │ │ ├── kernels/ │ │ │ │ │ ├── training_op_helpers.h │ │ │ │ │ ├── training_ops.cc │ │ │ │ │ ├── training_ops.h │ │ │ │ │ └── training_ops_gpu.cu.cc │ │ │ │ └── training_ops.cc │ │ │ ├── rmsprop.py │ │ │ ├── rmsprop_test.py │ │ │ ├── rmspropv2_test.py │ │ │ └── shampoo.py │ │ ├── prefetch_queue.py │ │ ├── prefetch_queue_test.py │ │ ├── proto/ │ │ │ ├── BUILD │ │ │ ├── ckpt_info.proto │ │ │ ├── debugging_info.proto │ │ │ └── primus_am_service.proto │ │ ├── ps_benchmark.py │ │ ├── ps_benchmark_test.py │ │ ├── ragged_utils.py │ │ ├── ragged_utils_test.py │ │ ├── remote_predict_ops.py │ │ ├── restore_test.py │ │ ├── runner_utils.py │ │ ├── runner_utils_test.py │ │ ├── runtime/ │ │ │ ├── allocator/ │ │ │ │ ├── BUILD │ │ │ │ ├── block_allocator.cc │ │ │ │ ├── block_allocator.h │ │ │ │ └── block_allocator_test.cc │ │ │ ├── common/ │ │ │ │ ├── BUILD │ │ │ │ ├── cpu_info.cc │ │ │ │ ├── cpu_info.h │ │ │ │ ├── linalg_utils.h │ │ │ │ ├── linalg_utils_test.cc │ │ │ │ ├── metrics.cc │ │ │ │ ├── metrics.h │ │ │ │ └── metrics_test.cc │ │ │ ├── concurrency/ │ │ │ │ ├── BUILD │ │ │ │ ├── micro_one_bit_spin_lock.h │ │ │ │ ├── queue.h │ │ │ │ ├── queue_test.cc │ │ │ │ ├── random_number_generator_benchmark.cc │ │ │ │ ├── sleeper.h │ │ │ │ ├── thread_pool.cc │ │ │ │ ├── thread_pool.h │ │ │ │ ├── xorshift.cc │ │ │ │ ├── xorshift.h │ │ │ │ └── xorshift_test.cc │ │ │ ├── deep_insight/ │ │ │ │ ├── BUILD │ │ │ │ ├── deep_insight.cc │ │ │ │ ├── deep_insight.h │ │ │ │ └── deep_insight_test.cc │ │ │ ├── hash_filter/ │ │ │ │ ├── BUILD │ │ │ │ ├── dummy_hash_filter.h │ │ │ │ ├── filter.h │ │ │ │ ├── hash_filter.cc │ │ │ │ ├── hash_filter.h │ │ │ │ ├── hash_filter_test.cc │ │ │ │ ├── probabilistic_filter.cc │ │ │ │ ├── probabilistic_filter.h │ │ │ │ ├── probabilistic_filter_test.cc │ │ │ │ ├── sliding_hash_filter.cc │ │ │ │ ├── sliding_hash_filter.h │ │ │ │ ├── sliding_hash_filter_test.cc │ │ │ │ └── types.h │ │ │ ├── hash_table/ │ │ │ │ ├── BUILD │ │ │ │ ├── compressor/ │ │ │ │ │ ├── BUILD │ │ │ │ │ ├── fake_quantizer.h │ │ │ │ │ ├── fake_quantizer_test.cc │ │ │ │ │ ├── float_compressor.cc │ │ │ │ │ ├── float_compressor.h │ │ │ │ │ ├── float_compressor.proto │ │ │ │ │ ├── float_compressor_test.cc │ │ │ │ │ ├── hash_net_quantizer.h │ │ │ │ │ └── hash_net_quantizer_test.cc │ │ │ │ ├── cuckoohash/ │ │ │ │ │ ├── BUILD │ │ │ │ │ ├── CUCKOO_ORIGINAL_LICENSE │ │ │ │ │ ├── bucket_container.hpp │ │ │ │ │ ├── cuckoo_embedding_hash_table.cc │ │ │ │ │ ├── cuckoo_embedding_hash_table.h │ │ │ │ │ ├── cuckoo_embedding_hash_table_benchmark.cc │ │ │ │ │ ├── cuckoo_embedding_hash_table_test.cc │ │ │ │ │ ├── cuckoohash_config.hpp │ │ │ │ │ ├── cuckoohash_map.hpp │ │ │ │ │ └── cuckoohash_util.hpp │ │ │ │ ├── embedding_hash_table.proto │ │ │ │ ├── embedding_hash_table_factory.cc │ │ │ │ ├── embedding_hash_table_factory.h │ │ │ │ ├── embedding_hash_table_interface.h │ │ │ │ ├── embedding_hash_table_test.h │ │ │ │ ├── entry_accessor.cc │ │ │ │ ├── entry_accessor.h │ │ │ │ ├── entry_accessor_decorator.h │ │ │ │ ├── entry_accessor_test.cc │ │ │ │ ├── entry_defs.h │ │ │ │ ├── entry_defs_test.cc │ │ │ │ ├── hash_table_benchmark.cc │ │ │ │ ├── initializer/ │ │ │ │ │ ├── BUILD │ │ │ │ │ ├── constants_initializer.cc │ │ │ │ │ ├── constants_initializer.h │ │ │ │ │ ├── initializer_combination.cc │ │ │ │ │ ├── initializer_combination.h │ │ │ │ │ ├── initializer_combination_test.cc │ │ │ │ │ ├── initializer_config.proto │ │ │ │ │ ├── initializer_factory.cc │ │ │ │ │ ├── initializer_factory.h │ │ │ │ │ ├── initializer_interface.h │ │ │ │ │ ├── random_uniform_initializer.cc │ │ │ │ │ ├── random_uniform_initializer.h │ │ │ │ │ └── random_uniform_initializer_test.cc │ │ │ │ ├── optimizer/ │ │ │ │ │ ├── BUILD │ │ │ │ │ ├── adadelta_optimizer.cc │ │ │ │ │ ├── adadelta_optimizer.h │ │ │ │ │ ├── adadelta_optimizer_test.cc │ │ │ │ │ ├── adagrad_optimizer.cc │ │ │ │ │ ├── adagrad_optimizer.h │ │ │ │ │ ├── adagrad_optimizer_test.cc │ │ │ │ │ ├── adam_optimizer.cc │ │ │ │ │ ├── adam_optimizer.h │ │ │ │ │ ├── adam_optimizer_test.cc │ │ │ │ │ ├── amsgrad_optimizer.cc │ │ │ │ │ ├── amsgrad_optimizer.h │ │ │ │ │ ├── amsgrad_optimizer_test.cc │ │ │ │ │ ├── avx_benchmark.cc │ │ │ │ │ ├── avx_test.cc │ │ │ │ │ ├── avx_utils.h │ │ │ │ │ ├── batch_softmax_optimizer.cc │ │ │ │ │ ├── batch_softmax_optimizer.h │ │ │ │ │ ├── batch_softmax_optimizer_test.cc │ │ │ │ │ ├── dc_optimizer.cc │ │ │ │ │ ├── dc_optimizer.h │ │ │ │ │ ├── dc_optimizer_test.cc │ │ │ │ │ ├── dynamic_wd_adagrad_optimizer.cc │ │ │ │ │ ├── dynamic_wd_adagrad_optimizer.h │ │ │ │ │ ├── dynamic_wd_adagrad_optimizer_test.cc │ │ │ │ │ ├── dynamic_wd_avx_test.cc │ │ │ │ │ ├── dynamic_wd_avx_utils.h │ │ │ │ │ ├── ftrl_optimizer.cc │ │ │ │ │ ├── ftrl_optimizer.h │ │ │ │ │ ├── ftrl_optimizer_test.cc │ │ │ │ │ ├── group_adagrad_optimizer.cc │ │ │ │ │ ├── group_adagrad_optimizer.h │ │ │ │ │ ├── group_adagrad_optimizer_test.cc │ │ │ │ │ ├── group_ftrl_optimizer.cc │ │ │ │ │ ├── group_ftrl_optimizer.h │ │ │ │ │ ├── group_ftrl_optimizer_test.cc │ │ │ │ │ ├── momentum_optimizer.cc │ │ │ │ │ ├── momentum_optimizer.h │ │ │ │ │ ├── momentum_optimizer_test.cc │ │ │ │ │ ├── moving_average_optimizer.cc │ │ │ │ │ ├── moving_average_optimizer.h │ │ │ │ │ ├── moving_average_optimizer_test.cc │ │ │ │ │ ├── optimizer.proto │ │ │ │ │ ├── optimizer_combination.cc │ │ │ │ │ ├── optimizer_combination.h │ │ │ │ │ ├── optimizer_combination_test.cc │ │ │ │ │ ├── optimizer_decorator.h │ │ │ │ │ ├── optimizer_factory.cc │ │ │ │ │ ├── optimizer_factory.h │ │ │ │ │ ├── optimizer_interface.h │ │ │ │ │ ├── rmsprop_optimizer.cc │ │ │ │ │ ├── rmsprop_optimizer.h │ │ │ │ │ ├── rmsprop_optimizer_test.cc │ │ │ │ │ ├── sgd_optimizer.cc │ │ │ │ │ ├── sgd_optimizer.h │ │ │ │ │ ├── sgd_optimizer_test.cc │ │ │ │ │ ├── stochastic_rounding.cc │ │ │ │ │ ├── stochastic_rounding.h │ │ │ │ │ ├── stochastic_rounding_test.cc │ │ │ │ │ └── test_utils.h │ │ │ │ ├── quantized_entry_accessor.h │ │ │ │ ├── quantized_entry_accessor_test.cc │ │ │ │ ├── retriever/ │ │ │ │ │ ├── BUILD │ │ │ │ │ ├── fake_quant_retriever.cc │ │ │ │ │ ├── fake_quant_retriever.h │ │ │ │ │ ├── fake_quant_retriever_test.cc │ │ │ │ │ ├── hash_net_retriever.cc │ │ │ │ │ ├── hash_net_retriever.h │ │ │ │ │ ├── hash_net_retriever_test.cc │ │ │ │ │ ├── raw_retriever.cc │ │ │ │ │ ├── raw_retriever.h │ │ │ │ │ ├── raw_retriever_test.cc │ │ │ │ │ ├── retriever_base.h │ │ │ │ │ ├── retriever_combination.cc │ │ │ │ │ ├── retriever_combination.h │ │ │ │ │ ├── retriever_combination_test.cc │ │ │ │ │ └── retriever_interface.h │ │ │ │ ├── utils.h │ │ │ │ └── workspace.bzl │ │ │ ├── hopscotch/ │ │ │ │ ├── BUILD │ │ │ │ ├── hopscotch_hash_set.cc │ │ │ │ ├── hopscotch_hash_set.h │ │ │ │ └── hopscotch_hash_set_test.cc │ │ │ ├── ops/ │ │ │ │ ├── BUILD │ │ │ │ ├── agent_heartbeat.cc │ │ │ │ ├── agent_heartbeat.h │ │ │ │ ├── agent_heartbeat_test.cc │ │ │ │ ├── aligned_concat_split.cu.cc │ │ │ │ ├── alloc_utils.h │ │ │ │ ├── clip_by_global_norm.cu.cc │ │ │ │ ├── clip_by_global_norm.h │ │ │ │ ├── clip_by_global_norm_fused.cu.cc │ │ │ │ ├── clip_by_global_norm_op.cc │ │ │ │ ├── deep_insight_client_tf_bridge.h │ │ │ │ ├── deep_insight_ops.cc │ │ │ │ ├── embedding_hash_table_tf_bridge.cc │ │ │ │ ├── embedding_hash_table_tf_bridge.h │ │ │ │ ├── file_metric_writer.cc │ │ │ │ ├── file_metric_writer.h │ │ │ │ ├── file_metric_writer_test.cc │ │ │ │ ├── file_ops.cc │ │ │ │ ├── file_utils.cc │ │ │ │ ├── file_utils.h │ │ │ │ ├── file_utils_test.cc │ │ │ │ ├── fused_embedding_to_layout.cc │ │ │ │ ├── fused_embedding_to_layout.cu.cc │ │ │ │ ├── fused_embedding_to_layout.h │ │ │ │ ├── fused_reorder_by_indices.cc │ │ │ │ ├── gen_monolith_ops.py │ │ │ │ ├── gen_seq_mask.cc │ │ │ │ ├── global_norm.cu.cc │ │ │ │ ├── gpu_multi_hash_table.h │ │ │ │ ├── hash_filter_intercept_gradient_op.cc │ │ │ │ ├── hash_filter_op.cc │ │ │ │ ├── hash_filter_restore_op.cc │ │ │ │ ├── hash_filter_save_op.cc │ │ │ │ ├── hash_filter_tf_bridge.cc │ │ │ │ ├── hash_filter_tf_bridge.h │ │ │ │ ├── hash_table/ │ │ │ │ │ └── misc_ops.cc │ │ │ │ ├── hash_table_lookup_op.cc │ │ │ │ ├── hash_table_op.cc │ │ │ │ ├── hash_table_restore_op.cc │ │ │ │ ├── hash_table_save_op.cc │ │ │ │ ├── hash_table_update_op.cc │ │ │ │ ├── inbatch_auc_loss.cc │ │ │ │ ├── logging_ops.cc │ │ │ │ ├── logging_ops.proto │ │ │ │ ├── map_id_to_embedding.cu.cc │ │ │ │ ├── map_id_to_embedding_op.cc │ │ │ │ ├── multi_hash_table.h │ │ │ │ ├── multi_hash_table_lookup_op.cc │ │ │ │ ├── multi_hash_table_op.cc │ │ │ │ ├── multi_hash_table_save_restore_ops.cc │ │ │ │ ├── multi_hash_table_update_op.cc │ │ │ │ ├── net_utils.cc │ │ │ │ ├── net_utils.h │ │ │ │ ├── net_utils_test.cc │ │ │ │ ├── normalize_merged_split_op.cc │ │ │ │ ├── parameter_sync_ops.cc │ │ │ │ ├── parameter_sync_tf_bridge.cc │ │ │ │ ├── parameter_sync_tf_bridge.h │ │ │ │ ├── prediction_service_grpc.cc │ │ │ │ ├── prediction_service_grpc.h │ │ │ │ ├── reduce_op.cc │ │ │ │ ├── reduce_op.cu.cc │ │ │ │ ├── remote_predict_op.h │ │ │ │ ├── remote_predict_op_grpc.cc │ │ │ │ ├── split_by_indices_op.cc │ │ │ │ ├── static_reshape_op.cc │ │ │ │ ├── touched_key_set_insert_op.cc │ │ │ │ ├── touched_key_set_op.cc │ │ │ │ ├── touched_key_set_steal_op.cc │ │ │ │ ├── touched_key_set_tf_bridge.h │ │ │ │ └── unique_mapping_ops.cc │ │ │ └── parameter_sync/ │ │ │ ├── BUILD │ │ │ ├── dummy_sync_client.h │ │ │ ├── dummy_sync_server.cc │ │ │ ├── dummy_sync_server.h │ │ │ ├── parameter_sync.proto │ │ │ ├── parameter_sync_client.cc │ │ │ ├── parameter_sync_client.h │ │ │ ├── parameter_sync_client_test.cc │ │ │ ├── request_splitter.cc │ │ │ ├── request_splitter.h │ │ │ ├── request_splitter_test.cc │ │ │ ├── sync_client_interface.h │ │ │ ├── sync_client_manager.cc │ │ │ └── sync_client_manager.h │ │ ├── save_utils.py │ │ ├── save_utils_test.py │ │ ├── service_discovery.py │ │ ├── service_discovery_test.py │ │ ├── serving_ps_test.py │ │ ├── session_run_hooks.py │ │ ├── session_run_hooks_test.py │ │ ├── signal_utils.py │ │ ├── signal_utils_test.py │ │ ├── static_reshape_op.py │ │ ├── static_reshape_op_test.py │ │ ├── summary/ │ │ │ ├── BUILD │ │ │ ├── summary_ops.py │ │ │ ├── summary_ops_test.py │ │ │ ├── utils.py │ │ │ └── utils_test.py │ │ ├── sync_hooks.py │ │ ├── sync_hooks_test.py │ │ ├── sync_training_hooks.py │ │ ├── sync_training_hooks_test.py │ │ ├── tensor_utils.py │ │ ├── tensor_utils_test.py │ │ ├── test_utils.py │ │ ├── touched_key_set_ops.py │ │ ├── touched_key_set_ops_test.py │ │ ├── utils.py │ │ ├── utils_test.py │ │ ├── variables.py │ │ ├── variables_test.py │ │ ├── yarn_runtime.py │ │ ├── yarn_runtime_test.py │ │ └── zk_utils.py │ ├── path_utils.py │ ├── tf_serving_workspace.bzl │ ├── tpu_runner.py │ ├── utils.py │ └── utils_test.py └── third_party/ ├── BUILD ├── arrow.BUILD ├── brotli.BUILD ├── bzip2.BUILD ├── cli11/ │ ├── BUILD │ ├── CLI11.hpp │ └── current_version ├── cuCollections.patch ├── cuco.BUILD ├── dpssdk.BUILD ├── eigen3/ │ ├── README.txt │ └── eigen_gcc6.patch ├── gperftools/ │ ├── gperftools.BUILD │ └── gperftools.patch ├── half_sourceforge_net/ │ ├── BUILD │ └── half.hpp ├── jemalloc/ │ ├── VERSION │ └── jeprof ├── kafka.BUILD ├── libdata_java_model_training.BUILD ├── lz4.BUILD ├── msgpack/ │ └── msgpack.BUILD ├── nlohmann/ │ ├── BUILD │ └── json.hpp ├── org_apache_zookeeper/ │ ├── BUILD │ ├── Makefile.in │ ├── config.guess │ ├── config.h.in │ ├── config.sub │ ├── configure │ ├── generated/ │ │ ├── zookeeper.jute.c │ │ └── zookeeper.jute.h │ ├── install-sh │ ├── ltmain.sh │ ├── missing │ ├── zookeeper-client-c.BUILD │ └── zookeeper.bzl ├── org_tensorflow/ │ ├── README.md │ └── tf.patch ├── org_tensorflow_serving/ │ ├── public_tf_serving.patch │ └── support_diff_dim_size_inputs.patch ├── pip_deps/ │ ├── BUILD │ └── requirements.txt ├── rapidjson.BUILD ├── repo.bzl ├── tcmalloc/ │ └── libtcmalloc_minimal.so.4.3.0 ├── thrift.BUILD ├── tlearner_arch.BUILD ├── upb.patch ├── xsimd.BUILD └── zstd.BUILD ================================================ FILE CONTENTS ================================================ ================================================ FILE: .bazelrc ================================================ # Copied from https://github.com/tensorflow/serving/blob/master/.bazelrc # Some entries are commented to fit to ByteDance environment # Options used to build with CUDA. build:cuda --repo_env TF_NEED_CUDA=1 build:cuda --crosstool_top=@local_config_cuda//crosstool:toolchain build:cuda --define=using_cuda=true --define=using_cuda_nvcc=true # build:cuda --action_env=TF_CUDA_COMPUTE_CAPABILITIES="sm_50,sm_60,sm_70,sm_75,compute_80" # Just compile for V100 and T4 and A100 for development: build:cuda --action_env=TF_CUDA_COMPUTE_CAPABILITIES="sm_70,sm_75,compute_80" # Explicitly specify "local" here to avoid sandboxed for local. # Use ./configure to create .monolith_configure.bazelrc to enable build from remote buildfarm. build --spawn_strategy=local build --define=grpc_no_ares=true # Sets the default Apple platform to macOS. build --apple_platform_type=macos build -c opt # LLVM, MLIR and TF require C++14. build --cxxopt=-std=c++14 build --host_cxxopt=-std=c++14 # preventing relocation overflow error build:dbg --copt=-gsplit-dwarf # dbg config, copied from tensorflow v2.4.0 build:dbg -c dbg # for now, disable arm_neon. see: https://github.com/tensorflow/tensorflow/issues/33360 build:dbg --cxxopt -DTF_LITE_DISABLE_X86_NEON # AWS SDK must be compiled in release mode. see: https://github.com/tensorflow/tensorflow/issues/37498 build:dbg --copt -DDEBUG_BUILD # Adding "--cxxopt=-D_GLIBCXX_USE_CXX11_ABI=0" creates parity with TF # compilation options. It also addresses memory use due to # copy-on-write semantics of std::strings of the older ABI. build --cxxopt=-D_GLIBCXX_USE_CXX11_ABI=0 build --experimental_repo_remote_exec ### end ### fetch --experimental_repo_remote_exec query --experimental_repo_remote_exec build --genrule_strategy=local # Make it default to TF2 build --define=tf_api_version=2 build --action_env=TF2_BEHAVIOR=1 # Horovod requires dynamic load on shared object. build --define=framework_shared_object=true # Some optimization config. build --define=open_source_build=true build --define=use_fast_cpp_protos=true build --define=allow_oversize_protos=true build --define=with_xla_support=true # Some native optimizations build --copt=-O3 build --copt=-mavx build --copt=-mavx2 build --copt=-mfma build --copt=-msse4.1 build --copt=-msse4.2 # TF currently relies on some bazel deprecated behavior, removes theses options # once TF fixes bugs. build --noincompatible_remove_legacy_whole_archive --noincompatible_prohibit_aapt1 # Import user configured options (e.g., like building cluster) try-import %workspace%/.monolith_configure.bazelrc # Import bazel 4 compatible options try-import %workspace%/.bazel4-compatible.bazelrc ================================================ FILE: .gitignore ================================================ # Ignore .DS_Store .DS_Store fid_mapping/.DS_Store fid_analysis/.DS_Store # Ignore some test data files fid_mapping/test_data.gz fid_mapping/downloaded/* fid_analysis/downloaded/* # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] # C extensions *.so # Distribution / packaging .Python .env venv/ env/ build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ *.egg-info/ .installed.cfg *.egg # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *,cover # Translations *.mo *.pot # Django stuff: *.log # Sphinx documentation docs/_build/ # PyBuilder target/ # Bazel binary folders /bazel-* /.vscode # Jupyter noteobok .ipynb_checkpoints # go vendor/ ================================================ FILE: LICENSE ================================================ Copyright 2022 ByteDance and/or its affiliates Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ====== Monolith project incorporates components or files from the following third project: === alibaba/x-deeplearning (https://github.com/alibaba/x-deeplearning) SPDX-License-Identifier: Apache-2.0 === Apache ZooKeeper Copyright: Copyright 2009-2022 The Apache Software Foundation SPDX-License-Identifier: Apache-2.0 === apache/singa Copyright: Copyright 2017 The Apache Software Foundation SPDX-License-Identifier: Apache-2.0 === Folly (https://github.com/facebook/folly) SPDX-License-Identifier: Apache-2.0 === Libcuckoo Copyright: SPDX-License-Identifier: Apache-2.0 === Lingvo Copyright: SPDX-License-Identifier: Apache-2.0 === TensorFlow Copyright: SPDX-License-Identifier: Apache-2.0 === tiny-dnn Copyright: All contributions by Taiga Nomi Copyright (c) 2013, Taiga Nomi All rights reserved. All other contributions: Copyright (c) 2013-2016, the respective contributors. All rights reserved. Each contributor holds copyright over their respective contributions. The project versioning (Git) records all such contribution source information. SPDX-License-Identifier: BSD 3-Clause === brotli Copyright (c) 2009, 2010, 2013-2016 by the Brotli Authors. SPDX-License-Identifier: MIT === ================================================ FILE: README.md ================================================ Monolith ## What is it? [Monolith](https://arxiv.org/abs/2209.07663) is a deep learning framework for large scale recommendation modeling. It introduces two important features which are crucial for advanced recommendation system: * collisionless embedding tables guarantees unique represeantion for different id features * real time training captures the latest hotspots and help users to discover new intersts rapidly Monolith is built on the top of TensorFlow and supports batch/real-time training and serving. ## Discussion Group ### Join us at Discord https://discord.gg/QYTDeKxGMX ## Quick start ### Build from source Currently, we only support compilation on the Linux. First, download bazel 3.1.0 ```bash wget https://github.com/bazelbuild/bazel/releases/download/3.1.0/bazel-3.1.0-installer-linux-x86_64.sh && \ chmod +x bazel-3.1.0-installer-linux-x86_64.sh && \ ./bazel-3.1.0-installer-linux-x86_64.sh && \ rm bazel-3.1.0-installer-linux-x86_64.sh ``` Then, prepare a python environment ```bash pip install -U --user pip numpy wheel packaging requests opt_einsum pip install -U --user keras_preprocessing --no-deps ``` Finally, you can build any target in the monolith. For example, ```bash bazel run //monolith/native_training:demo --output_filter=IGNORE_LOGS ``` ### Demo and tutorials There are a tutorial in [markdown/demo](markdown/demo) on how to run distributed async training, and few guides on how to use the `MonolithModel` API [here](markdown). ================================================ FILE: WORKSPACE ================================================ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") http_archive( name = "rules_python", sha256 = "b6d46438523a3ec0f3cead544190ee13223a52f6a6765a29eae7b7cc24cc83a0", url = "https://github.com/bazelbuild/rules_python/releases/download/0.1.0/rules_python-0.1.0.tar.gz", ) http_archive( name = "rules_foreign_cc", sha256 = "c2cdcf55ffaf49366725639e45dedd449b8c3fe22b54e31625eb80ce3a240f1e", strip_prefix = "rules_foreign_cc-0.1.0", url = "https://github.com/bazelbuild/rules_foreign_cc/archive/0.1.0.zip", ) load("@rules_foreign_cc//:workspace_definitions.bzl", "rules_foreign_cc_dependencies") # This sets up some common toolchains for building targets. For more details, please see # https://bazelbuild.github.io/rules_foreign_cc/0.1.0/#rules_foreign_cc_dependencies rules_foreign_cc_dependencies() load("//monolith:monolith_workspace.bzl", "monolith_workspace") monolith_workspace() load("@bazel_tools//tools/build_defs/repo:git.bzl", "git_repository") # This is an unofficial boost build but it is useful. git_repository( name = "com_github_nelhage_rules_boost", commit = "1e3a69bf2d5cd10c34b74f066054cd335d033d71", remote = "https://github.com/nelhage/rules_boost", shallow_since = "1591047380 -0700", ) load("@com_github_nelhage_rules_boost//:boost/boost.bzl", "boost_deps") boost_deps() http_archive( name = "org_tensorflow_serving", patch_args = ["-p1"], patches = [ "//third_party:org_tensorflow_serving/public_tf_serving.patch", "//third_party:org_tensorflow_serving/support_diff_dim_size_inputs.patch", ], sha256 = "8c1a4d31ec7ab041b9302348a01422e21349507c7a6f0974639386c8901b721b", strip_prefix = "serving-2.4.0", url = "https://github.com/tensorflow/serving/archive/2.4.0.tar.gz", ) # To update TensorFlow to a new revision. # 1. Update the 'git_commit' args below to include the new git hash. # 2. Get the sha256 hash of the archive with a command such as... # curl -L https://github.com/tensorflow/tensorflow/archive/.tar.gz | sha256sum # and update the 'sha256' arg with the result. # 3. Request the new archive to be mirrored on mirror.bazel.build for more # reliable downloads. load("@org_tensorflow_serving//tensorflow_serving:repo.bzl", "tensorflow_http_archive") # Tensorflow 2.4.0 tensorflow_http_archive( name = "org_tensorflow", git_commit = "582c8d236cb079023657287c318ff26adb239002", patch = "//third_party:org_tensorflow/tf.patch", sha256 = "9c94bfec7214853750c7cacebd079348046f246ec0174d01cd36eda375117628", ) http_archive( name = "rules_pkg", sha256 = "352c090cc3d3f9a6b4e676cf42a6047c16824959b438895a76c2989c6d7c246a", url = "https://github.com/bazelbuild/rules_pkg/releases/download/0.2.5/rules_pkg-0.2.5.tar.gz", ) load("@rules_pkg//:deps.bzl", "rules_pkg_dependencies") rules_pkg_dependencies() load( "@org_tensorflow//third_party/toolchains/preconfig/generate:archives.bzl", "bazel_toolchains_archive", ) bazel_toolchains_archive() load( "@bazel_toolchains//repositories:repositories.bzl", bazel_toolchains_repositories = "repositories", ) bazel_toolchains_repositories() # START: Upstream TensorFlow dependencies # TensorFlow build depends on these dependencies. # Needs to be in-sync with TensorFlow sources. http_archive( name = "io_bazel_rules_closure", sha256 = "5b00383d08dd71f28503736db0500b6fb4dda47489ff5fc6bed42557c07c6ba9", strip_prefix = "rules_closure-308b05b2419edb5c8ee0471b67a40403df940149", urls = [ "https://storage.googleapis.com/mirror.tensorflow.org/github.com/bazelbuild/rules_closure/archive/308b05b2419edb5c8ee0471b67a40403df940149.tar.gz", "https://github.com/bazelbuild/rules_closure/archive/308b05b2419edb5c8ee0471b67a40403df940149.tar.gz", # 2019-06-13 ], ) http_archive( name = "bazel_skylib", sha256 = "1dde365491125a3db70731e25658dfdd3bc5dbdfd11b840b3e987ecf043c7ca0", urls = [ "https://storage.googleapis.com/mirror.tensorflow.org/github.com/bazelbuild/bazel-skylib/releases/download/0.9.0/bazel_skylib-0.9.0.tar.gz", "https://github.com/bazelbuild/bazel-skylib/releases/download/0.9.0/bazel_skylib-0.9.0.tar.gz", ], ) # https://github.com/bazelbuild/bazel-skylib/releases # END: Upstream TensorFlow dependencies # Please add all new TensorFlow Serving and Archon dependencies in workspace.bzl. load("//monolith:tf_serving_workspace.bzl", "tf_serving_workspace") tf_serving_workspace() load("@com_google_protobuf//:protobuf_deps.bzl", "protobuf_deps") protobuf_deps() # Specify the minimum required bazel version. load("@org_tensorflow//tensorflow:version_check.bzl", "check_bazel_version_at_least") check_bazel_version_at_least("3.0.0") # GPRC deps, required to match TF's. Only after calling tf_serving_workspace() load("@com_github_grpc_grpc//bazel:grpc_deps.bzl", "grpc_deps") grpc_deps() http_archive( name = "upb", patch_args = ["-p1"], patches = ["//third_party:upb.patch"], sha256 = "61d0417abd60e65ed589c9deee7c124fe76a4106831f6ad39464e1525cef1454", strip_prefix = "upb-9effcbcb27f0a665f9f345030188c0b291e32482", url = "https://github.com/protocolbuffers/upb/archive/9effcbcb27f0a665f9f345030188c0b291e32482.tar.gz", ) load("@upb//bazel:repository_defs.bzl", "bazel_version_repository") bazel_version_repository(name = "bazel_version") # Hedron's Compile Commands Extractor for Bazel # https://github.com/hedronvision/bazel-compile-commands-extractor http_archive( name = "hedron_compile_commands", strip_prefix = "bazel-compile-commands-extractor-79f8dcae6b451abb97fe76853c867792ac9ac703", # Replace the commit hash in both places (below) with the latest, rather than using the stale one here. # Even better, set up Renovate and let it do the work for you (see "Suggestion: Updates" in the README). url = "https://github.com/hedronvision/bazel-compile-commands-extractor/archive/79f8dcae6b451abb97fe76853c867792ac9ac703.tar.gz", # When you first run this tool, it'll recommend a sha256 hash to put here with a message like: "DEBUG: Rule 'hedron_compile_commands' indicated that a canonical reproducible form can be obtained by modifying arguments sha256 = ..." ) load("@hedron_compile_commands//:workspace_setup.bzl", "hedron_compile_commands_setup") hedron_compile_commands_setup() http_archive( name = "zstd", build_file = "//third_party:zstd.BUILD", sha256 = "a364f5162c7d1a455cc915e8e3cf5f4bd8b75d09bc0f53965b0c9ca1383c52c8", strip_prefix = "zstd-1.4.4", urls = [ "https://storage.googleapis.com/mirror.tensorflow.org/github.com/facebook/zstd/archive/v1.4.4.tar.gz", "https://github.com/facebook/zstd/archive/v1.4.4.tar.gz", ], ) http_archive( name = "lz4", build_file = "//third_party:lz4.BUILD", sha256 = "658ba6191fa44c92280d4aa2c271b0f4fbc0e34d249578dd05e50e76d0e5efcc", strip_prefix = "lz4-1.9.2", urls = [ "https://storage.googleapis.com/mirror.tensorflow.org/github.com/lz4/lz4/archive/v1.9.2.tar.gz", "https://github.com/lz4/lz4/archive/v1.9.2.tar.gz", ], ) http_archive( name = "kafka", build_file = "//third_party:kafka.BUILD", patch_cmds = [ "rm -f src/win32_config.h", # TODO: Remove the fowllowing once librdkafka issue is resolved. """sed -i.bak '\\|rd_kafka_log(rk,|,/ exceeded);/ s/^/\\/\\//' src/rdkafka_cgrp.c""", ], sha256 = "f7fee59fdbf1286ec23ef0b35b2dfb41031c8727c90ced6435b8cf576f23a656", strip_prefix = "librdkafka-1.5.0", urls = [ "https://storage.googleapis.com/mirror.tensorflow.org/github.com/edenhill/librdkafka/archive/v1.5.0.tar.gz", "https://github.com/edenhill/librdkafka/archive/v1.5.0.tar.gz", ], ) load("//third_party:repo.bzl", "tf_http_archive") tf_http_archive( name = "cuCollections", build_file = "//third_party:cuco.BUILD", patch_file = "//third_party:cuCollections.patch", sha256 = "2e059ea1ae18173c5cc3f00989b114c431af78c674f92e35bed56367a9b8b186", strip_prefix = "cuCollections-1e3c5842c6e212e0bd7de9802af583e53009f4a6", urls = [ "https://github.com/NVIDIA/cuCollections/archive/1e3c5842c6e212e0bd7de9802af583e53009f4a6.zip", "https://github.com/NVIDIA/cuCollections/archive/1e3c5842c6e212e0bd7de9802af583e53009f4a6.zip", ], ) http_archive( name = "arrow", build_file = "//third_party:arrow.BUILD", sha256 = "57e13c62f27b710e1de54fd30faed612aefa22aa41fa2c0c3bacd204dd18a8f3", strip_prefix = "arrow-apache-arrow-7.0.0", urls = [ "https://storage.googleapis.com/mirror.tensorflow.org/github.com/apache/arrow/archive/apache-arrow-7.0.0.tar.gz", "https://github.com/apache/arrow/archive/apache-arrow-7.0.0.tar.gz", ], ) # extra dependencies of arrow begin http_archive( name = "rapidjson", build_file = "//third_party:rapidjson.BUILD", sha256 = "30bd2c428216e50400d493b38ca33a25efb1dd65f79dfc614ab0c957a3ac2c28", strip_prefix = "rapidjson-418331e99f859f00bdc8306f69eba67e8693c55e", urls = [ "https://storage.googleapis.com/mirror.tensorflow.org/github.com/miloyip/rapidjson/archive/418331e99f859f00bdc8306f69eba67e8693c55e.tar.gz", "https://github.com/miloyip/rapidjson/archive/418331e99f859f00bdc8306f69eba67e8693c55e.tar.gz", ], ) http_archive( name = "xsimd", build_file = "//third_party:xsimd.BUILD", sha256 = "21b4700e9ef70f6c9a86952047efd8272317df4e6fee35963de9394fd9c5677f", strip_prefix = "xsimd-8.0.1", urls = [ "https://github.com/xtensor-stack/xsimd/archive/refs/tags/8.0.1.tar.gz", ], ) http_archive( name = "brotli", build_file = "//third_party:brotli.BUILD", sha256 = "4c61bfb0faca87219ea587326c467b95acb25555b53d1a421ffa3c8a9296ee2c", strip_prefix = "brotli-1.0.7", urls = [ "https://storage.googleapis.com/mirror.tensorflow.org/github.com/google/brotli/archive/v1.0.7.tar.gz", "https://github.com/google/brotli/archive/v1.0.7.tar.gz", ], ) http_archive( name = "bzip2", build_file = "//third_party:bzip2.BUILD", sha256 = "ab5a03176ee106d3f0fa90e381da478ddae405918153cca248e682cd0c4a2269", strip_prefix = "bzip2-1.0.8", urls = [ "https://storage.googleapis.com/mirror.tensorflow.org/sourceware.org/pub/bzip2/bzip2-1.0.8.tar.gz", "https://sourceware.org/pub/bzip2/bzip2-1.0.8.tar.gz", ], ) http_archive( name = "thrift", build_file = "//third_party:thrift.BUILD", sha256 = "5da60088e60984f4f0801deeea628d193c33cec621e78c8a43a5d8c4055f7ad9", strip_prefix = "thrift-0.13.0", urls = [ "https://storage.googleapis.com/mirror.tensorflow.org/github.com/apache/thrift/archive/v0.13.0.tar.gz", "https://github.com/apache/thrift/archive/v0.13.0.tar.gz", ], ) # extra dependencies of arrow end ================================================ FILE: conf/BUILD ================================================ package( default_visibility = ["//visibility:public"], ) filegroup( name = "serving", srcs = glob([ "*.properties", "*.conf", "*.cfg", ]), ) ================================================ FILE: deploy/.dockerignore ================================================ # More info: https://docs.docker.com/engine/reference/builder/#dockerignore-file # Ignore all files which are not go type !**/*.go !**/*.mod !**/*.sum ================================================ FILE: deploy/.gitignore ================================================ # Binaries for programs and plugins *.exe *.exe~ *.dll *.so *.dylib bin testbin/* # Test binary, build with `go test -c` *.test # Output of the go coverage tool, specifically when used with LiteIDE *.out # Kubernetes Generated files - skip generated files, except for vendored files !vendor/**/zz_generated.* # editor and IDE paraphernalia .idea *.swp *.swo *~ ================================================ FILE: deploy/.golangci.yaml ================================================ # This file contains all available configuration options # with their default values. # options for analysis running run: # default concurrency is a available CPU number concurrency: 4 # timeout for analysis, e.g. 30s, 5m, default is 1m timeout: 5m # exit code when at least one issue was found, default is 1 issues-exit-code: 1 # include test files or not, default is true tests: false # list of build tags, all linters use it. Default is empty list. build-tags: [] # which dirs to skip: issues from them won't be reported; # can use regexp here: generated.*, regexp is applied on full path; # default value is empty list, but default dirs are skipped independently # from this option's value (see skip-dirs-use-default). # "/" will be replaced by current OS file path separator to properly work # on Windows. skip-dirs: - common/model/api - test # default is true. Enables skipping of directories: # vendor$, third_party$, testdata$, examples$, Godeps$, builtin$ skip-dirs-use-default: true # which files to skip: they will be analyzed, but issues from them # won't be reported. Default value is empty list, but there is # no need to include all autogenerated files, we confidently recognize # autogenerated files. If it's not please let us know. # "/" will be replaced by current OS file path separator to properly work # on Windows. skip-files: - ".*_gen.go" - "k-*.go" # Allow multiple parallel golangci-lint instances running. # If false (default) - golangci-lint acquires file lock on start. allow-parallel-runners: false # output configuration options output: # colored-line-number|line-number|json|tab|checkstyle|code-climate|junit-xml|github-actions # default is "colored-line-number" format: colored-line-number # print lines of code with issue, default is true print-issued-lines: true # print linter name in the end of issue text, default is true print-linter-name: true # make issues output unique by line, default is true uniq-by-line: true # add a prefix to the output file references; default is no prefix path-prefix: "" # sorts results by: filepath, line and column sort-results: false linters-settings: dupl: threshold: 100 funlen: lines: 120 statements: 70 goconst: min-len: 2 min-occurrences: 2 gocritic: enabled-tags: - diagnostic - experimental - opinionated - performance - style disabled-checks: - dupImport # https://github.com/go-critic/go-critic/issues/845 - ifElseChain - octalLiteral - whyNoLint - wrapperFunc gocognit: min-complexity: 15 gomnd: settings: mnd: # don't include the "operation" and "assign" checks: [argument, case, condition, return] govet: check-shadowing: true lll: line-length: 120 misspell: locale: US nolintlint: allow-leading-space: true # don't require machine-readable nolint directives (i.e. with no leading space) allow-unused: false # report any unused nolint directives require-explanation: false # don't require an explanation for nolint directives linters: # please, do not use `enable-all`: it's deprecated and will be removed soon. # inverted configuration with `enable-all` and `disable` is not scalable during updates of golangci-lint disable-all: true enable: - bodyclose - deadcode - dogsled - dupl - errcheck - exportloopref - exhaustive - funlen - gochecknoinits - goconst - gocritic - gofmt - goimports - gomnd - goprintffuncname - gosec - gosimple - govet - ineffassign - lll - misspell - nakedret - noctx - nolintlint - rowserrcheck - staticcheck - structcheck - stylecheck - typecheck - unconvert - unparam - unused - varcheck - whitespace - wrapcheck - gocognit - asciicheck - nestif - sqlclosecheck - prealloc # don't enable: # - scopelint # - gochecknoglobals # - godot # - godox # - goerr113 # - golint # - interfacer # - maligned # - testpackage # - revive # - wsl issues: # Excluding configuration per-path, per-linter, per-text and per-source exclude-rules: - path: _test\.go linters: - gomnd # https://github.com/go-critic/go-critic/issues/926 - linters: - gocritic text: "unnecessaryDefer:" - linters: - stylecheck text: "ST1003:" ================================================ FILE: deploy/Dockerfile ================================================ # Build the manager binary FROM golang:1.15 as builder WORKDIR /workspace # Copy the Go Modules manifests COPY go.mod go.mod COPY go.sum go.sum # cache deps before building and copying source so that we don't need to re-download as much # and so that source changes don't invalidate our downloaded layer RUN go mod download # Copy the go source COPY main.go main.go COPY api/ api/ COPY controllers/ controllers/ # Build RUN CGO_ENABLED=0 GOOS=linux GOARCH=amd64 GO111MODULE=on go build -a -o manager main.go # Use distroless as minimal base image to package the manager binary # Refer to https://github.com/GoogleContainerTools/distroless for more details FROM gcr.io/distroless/static:nonroot WORKDIR / COPY --from=builder /workspace/manager . USER 65532:65532 ENTRYPOINT ["/manager"] ================================================ FILE: deploy/Makefile ================================================ # Image URL to use all building/pushing image targets REGISTRY ?= ml-platform-cn-guilin-boe.cr.volces.com/ml-platform NAME ?= data.monolith.controller-manager # TAG ?= $(shell git describe --always --dirty) TAG ?= b85906ce01ef40a75ba48779efdd4e3f # IMG ?= controller:latest IMG ?= ${REGISTRY}/${NAME}:${TAG} # Produce CRDs that work back to Kubernetes 1.11 (no version conversion) CRD_OPTIONS ?= "crd:trivialVersions=true,preserveUnknownFields=false,generateEmbeddedObjectMeta=true" # Get the currently used golang install path (in GOPATH/bin, unless GOBIN is set) ifeq (,$(shell go env GOBIN)) GOBIN=$(shell go env GOPATH)/bin else GOBIN=$(shell go env GOBIN) endif # Setting SHELL to bash allows bash commands to be executed by recipes. # This is a requirement for 'setup-envtest.sh' in the test target. # Options are set to exit when a recipe line exits non-zero or a piped command fails. SHELL = /usr/bin/env bash -o pipefail .SHELLFLAGS = -ec all: build ##@ General # The help target prints out all targets with their descriptions organized # beneath their categories. The categories are represented by '##@' and the # target descriptions by '##'. The awk commands is responsible for reading the # entire set of makefiles included in this invocation, looking for lines of the # file as xyz: ## something, and then pretty-format the target and help. Then, # if there's a line with ##@ something, that gets pretty-printed as a category. # More info on the usage of ANSI control characters for terminal formatting: # https://en.wikipedia.org/wiki/ANSI_escape_code#SGR_parameters # More info on the awk command: # http://linuxcommand.org/lc3_adv_awk.php help: ## Display this help. @awk 'BEGIN {FS = ":.*##"; printf "\nUsage:\n make \033[36m\033[0m\n"} /^[a-zA-Z_0-9-]+:.*?##/ { printf " \033[36m%-15s\033[0m %s\n", $$1, $$2 } /^##@/ { printf "\n\033[1m%s\033[0m\n", substr($$0, 5) } ' $(MAKEFILE_LIST) ##@ Development manifests: controller-gen ## Generate WebhookConfiguration, ClusterRole and CustomResourceDefinition objects. $(CONTROLLER_GEN) $(CRD_OPTIONS) rbac:roleName=manager-role webhook paths="./..." output:crd:artifacts:config=config/crd/bases generate: controller-gen ## Generate code containing DeepCopy, DeepCopyInto, and DeepCopyObject method implementations. $(CONTROLLER_GEN) object:headerFile="hack/boilerplate.go.txt" paths="./..." fmt: ## Run go fmt against code. go fmt ./... vet: ## Run go vet against code. go vet ./... ENVTEST_ASSETS_DIR=$(shell pwd)/testbin test: manifests generate fmt vet ## Run tests. mkdir -p ${ENVTEST_ASSETS_DIR} test -f ${ENVTEST_ASSETS_DIR}/setup-envtest.sh || curl -sSLo ${ENVTEST_ASSETS_DIR}/setup-envtest.sh https://raw.githubusercontent.com/kubernetes-sigs/controller-runtime/v0.7.2/hack/setup-envtest.sh source ${ENVTEST_ASSETS_DIR}/setup-envtest.sh; fetch_envtest_tools $(ENVTEST_ASSETS_DIR); setup_envtest_env $(ENVTEST_ASSETS_DIR); go test ./... -coverprofile cover.out ##@ Build build: generate fmt vet ## Build manager binary. go build -o bin/manager main.go run: manifests generate fmt vet ## Run a controller from your host. go run ./main.go docker-build: test ## Build docker image with the manager. docker build -t ${IMG} . docker-push: ## Push docker image with the manager. docker push ${IMG} ##@ Deployment install: manifests kustomize ## Install CRDs into the K8s cluster specified in ~/.kube/config. $(KUSTOMIZE) build config/crd | kubectl apply -f - uninstall: manifests kustomize ## Uninstall CRDs from the K8s cluster specified in ~/.kube/config. $(KUSTOMIZE) build config/crd | kubectl delete -f - deploy: manifests kustomize ## Deploy controller to the K8s cluster specified in ~/.kube/config. cd config/manager && $(KUSTOMIZE) edit set image controller=${IMG} $(KUSTOMIZE) build config/default | kubectl apply -f - undeploy: ## Undeploy controller from the K8s cluster specified in ~/.kube/config. $(KUSTOMIZE) build config/default | kubectl delete -f - CONTROLLER_GEN = $(shell pwd)/bin/controller-gen controller-gen: ## Download controller-gen locally if necessary. $(call go-get-tool,$(CONTROLLER_GEN),sigs.k8s.io/controller-tools/cmd/controller-gen@v0.6.1) KUSTOMIZE = $(shell pwd)/bin/kustomize kustomize: ## Download kustomize locally if necessary. $(call go-get-tool,$(KUSTOMIZE),sigs.k8s.io/kustomize/kustomize/v3@v3.8.7) # go-get-tool will 'go get' any package $2 and install it to $1. PROJECT_DIR := $(shell dirname $(abspath $(lastword $(MAKEFILE_LIST)))) define go-get-tool @[ -f $(1) ] || { \ set -e ;\ TMP_DIR=$$(mktemp -d) ;\ cd $$TMP_DIR ;\ go mod init tmp ;\ echo "Downloading $(2)" ;\ GOBIN=$(PROJECT_DIR)/bin go get $(2) ;\ rm -rf $$TMP_DIR ;\ } endef ================================================ FILE: deploy/PROJECT ================================================ domain: volcengine.com layout: - go.kubebuilder.io/v3 projectName: deploy repo: code.byted.org/data/monolith/deploy resources: - api: crdVersion: v1 namespaced: true controller: true domain: volcengine.com group: mlplatform kind: MLService path: code.byted.org/data/monolith/deploy/api/v1 version: v1 version: "3" ================================================ FILE: deploy/README.md ================================================ # 项目介绍 项目初始结构通过kubebuilder(https://github.com/kubernetes-sigs/kubebuilder) 生成 kubebuilder使用文档:https://book.kubebuilder.io/cronjob-tutorial/cronjob-tutorial.html ``` kubebuilder init --domain volcengine.com --repo https://github.com/bytedance/monolith/blob/master/deploy --skip-go-version-check kubebuilder create api --group mlplatform --version v1 --kind MLService --controller --resource ``` # 安装部署 配置好集群kubeconfig后执行以下命令: ``` make deploy # 安装CRD,部署controller ``` ================================================ FILE: deploy/api/v1/groupversion_info.go ================================================ /* Copyright 2023. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ // Package v1 contains API Schema definitions for the mlplatform v1 API group //+kubebuilder:object:generate=true //+groupName=mlplatform.volcengine.com package v1 import ( "k8s.io/apimachinery/pkg/runtime/schema" "sigs.k8s.io/controller-runtime/pkg/scheme" ) var ( // GroupVersion is group version used to register these objects GroupVersion = schema.GroupVersion{Group: "mlplatform.volcengine.com", Version: "v1"} // SchemeBuilder is used to add go types to the GroupVersionKind scheme SchemeBuilder = &scheme.Builder{GroupVersion: GroupVersion} // AddToScheme adds the types in this group-version to the given scheme. AddToScheme = SchemeBuilder.AddToScheme ) ================================================ FILE: deploy/api/v1/mlservice_types.go ================================================ /* Copyright 2023. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package v1 import ( appsv1 "k8s.io/api/apps/v1" corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" ) // NOTE: json tags are required. Any new fields you add must have json tags for the fields to be serialized. // ServicePortType is the data type of ServicePort Type type ServicePortType string const ( ServicePortTypeHttp ServicePortType = "HTTP" ServicePortTypeRpc ServicePortType = "RPC" ServicePortTypeMetrics ServicePortType = "Metrics" ServicePortTypeOther ServicePortType = "Other" ) // DeploymentTemplateSpec defines the metadata and spec of a Deployment type DeploymentTemplateSpec struct { // Standard object metadata. // +optional metav1.ObjectMeta `json:"metadata,omitempty"` // Specification of the desired behavior of the Deployment. Spec appsv1.DeploymentSpec `json:"spec"` } // ServicePort contains information on service's port. type ServicePort struct { // The type of this port within the service. Type ServicePortType `json:"type,omitempty"` // The port that will be exposed by this service. Port int32 `json:"port"` } // ServiceSpec describes the attributes that a user creates on a service. type ServiceSpec struct { // ServiceType defines which type of service need to be created ServiceType corev1.ServiceType `json:"serviceType,omitempty"` // The list of ports that are exposed by this service. // More info: https://kubernetes.io/docs/concepts/services-networking/service/#virtual-ips-and-service-proxies Ports []ServicePort `json:"ports,omitempty"` } // RoleSpec defines the desired state of a role in MLService type RoleSpec struct { // Name of the role Name string `json:"name"` // Number of shards for the role, each shard associated with a Deployment ShardNum int32 `json:"shardNum,omitempty"` // Template of the DeploymentSpec Template DeploymentTemplateSpec `json:"template"` ServiceSpec *ServiceSpec `json:"serviceSpec,omitempty"` } // MLServiceSpec defines the desired state of MLService type MLServiceSpec struct { // selector is a label query over deployment. // It must match the deployment template's labels. Selector *metav1.LabelSelector `json:"selector"` // Roles defines desired state for each role in the service Roles []RoleSpec `json:"roles"` } // ServicePhase is a label for the condition of a MLService at the current time. type ServicePhase string const ( // ServiceQueuing means the service is queuing, waiting to be scheduled ServiceQueuing ServicePhase = "Queuing" // ServiceDeploying means pods of the service are scheduled and being initializing ServiceDeploying ServicePhase = "Deploying" // ServiceRunning means all pods of the service are running ServiceRunning ServicePhase = "Running" // ServiceAbnormal means some pods of the service are abnormal ServiceAbnormal ServicePhase = "Abnormal" // ServiceDeleting means the service is being deleted ServiceDeleting ServicePhase = "Deleting" // ServiceStopping means replicas of the service is being scaled down to 0 ServiceStopping ServicePhase = "Stopping" // ServiceStopped means replicas of the service has been scaled down to 0 ServiceStopped ServicePhase = "Stopped" ) // MLServiceStatus defines the observed state of MLService type MLServiceStatus struct { // Phase is a simple, high-level summary of where the Service is in its lifecycle. // +optional Phase ServicePhase `json:"phase,omitempty"` // RoleShardStatusMap shows the current status for all Deployments. // The key is Deployment name, value is its status info RoleShardStatusMap map[string]appsv1.DeploymentStatus `json:"roleShardStatusMap,omitempty"` // RoleShardStatusMap shows the current status for all Services. // The key is Service name, value is its status info RoleServiceStatusMap map[string]corev1.ServiceStatus `json:"roleServiceStatusMap,omitempty"` // RoleServiceClusterIps shows the cluster ip for all Services. // The key is Service name, value is its clusterIP RoleServiceClusterIps map[string]string `json:"roleServiceClusterIps,omitempty"` // LastTransitionTime is time the last Phase transitioned to current one. // +optional LastTransitionTime metav1.Time `json:"lastTransitionTime,omitempty"` // Unique, one-word, CamelCase reason for the phase's last transition. // +optional Reason string `json:"reason,omitempty"` // Human-readable message indicating details about last transition. // +optional Message string `json:"message,omitempty"` } // +kubebuilder:object:root=true // +kubebuilder:subresource:status // +kubebuilder:printcolumn:name="Age",type=date,JSONPath=`.metadata.creationTimestamp` // +kubebuilder:printcolumn:name="Phase",type=string,JSONPath=`.status.phase` // +kubebuilder:resource:path=mlservices,shortName=mlsvc // MLService is the Schema for the mlservices API type MLService struct { metav1.TypeMeta `json:",inline"` metav1.ObjectMeta `json:"metadata,omitempty"` Spec MLServiceSpec `json:"spec,omitempty"` Status MLServiceStatus `json:"status,omitempty"` } //+kubebuilder:object:root=true // MLServiceList contains a list of MLService type MLServiceList struct { metav1.TypeMeta `json:",inline"` metav1.ListMeta `json:"metadata,omitempty"` Items []MLService `json:"items"` } func init() { SchemeBuilder.Register(&MLService{}, &MLServiceList{}) } ================================================ FILE: deploy/api/v1/zz_generated.deepcopy.go ================================================ //go:build !ignore_autogenerated // +build !ignore_autogenerated /* Copyright 2023. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ // Code generated by controller-gen. DO NOT EDIT. package v1 import ( appsv1 "k8s.io/api/apps/v1" corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" runtime "k8s.io/apimachinery/pkg/runtime" ) // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *DeploymentTemplateSpec) DeepCopyInto(out *DeploymentTemplateSpec) { *out = *in in.ObjectMeta.DeepCopyInto(&out.ObjectMeta) in.Spec.DeepCopyInto(&out.Spec) } // DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new DeploymentTemplateSpec. func (in *DeploymentTemplateSpec) DeepCopy() *DeploymentTemplateSpec { if in == nil { return nil } out := new(DeploymentTemplateSpec) in.DeepCopyInto(out) return out } // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *MLService) DeepCopyInto(out *MLService) { *out = *in out.TypeMeta = in.TypeMeta in.ObjectMeta.DeepCopyInto(&out.ObjectMeta) in.Spec.DeepCopyInto(&out.Spec) in.Status.DeepCopyInto(&out.Status) } // DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new MLService. func (in *MLService) DeepCopy() *MLService { if in == nil { return nil } out := new(MLService) in.DeepCopyInto(out) return out } // DeepCopyObject is an autogenerated deepcopy function, copying the receiver, creating a new runtime.Object. func (in *MLService) DeepCopyObject() runtime.Object { if c := in.DeepCopy(); c != nil { return c } return nil } // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *MLServiceList) DeepCopyInto(out *MLServiceList) { *out = *in out.TypeMeta = in.TypeMeta in.ListMeta.DeepCopyInto(&out.ListMeta) if in.Items != nil { in, out := &in.Items, &out.Items *out = make([]MLService, len(*in)) for i := range *in { (*in)[i].DeepCopyInto(&(*out)[i]) } } } // DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new MLServiceList. func (in *MLServiceList) DeepCopy() *MLServiceList { if in == nil { return nil } out := new(MLServiceList) in.DeepCopyInto(out) return out } // DeepCopyObject is an autogenerated deepcopy function, copying the receiver, creating a new runtime.Object. func (in *MLServiceList) DeepCopyObject() runtime.Object { if c := in.DeepCopy(); c != nil { return c } return nil } // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *MLServiceSpec) DeepCopyInto(out *MLServiceSpec) { *out = *in if in.Selector != nil { in, out := &in.Selector, &out.Selector *out = new(metav1.LabelSelector) (*in).DeepCopyInto(*out) } if in.Roles != nil { in, out := &in.Roles, &out.Roles *out = make([]RoleSpec, len(*in)) for i := range *in { (*in)[i].DeepCopyInto(&(*out)[i]) } } } // DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new MLServiceSpec. func (in *MLServiceSpec) DeepCopy() *MLServiceSpec { if in == nil { return nil } out := new(MLServiceSpec) in.DeepCopyInto(out) return out } // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *MLServiceStatus) DeepCopyInto(out *MLServiceStatus) { *out = *in if in.RoleShardStatusMap != nil { in, out := &in.RoleShardStatusMap, &out.RoleShardStatusMap *out = make(map[string]appsv1.DeploymentStatus, len(*in)) for key, val := range *in { (*out)[key] = *val.DeepCopy() } } if in.RoleServiceStatusMap != nil { in, out := &in.RoleServiceStatusMap, &out.RoleServiceStatusMap *out = make(map[string]corev1.ServiceStatus, len(*in)) for key, val := range *in { (*out)[key] = *val.DeepCopy() } } if in.RoleServiceClusterIps != nil { in, out := &in.RoleServiceClusterIps, &out.RoleServiceClusterIps *out = make(map[string]string, len(*in)) for key, val := range *in { (*out)[key] = val } } in.LastTransitionTime.DeepCopyInto(&out.LastTransitionTime) } // DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new MLServiceStatus. func (in *MLServiceStatus) DeepCopy() *MLServiceStatus { if in == nil { return nil } out := new(MLServiceStatus) in.DeepCopyInto(out) return out } // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *RoleSpec) DeepCopyInto(out *RoleSpec) { *out = *in in.Template.DeepCopyInto(&out.Template) if in.ServiceSpec != nil { in, out := &in.ServiceSpec, &out.ServiceSpec *out = new(ServiceSpec) (*in).DeepCopyInto(*out) } } // DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new RoleSpec. func (in *RoleSpec) DeepCopy() *RoleSpec { if in == nil { return nil } out := new(RoleSpec) in.DeepCopyInto(out) return out } // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *ServicePort) DeepCopyInto(out *ServicePort) { *out = *in } // DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new ServicePort. func (in *ServicePort) DeepCopy() *ServicePort { if in == nil { return nil } out := new(ServicePort) in.DeepCopyInto(out) return out } // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *ServiceSpec) DeepCopyInto(out *ServiceSpec) { *out = *in if in.Ports != nil { in, out := &in.Ports, &out.Ports *out = make([]ServicePort, len(*in)) copy(*out, *in) } } // DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new ServiceSpec. func (in *ServiceSpec) DeepCopy() *ServiceSpec { if in == nil { return nil } out := new(ServiceSpec) in.DeepCopyInto(out) return out } ================================================ FILE: deploy/build.sh ================================================ #!/bin/bash mkdir output cd deploy && make build cp bin/manager ../output ================================================ FILE: deploy/config/crd/bases/mlplatform.volcengine.com_mlservices.yaml ================================================ --- apiVersion: apiextensions.k8s.io/v1 kind: CustomResourceDefinition metadata: annotations: controller-gen.kubebuilder.io/version: v0.6.1 creationTimestamp: null name: mlservices.mlplatform.volcengine.com spec: group: mlplatform.volcengine.com names: kind: MLService listKind: MLServiceList plural: mlservices shortNames: - mlsvc singular: mlservice scope: Namespaced versions: - additionalPrinterColumns: - jsonPath: .metadata.creationTimestamp name: Age type: date - jsonPath: .status.phase name: Phase type: string name: v1 schema: openAPIV3Schema: description: MLService is the Schema for the mlservices API properties: apiVersion: description: 'APIVersion defines the versioned schema of this representation of an object. Servers should convert recognized schemas to the latest internal value, and may reject unrecognized values. More info: https://git.k8s.io/community/contributors/devel/sig-architecture/api-conventions.md#resources' type: string kind: description: 'Kind is a string value representing the REST resource this object represents. Servers may infer this from the endpoint the client submits requests to. Cannot be updated. In CamelCase. More info: https://git.k8s.io/community/contributors/devel/sig-architecture/api-conventions.md#types-kinds' type: string metadata: type: object spec: description: MLServiceSpec defines the desired state of MLService properties: roles: description: Roles defines desired state for each role in the service items: description: RoleSpec defines the desired state of a role in MLService properties: name: description: Name of the role type: string serviceSpec: description: ServiceSpec describes the attributes that a user creates on a service. properties: ports: description: 'The list of ports that are exposed by this service. More info: https://kubernetes.io/docs/concepts/services-networking/service/#virtual-ips-and-service-proxies' items: description: ServicePort contains information on service's port. properties: port: description: The port that will be exposed by this service. format: int32 type: integer type: description: The type of this port within the service. type: string required: - port type: object type: array serviceType: description: ServiceType defines which type of service need to be created type: string type: object shardNum: description: Number of shards for the role, each shard associated with a Deployment format: int32 type: integer template: description: Template of the DeploymentSpec properties: metadata: description: Standard object metadata. properties: annotations: additionalProperties: type: string type: object finalizers: items: type: string type: array labels: additionalProperties: type: string type: object name: type: string namespace: type: string type: object spec: description: Specification of the desired behavior of the Deployment. properties: minReadySeconds: description: Minimum number of seconds for which a newly created pod should be ready without any of its container crashing, for it to be considered available. Defaults to 0 (pod will be considered available as soon as it is ready) format: int32 type: integer paused: description: Indicates that the deployment is paused. type: boolean progressDeadlineSeconds: description: The maximum time in seconds for a deployment to make progress before it is considered to be failed. The deployment controller will continue to process failed deployments and a condition with a ProgressDeadlineExceeded reason will be surfaced in the deployment status. Note that progress will not be estimated during the time a deployment is paused. Defaults to 600s. format: int32 type: integer replicas: description: Number of desired pods. This is a pointer to distinguish between explicit zero and not specified. Defaults to 1. format: int32 type: integer revisionHistoryLimit: description: The number of old ReplicaSets to retain to allow rollback. This is a pointer to distinguish between explicit zero and not specified. Defaults to 10. format: int32 type: integer selector: description: Label selector for pods. Existing ReplicaSets whose pods are selected by this will be the ones affected by this deployment. It must match the pod template's labels. properties: matchExpressions: description: matchExpressions is a list of label selector requirements. The requirements are ANDed. items: description: A label selector requirement is a selector that contains values, a key, and an operator that relates the key and values. properties: key: description: key is the label key that the selector applies to. type: string operator: description: operator represents a key's relationship to a set of values. Valid operators are In, NotIn, Exists and DoesNotExist. type: string values: description: values is an array of string values. If the operator is In or NotIn, the values array must be non-empty. If the operator is Exists or DoesNotExist, the values array must be empty. This array is replaced during a strategic merge patch. items: type: string type: array required: - key - operator type: object type: array matchLabels: additionalProperties: type: string description: matchLabels is a map of {key,value} pairs. A single {key,value} in the matchLabels map is equivalent to an element of matchExpressions, whose key field is "key", the operator is "In", and the values array contains only "value". The requirements are ANDed. type: object type: object strategy: description: The deployment strategy to use to replace existing pods with new ones. properties: rollingUpdate: description: 'Rolling update config params. Present only if DeploymentStrategyType = RollingUpdate. --- TODO: Update this to follow our convention for oneOf, whatever we decide it to be.' properties: maxSurge: anyOf: - type: integer - type: string description: 'The maximum number of pods that can be scheduled above the desired number of pods. Value can be an absolute number (ex: 5) or a percentage of desired pods (ex: 10%). This can not be 0 if MaxUnavailable is 0. Absolute number is calculated from percentage by rounding up. Defaults to 25%. Example: when this is set to 30%, the new ReplicaSet can be scaled up immediately when the rolling update starts, such that the total number of old and new pods do not exceed 130% of desired pods. Once old pods have been killed, new ReplicaSet can be scaled up further, ensuring that total number of pods running at any time during the update is at most 130% of desired pods.' x-kubernetes-int-or-string: true maxUnavailable: anyOf: - type: integer - type: string description: 'The maximum number of pods that can be unavailable during the update. Value can be an absolute number (ex: 5) or a percentage of desired pods (ex: 10%). Absolute number is calculated from percentage by rounding down. This can not be 0 if MaxSurge is 0. Defaults to 25%. Example: when this is set to 30%, the old ReplicaSet can be scaled down to 70% of desired pods immediately when the rolling update starts. Once new pods are ready, old ReplicaSet can be scaled down further, followed by scaling up the new ReplicaSet, ensuring that the total number of pods available at all times during the update is at least 70% of desired pods.' x-kubernetes-int-or-string: true type: object type: description: Type of deployment. Can be "Recreate" or "RollingUpdate". Default is RollingUpdate. type: string type: object template: description: Template describes the pods that will be created. properties: metadata: description: 'Standard object''s metadata. More info: https://git.k8s.io/community/contributors/devel/sig-architecture/api-conventions.md#metadata' properties: annotations: additionalProperties: type: string type: object finalizers: items: type: string type: array labels: additionalProperties: type: string type: object name: type: string namespace: type: string type: object spec: description: 'Specification of the desired behavior of the pod. More info: https://git.k8s.io/community/contributors/devel/sig-architecture/api-conventions.md#spec-and-status' properties: activeDeadlineSeconds: description: Optional duration in seconds the pod may be active on the node relative to StartTime before the system will actively try to mark it failed and kill associated containers. Value must be a positive integer. format: int64 type: integer affinity: description: If specified, the pod's scheduling constraints properties: nodeAffinity: description: Describes node affinity scheduling rules for the pod. properties: preferredDuringSchedulingIgnoredDuringExecution: description: The scheduler will prefer to schedule pods to nodes that satisfy the affinity expressions specified by this field, but it may choose a node that violates one or more of the expressions. The node that is most preferred is the one with the greatest sum of weights, i.e. for each node that meets all of the scheduling requirements (resource request, requiredDuringScheduling affinity expressions, etc.), compute a sum by iterating through the elements of this field and adding "weight" to the sum if the node matches the corresponding matchExpressions; the node(s) with the highest sum are the most preferred. items: description: An empty preferred scheduling term matches all objects with implicit weight 0 (i.e. it's a no-op). A null preferred scheduling term matches no objects (i.e. is also a no-op). properties: preference: description: A node selector term, associated with the corresponding weight. properties: matchExpressions: description: A list of node selector requirements by node's labels. items: description: A node selector requirement is a selector that contains values, a key, and an operator that relates the key and values. properties: key: description: The label key that the selector applies to. type: string operator: description: Represents a key's relationship to a set of values. Valid operators are In, NotIn, Exists, DoesNotExist. Gt, and Lt. type: string values: description: An array of string values. If the operator is In or NotIn, the values array must be non-empty. If the operator is Exists or DoesNotExist, the values array must be empty. If the operator is Gt or Lt, the values array must have a single element, which will be interpreted as an integer. This array is replaced during a strategic merge patch. items: type: string type: array required: - key - operator type: object type: array matchFields: description: A list of node selector requirements by node's fields. items: description: A node selector requirement is a selector that contains values, a key, and an operator that relates the key and values. properties: key: description: The label key that the selector applies to. type: string operator: description: Represents a key's relationship to a set of values. Valid operators are In, NotIn, Exists, DoesNotExist. Gt, and Lt. type: string values: description: An array of string values. If the operator is In or NotIn, the values array must be non-empty. If the operator is Exists or DoesNotExist, the values array must be empty. If the operator is Gt or Lt, the values array must have a single element, which will be interpreted as an integer. This array is replaced during a strategic merge patch. items: type: string type: array required: - key - operator type: object type: array type: object weight: description: Weight associated with matching the corresponding nodeSelectorTerm, in the range 1-100. format: int32 type: integer required: - preference - weight type: object type: array requiredDuringSchedulingIgnoredDuringExecution: description: If the affinity requirements specified by this field are not met at scheduling time, the pod will not be scheduled onto the node. If the affinity requirements specified by this field cease to be met at some point during pod execution (e.g. due to an update), the system may or may not try to eventually evict the pod from its node. properties: nodeSelectorTerms: description: Required. A list of node selector terms. The terms are ORed. items: description: A null or empty node selector term matches no objects. The requirements of them are ANDed. The TopologySelectorTerm type implements a subset of the NodeSelectorTerm. properties: matchExpressions: description: A list of node selector requirements by node's labels. items: description: A node selector requirement is a selector that contains values, a key, and an operator that relates the key and values. properties: key: description: The label key that the selector applies to. type: string operator: description: Represents a key's relationship to a set of values. Valid operators are In, NotIn, Exists, DoesNotExist. Gt, and Lt. type: string values: description: An array of string values. If the operator is In or NotIn, the values array must be non-empty. If the operator is Exists or DoesNotExist, the values array must be empty. If the operator is Gt or Lt, the values array must have a single element, which will be interpreted as an integer. This array is replaced during a strategic merge patch. items: type: string type: array required: - key - operator type: object type: array matchFields: description: A list of node selector requirements by node's fields. items: description: A node selector requirement is a selector that contains values, a key, and an operator that relates the key and values. properties: key: description: The label key that the selector applies to. type: string operator: description: Represents a key's relationship to a set of values. Valid operators are In, NotIn, Exists, DoesNotExist. Gt, and Lt. type: string values: description: An array of string values. If the operator is In or NotIn, the values array must be non-empty. If the operator is Exists or DoesNotExist, the values array must be empty. If the operator is Gt or Lt, the values array must have a single element, which will be interpreted as an integer. This array is replaced during a strategic merge patch. items: type: string type: array required: - key - operator type: object type: array type: object type: array required: - nodeSelectorTerms type: object type: object podAffinity: description: Describes pod affinity scheduling rules (e.g. co-locate this pod in the same node, zone, etc. as some other pod(s)). properties: preferredDuringSchedulingIgnoredDuringExecution: description: The scheduler will prefer to schedule pods to nodes that satisfy the affinity expressions specified by this field, but it may choose a node that violates one or more of the expressions. The node that is most preferred is the one with the greatest sum of weights, i.e. for each node that meets all of the scheduling requirements (resource request, requiredDuringScheduling affinity expressions, etc.), compute a sum by iterating through the elements of this field and adding "weight" to the sum if the node has pods which matches the corresponding podAffinityTerm; the node(s) with the highest sum are the most preferred. items: description: The weights of all of the matched WeightedPodAffinityTerm fields are added per-node to find the most preferred node(s) properties: podAffinityTerm: description: Required. A pod affinity term, associated with the corresponding weight. properties: labelSelector: description: A label query over a set of resources, in this case pods. properties: matchExpressions: description: matchExpressions is a list of label selector requirements. The requirements are ANDed. items: description: A label selector requirement is a selector that contains values, a key, and an operator that relates the key and values. properties: key: description: key is the label key that the selector applies to. type: string operator: description: operator represents a key's relationship to a set of values. Valid operators are In, NotIn, Exists and DoesNotExist. type: string values: description: values is an array of string values. If the operator is In or NotIn, the values array must be non-empty. If the operator is Exists or DoesNotExist, the values array must be empty. This array is replaced during a strategic merge patch. items: type: string type: array required: - key - operator type: object type: array matchLabels: additionalProperties: type: string description: matchLabels is a map of {key,value} pairs. A single {key,value} in the matchLabels map is equivalent to an element of matchExpressions, whose key field is "key", the operator is "In", and the values array contains only "value". The requirements are ANDed. type: object type: object namespaces: description: namespaces specifies which namespaces the labelSelector applies to (matches against); null or empty list means "this pod's namespace" items: type: string type: array topologyKey: description: This pod should be co-located (affinity) or not co-located (anti-affinity) with the pods matching the labelSelector in the specified namespaces, where co-located is defined as running on a node whose value of the label with key topologyKey matches that of any node on which any of the selected pods is running. Empty topologyKey is not allowed. type: string required: - topologyKey type: object weight: description: weight associated with matching the corresponding podAffinityTerm, in the range 1-100. format: int32 type: integer required: - podAffinityTerm - weight type: object type: array requiredDuringSchedulingIgnoredDuringExecution: description: If the affinity requirements specified by this field are not met at scheduling time, the pod will not be scheduled onto the node. If the affinity requirements specified by this field cease to be met at some point during pod execution (e.g. due to a pod label update), the system may or may not try to eventually evict the pod from its node. When there are multiple elements, the lists of nodes corresponding to each podAffinityTerm are intersected, i.e. all terms must be satisfied. items: description: Defines a set of pods (namely those matching the labelSelector relative to the given namespace(s)) that this pod should be co-located (affinity) or not co-located (anti-affinity) with, where co-located is defined as running on a node whose value of the label with key matches that of any node on which a pod of the set of pods is running properties: labelSelector: description: A label query over a set of resources, in this case pods. properties: matchExpressions: description: matchExpressions is a list of label selector requirements. The requirements are ANDed. items: description: A label selector requirement is a selector that contains values, a key, and an operator that relates the key and values. properties: key: description: key is the label key that the selector applies to. type: string operator: description: operator represents a key's relationship to a set of values. Valid operators are In, NotIn, Exists and DoesNotExist. type: string values: description: values is an array of string values. If the operator is In or NotIn, the values array must be non-empty. If the operator is Exists or DoesNotExist, the values array must be empty. This array is replaced during a strategic merge patch. items: type: string type: array required: - key - operator type: object type: array matchLabels: additionalProperties: type: string description: matchLabels is a map of {key,value} pairs. A single {key,value} in the matchLabels map is equivalent to an element of matchExpressions, whose key field is "key", the operator is "In", and the values array contains only "value". The requirements are ANDed. type: object type: object namespaces: description: namespaces specifies which namespaces the labelSelector applies to (matches against); null or empty list means "this pod's namespace" items: type: string type: array topologyKey: description: This pod should be co-located (affinity) or not co-located (anti-affinity) with the pods matching the labelSelector in the specified namespaces, where co-located is defined as running on a node whose value of the label with key topologyKey matches that of any node on which any of the selected pods is running. Empty topologyKey is not allowed. type: string required: - topologyKey type: object type: array type: object podAntiAffinity: description: Describes pod anti-affinity scheduling rules (e.g. avoid putting this pod in the same node, zone, etc. as some other pod(s)). properties: preferredDuringSchedulingIgnoredDuringExecution: description: The scheduler will prefer to schedule pods to nodes that satisfy the anti-affinity expressions specified by this field, but it may choose a node that violates one or more of the expressions. The node that is most preferred is the one with the greatest sum of weights, i.e. for each node that meets all of the scheduling requirements (resource request, requiredDuringScheduling anti-affinity expressions, etc.), compute a sum by iterating through the elements of this field and adding "weight" to the sum if the node has pods which matches the corresponding podAffinityTerm; the node(s) with the highest sum are the most preferred. items: description: The weights of all of the matched WeightedPodAffinityTerm fields are added per-node to find the most preferred node(s) properties: podAffinityTerm: description: Required. A pod affinity term, associated with the corresponding weight. properties: labelSelector: description: A label query over a set of resources, in this case pods. properties: matchExpressions: description: matchExpressions is a list of label selector requirements. The requirements are ANDed. items: description: A label selector requirement is a selector that contains values, a key, and an operator that relates the key and values. properties: key: description: key is the label key that the selector applies to. type: string operator: description: operator represents a key's relationship to a set of values. Valid operators are In, NotIn, Exists and DoesNotExist. type: string values: description: values is an array of string values. If the operator is In or NotIn, the values array must be non-empty. If the operator is Exists or DoesNotExist, the values array must be empty. This array is replaced during a strategic merge patch. items: type: string type: array required: - key - operator type: object type: array matchLabels: additionalProperties: type: string description: matchLabels is a map of {key,value} pairs. A single {key,value} in the matchLabels map is equivalent to an element of matchExpressions, whose key field is "key", the operator is "In", and the values array contains only "value". The requirements are ANDed. type: object type: object namespaces: description: namespaces specifies which namespaces the labelSelector applies to (matches against); null or empty list means "this pod's namespace" items: type: string type: array topologyKey: description: This pod should be co-located (affinity) or not co-located (anti-affinity) with the pods matching the labelSelector in the specified namespaces, where co-located is defined as running on a node whose value of the label with key topologyKey matches that of any node on which any of the selected pods is running. Empty topologyKey is not allowed. type: string required: - topologyKey type: object weight: description: weight associated with matching the corresponding podAffinityTerm, in the range 1-100. format: int32 type: integer required: - podAffinityTerm - weight type: object type: array requiredDuringSchedulingIgnoredDuringExecution: description: If the anti-affinity requirements specified by this field are not met at scheduling time, the pod will not be scheduled onto the node. If the anti-affinity requirements specified by this field cease to be met at some point during pod execution (e.g. due to a pod label update), the system may or may not try to eventually evict the pod from its node. When there are multiple elements, the lists of nodes corresponding to each podAffinityTerm are intersected, i.e. all terms must be satisfied. items: description: Defines a set of pods (namely those matching the labelSelector relative to the given namespace(s)) that this pod should be co-located (affinity) or not co-located (anti-affinity) with, where co-located is defined as running on a node whose value of the label with key matches that of any node on which a pod of the set of pods is running properties: labelSelector: description: A label query over a set of resources, in this case pods. properties: matchExpressions: description: matchExpressions is a list of label selector requirements. The requirements are ANDed. items: description: A label selector requirement is a selector that contains values, a key, and an operator that relates the key and values. properties: key: description: key is the label key that the selector applies to. type: string operator: description: operator represents a key's relationship to a set of values. Valid operators are In, NotIn, Exists and DoesNotExist. type: string values: description: values is an array of string values. If the operator is In or NotIn, the values array must be non-empty. If the operator is Exists or DoesNotExist, the values array must be empty. This array is replaced during a strategic merge patch. items: type: string type: array required: - key - operator type: object type: array matchLabels: additionalProperties: type: string description: matchLabels is a map of {key,value} pairs. A single {key,value} in the matchLabels map is equivalent to an element of matchExpressions, whose key field is "key", the operator is "In", and the values array contains only "value". The requirements are ANDed. type: object type: object namespaces: description: namespaces specifies which namespaces the labelSelector applies to (matches against); null or empty list means "this pod's namespace" items: type: string type: array topologyKey: description: This pod should be co-located (affinity) or not co-located (anti-affinity) with the pods matching the labelSelector in the specified namespaces, where co-located is defined as running on a node whose value of the label with key topologyKey matches that of any node on which any of the selected pods is running. Empty topologyKey is not allowed. type: string required: - topologyKey type: object type: array type: object type: object automountServiceAccountToken: description: AutomountServiceAccountToken indicates whether a service account token should be automatically mounted. type: boolean containers: description: List of containers belonging to the pod. Containers cannot currently be added or removed. There must be at least one container in a Pod. Cannot be updated. items: description: A single application container that you want to run within a pod. properties: args: description: 'Arguments to the entrypoint. The docker image''s CMD is used if this is not provided. Variable references $(VAR_NAME) are expanded using the container''s environment. If a variable cannot be resolved, the reference in the input string will be unchanged. The $(VAR_NAME) syntax can be escaped with a double $$, ie: $$(VAR_NAME). Escaped references will never be expanded, regardless of whether the variable exists or not. Cannot be updated. More info: https://kubernetes.io/docs/tasks/inject-data-application/define-command-argument-container/#running-a-command-in-a-shell' items: type: string type: array command: description: 'Entrypoint array. Not executed within a shell. The docker image''s ENTRYPOINT is used if this is not provided. Variable references $(VAR_NAME) are expanded using the container''s environment. If a variable cannot be resolved, the reference in the input string will be unchanged. The $(VAR_NAME) syntax can be escaped with a double $$, ie: $$(VAR_NAME). Escaped references will never be expanded, regardless of whether the variable exists or not. Cannot be updated. More info: https://kubernetes.io/docs/tasks/inject-data-application/define-command-argument-container/#running-a-command-in-a-shell' items: type: string type: array env: description: List of environment variables to set in the container. Cannot be updated. items: description: EnvVar represents an environment variable present in a Container. properties: name: description: Name of the environment variable. Must be a C_IDENTIFIER. type: string value: description: 'Variable references $(VAR_NAME) are expanded using the previous defined environment variables in the container and any service environment variables. If a variable cannot be resolved, the reference in the input string will be unchanged. The $(VAR_NAME) syntax can be escaped with a double $$, ie: $$(VAR_NAME). Escaped references will never be expanded, regardless of whether the variable exists or not. Defaults to "".' type: string valueFrom: description: Source for the environment variable's value. Cannot be used if value is not empty. properties: configMapKeyRef: description: Selects a key of a ConfigMap. properties: key: description: The key to select. type: string name: description: 'Name of the referent. More info: https://kubernetes.io/docs/concepts/overview/working-with-objects/names/#names TODO: Add other useful fields. apiVersion, kind, uid?' type: string optional: description: Specify whether the ConfigMap or its key must be defined type: boolean required: - key type: object fieldRef: description: 'Selects a field of the pod: supports metadata.name, metadata.namespace, `metadata.labels['''']`, `metadata.annotations['''']`, spec.nodeName, spec.serviceAccountName, status.hostIP, status.podIP, status.podIPs.' properties: apiVersion: description: Version of the schema the FieldPath is written in terms of, defaults to "v1". type: string fieldPath: description: Path of the field to select in the specified API version. type: string required: - fieldPath type: object resourceFieldRef: description: 'Selects a resource of the container: only resources limits and requests (limits.cpu, limits.memory, limits.ephemeral-storage, requests.cpu, requests.memory and requests.ephemeral-storage) are currently supported.' properties: containerName: description: 'Container name: required for volumes, optional for env vars' type: string divisor: anyOf: - type: integer - type: string description: Specifies the output format of the exposed resources, defaults to "1" pattern: ^(\+|-)?(([0-9]+(\.[0-9]*)?)|(\.[0-9]+))(([KMGTPE]i)|[numkMGTPE]|([eE](\+|-)?(([0-9]+(\.[0-9]*)?)|(\.[0-9]+))))?$ x-kubernetes-int-or-string: true resource: description: 'Required: resource to select' type: string required: - resource type: object secretKeyRef: description: Selects a key of a secret in the pod's namespace properties: key: description: The key of the secret to select from. Must be a valid secret key. type: string name: description: 'Name of the referent. More info: https://kubernetes.io/docs/concepts/overview/working-with-objects/names/#names TODO: Add other useful fields. apiVersion, kind, uid?' type: string optional: description: Specify whether the Secret or its key must be defined type: boolean required: - key type: object type: object required: - name type: object type: array envFrom: description: List of sources to populate environment variables in the container. The keys defined within a source must be a C_IDENTIFIER. All invalid keys will be reported as an event when the container is starting. When a key exists in multiple sources, the value associated with the last source will take precedence. Values defined by an Env with a duplicate key will take precedence. Cannot be updated. items: description: EnvFromSource represents the source of a set of ConfigMaps properties: configMapRef: description: The ConfigMap to select from properties: name: description: 'Name of the referent. More info: https://kubernetes.io/docs/concepts/overview/working-with-objects/names/#names TODO: Add other useful fields. apiVersion, kind, uid?' type: string optional: description: Specify whether the ConfigMap must be defined type: boolean type: object prefix: description: An optional identifier to prepend to each key in the ConfigMap. Must be a C_IDENTIFIER. type: string secretRef: description: The Secret to select from properties: name: description: 'Name of the referent. More info: https://kubernetes.io/docs/concepts/overview/working-with-objects/names/#names TODO: Add other useful fields. apiVersion, kind, uid?' type: string optional: description: Specify whether the Secret must be defined type: boolean type: object type: object type: array image: description: 'Docker image name. More info: https://kubernetes.io/docs/concepts/containers/images This field is optional to allow higher level config management to default or override container images in workload controllers like Deployments and StatefulSets.' type: string imagePullPolicy: description: 'Image pull policy. One of Always, Never, IfNotPresent. Defaults to Always if :latest tag is specified, or IfNotPresent otherwise. Cannot be updated. More info: https://kubernetes.io/docs/concepts/containers/images#updating-images' type: string lifecycle: description: Actions that the management system should take in response to container lifecycle events. Cannot be updated. properties: postStart: description: 'PostStart is called immediately after a container is created. If the handler fails, the container is terminated and restarted according to its restart policy. Other management of the container blocks until the hook completes. More info: https://kubernetes.io/docs/concepts/containers/container-lifecycle-hooks/#container-hooks' properties: exec: description: One and only one of the following should be specified. Exec specifies the action to take. properties: command: description: Command is the command line to execute inside the container, the working directory for the command is root ('/') in the container's filesystem. The command is simply exec'd, it is not run inside a shell, so traditional shell instructions ('|', etc) won't work. To use a shell, you need to explicitly call out to that shell. Exit status of 0 is treated as live/healthy and non-zero is unhealthy. items: type: string type: array type: object httpGet: description: HTTPGet specifies the http request to perform. properties: host: description: Host name to connect to, defaults to the pod IP. You probably want to set "Host" in httpHeaders instead. type: string httpHeaders: description: Custom headers to set in the request. HTTP allows repeated headers. items: description: HTTPHeader describes a custom header to be used in HTTP probes properties: name: description: The header field name type: string value: description: The header field value type: string required: - name - value type: object type: array path: description: Path to access on the HTTP server. type: string port: anyOf: - type: integer - type: string description: Name or number of the port to access on the container. Number must be in the range 1 to 65535. Name must be an IANA_SVC_NAME. x-kubernetes-int-or-string: true scheme: description: Scheme to use for connecting to the host. Defaults to HTTP. type: string required: - port type: object tcpSocket: description: 'TCPSocket specifies an action involving a TCP port. TCP hooks not yet supported TODO: implement a realistic TCP lifecycle hook' properties: host: description: 'Optional: Host name to connect to, defaults to the pod IP.' type: string port: anyOf: - type: integer - type: string description: Number or name of the port to access on the container. Number must be in the range 1 to 65535. Name must be an IANA_SVC_NAME. x-kubernetes-int-or-string: true required: - port type: object type: object preStop: description: 'PreStop is called immediately before a container is terminated due to an API request or management event such as liveness/startup probe failure, preemption, resource contention, etc. The handler is not called if the container crashes or exits. The reason for termination is passed to the handler. The Pod''s termination grace period countdown begins before the PreStop hooked is executed. Regardless of the outcome of the handler, the container will eventually terminate within the Pod''s termination grace period. Other management of the container blocks until the hook completes or until the termination grace period is reached. More info: https://kubernetes.io/docs/concepts/containers/container-lifecycle-hooks/#container-hooks' properties: exec: description: One and only one of the following should be specified. Exec specifies the action to take. properties: command: description: Command is the command line to execute inside the container, the working directory for the command is root ('/') in the container's filesystem. The command is simply exec'd, it is not run inside a shell, so traditional shell instructions ('|', etc) won't work. To use a shell, you need to explicitly call out to that shell. Exit status of 0 is treated as live/healthy and non-zero is unhealthy. items: type: string type: array type: object httpGet: description: HTTPGet specifies the http request to perform. properties: host: description: Host name to connect to, defaults to the pod IP. You probably want to set "Host" in httpHeaders instead. type: string httpHeaders: description: Custom headers to set in the request. HTTP allows repeated headers. items: description: HTTPHeader describes a custom header to be used in HTTP probes properties: name: description: The header field name type: string value: description: The header field value type: string required: - name - value type: object type: array path: description: Path to access on the HTTP server. type: string port: anyOf: - type: integer - type: string description: Name or number of the port to access on the container. Number must be in the range 1 to 65535. Name must be an IANA_SVC_NAME. x-kubernetes-int-or-string: true scheme: description: Scheme to use for connecting to the host. Defaults to HTTP. type: string required: - port type: object tcpSocket: description: 'TCPSocket specifies an action involving a TCP port. TCP hooks not yet supported TODO: implement a realistic TCP lifecycle hook' properties: host: description: 'Optional: Host name to connect to, defaults to the pod IP.' type: string port: anyOf: - type: integer - type: string description: Number or name of the port to access on the container. Number must be in the range 1 to 65535. Name must be an IANA_SVC_NAME. x-kubernetes-int-or-string: true required: - port type: object type: object type: object livenessProbe: description: 'Periodic probe of container liveness. Container will be restarted if the probe fails. Cannot be updated. More info: https://kubernetes.io/docs/concepts/workloads/pods/pod-lifecycle#container-probes' properties: exec: description: One and only one of the following should be specified. Exec specifies the action to take. properties: command: description: Command is the command line to execute inside the container, the working directory for the command is root ('/') in the container's filesystem. The command is simply exec'd, it is not run inside a shell, so traditional shell instructions ('|', etc) won't work. To use a shell, you need to explicitly call out to that shell. Exit status of 0 is treated as live/healthy and non-zero is unhealthy. items: type: string type: array type: object failureThreshold: description: Minimum consecutive failures for the probe to be considered failed after having succeeded. Defaults to 3. Minimum value is 1. format: int32 type: integer httpGet: description: HTTPGet specifies the http request to perform. properties: host: description: Host name to connect to, defaults to the pod IP. You probably want to set "Host" in httpHeaders instead. type: string httpHeaders: description: Custom headers to set in the request. HTTP allows repeated headers. items: description: HTTPHeader describes a custom header to be used in HTTP probes properties: name: description: The header field name type: string value: description: The header field value type: string required: - name - value type: object type: array path: description: Path to access on the HTTP server. type: string port: anyOf: - type: integer - type: string description: Name or number of the port to access on the container. Number must be in the range 1 to 65535. Name must be an IANA_SVC_NAME. x-kubernetes-int-or-string: true scheme: description: Scheme to use for connecting to the host. Defaults to HTTP. type: string required: - port type: object initialDelaySeconds: description: 'Number of seconds after the container has started before liveness probes are initiated. More info: https://kubernetes.io/docs/concepts/workloads/pods/pod-lifecycle#container-probes' format: int32 type: integer periodSeconds: description: How often (in seconds) to perform the probe. Default to 10 seconds. Minimum value is 1. format: int32 type: integer successThreshold: description: Minimum consecutive successes for the probe to be considered successful after having failed. Defaults to 1. Must be 1 for liveness and startup. Minimum value is 1. format: int32 type: integer tcpSocket: description: 'TCPSocket specifies an action involving a TCP port. TCP hooks not yet supported TODO: implement a realistic TCP lifecycle hook' properties: host: description: 'Optional: Host name to connect to, defaults to the pod IP.' type: string port: anyOf: - type: integer - type: string description: Number or name of the port to access on the container. Number must be in the range 1 to 65535. Name must be an IANA_SVC_NAME. x-kubernetes-int-or-string: true required: - port type: object timeoutSeconds: description: 'Number of seconds after which the probe times out. Defaults to 1 second. Minimum value is 1. More info: https://kubernetes.io/docs/concepts/workloads/pods/pod-lifecycle#container-probes' format: int32 type: integer type: object name: description: Name of the container specified as a DNS_LABEL. Each container in a pod must have a unique name (DNS_LABEL). Cannot be updated. type: string ports: description: List of ports to expose from the container. Exposing a port here gives the system additional information about the network connections a container uses, but is primarily informational. Not specifying a port here DOES NOT prevent that port from being exposed. Any port which is listening on the default "0.0.0.0" address inside a container will be accessible from the network. Cannot be updated. items: description: ContainerPort represents a network port in a single container. properties: containerPort: description: Number of port to expose on the pod's IP address. This must be a valid port number, 0 < x < 65536. format: int32 type: integer hostIP: description: What host IP to bind the external port to. type: string hostPort: description: Number of port to expose on the host. If specified, this must be a valid port number, 0 < x < 65536. If HostNetwork is specified, this must match ContainerPort. Most containers do not need this. format: int32 type: integer name: description: If specified, this must be an IANA_SVC_NAME and unique within the pod. Each named port in a pod must have a unique name. Name for the port that can be referred to by services. type: string protocol: default: TCP description: Protocol for port. Must be UDP, TCP, or SCTP. Defaults to "TCP". type: string required: - containerPort type: object type: array x-kubernetes-list-map-keys: - containerPort - protocol x-kubernetes-list-type: map readinessProbe: description: 'Periodic probe of container service readiness. Container will be removed from service endpoints if the probe fails. Cannot be updated. More info: https://kubernetes.io/docs/concepts/workloads/pods/pod-lifecycle#container-probes' properties: exec: description: One and only one of the following should be specified. Exec specifies the action to take. properties: command: description: Command is the command line to execute inside the container, the working directory for the command is root ('/') in the container's filesystem. The command is simply exec'd, it is not run inside a shell, so traditional shell instructions ('|', etc) won't work. To use a shell, you need to explicitly call out to that shell. Exit status of 0 is treated as live/healthy and non-zero is unhealthy. items: type: string type: array type: object failureThreshold: description: Minimum consecutive failures for the probe to be considered failed after having succeeded. Defaults to 3. Minimum value is 1. format: int32 type: integer httpGet: description: HTTPGet specifies the http request to perform. properties: host: description: Host name to connect to, defaults to the pod IP. You probably want to set "Host" in httpHeaders instead. type: string httpHeaders: description: Custom headers to set in the request. HTTP allows repeated headers. items: description: HTTPHeader describes a custom header to be used in HTTP probes properties: name: description: The header field name type: string value: description: The header field value type: string required: - name - value type: object type: array path: description: Path to access on the HTTP server. type: string port: anyOf: - type: integer - type: string description: Name or number of the port to access on the container. Number must be in the range 1 to 65535. Name must be an IANA_SVC_NAME. x-kubernetes-int-or-string: true scheme: description: Scheme to use for connecting to the host. Defaults to HTTP. type: string required: - port type: object initialDelaySeconds: description: 'Number of seconds after the container has started before liveness probes are initiated. More info: https://kubernetes.io/docs/concepts/workloads/pods/pod-lifecycle#container-probes' format: int32 type: integer periodSeconds: description: How often (in seconds) to perform the probe. Default to 10 seconds. Minimum value is 1. format: int32 type: integer successThreshold: description: Minimum consecutive successes for the probe to be considered successful after having failed. Defaults to 1. Must be 1 for liveness and startup. Minimum value is 1. format: int32 type: integer tcpSocket: description: 'TCPSocket specifies an action involving a TCP port. TCP hooks not yet supported TODO: implement a realistic TCP lifecycle hook' properties: host: description: 'Optional: Host name to connect to, defaults to the pod IP.' type: string port: anyOf: - type: integer - type: string description: Number or name of the port to access on the container. Number must be in the range 1 to 65535. Name must be an IANA_SVC_NAME. x-kubernetes-int-or-string: true required: - port type: object timeoutSeconds: description: 'Number of seconds after which the probe times out. Defaults to 1 second. Minimum value is 1. More info: https://kubernetes.io/docs/concepts/workloads/pods/pod-lifecycle#container-probes' format: int32 type: integer type: object resources: description: 'Compute Resources required by this container. Cannot be updated. More info: https://kubernetes.io/docs/concepts/configuration/manage-compute-resources-container/' properties: limits: additionalProperties: anyOf: - type: integer - type: string pattern: ^(\+|-)?(([0-9]+(\.[0-9]*)?)|(\.[0-9]+))(([KMGTPE]i)|[numkMGTPE]|([eE](\+|-)?(([0-9]+(\.[0-9]*)?)|(\.[0-9]+))))?$ x-kubernetes-int-or-string: true description: 'Limits describes the maximum amount of compute resources allowed. More info: https://kubernetes.io/docs/concepts/configuration/manage-compute-resources-container/' type: object requests: additionalProperties: anyOf: - type: integer - type: string pattern: ^(\+|-)?(([0-9]+(\.[0-9]*)?)|(\.[0-9]+))(([KMGTPE]i)|[numkMGTPE]|([eE](\+|-)?(([0-9]+(\.[0-9]*)?)|(\.[0-9]+))))?$ x-kubernetes-int-or-string: true description: 'Requests describes the minimum amount of compute resources required. If Requests is omitted for a container, it defaults to Limits if that is explicitly specified, otherwise to an implementation-defined value. More info: https://kubernetes.io/docs/concepts/configuration/manage-compute-resources-container/' type: object type: object securityContext: description: 'Security options the pod should run with. More info: https://kubernetes.io/docs/concepts/policy/security-context/ More info: https://kubernetes.io/docs/tasks/configure-pod-container/security-context/' properties: allowPrivilegeEscalation: description: 'AllowPrivilegeEscalation controls whether a process can gain more privileges than its parent process. This bool directly controls if the no_new_privs flag will be set on the container process. AllowPrivilegeEscalation is true always when the container is: 1) run as Privileged 2) has CAP_SYS_ADMIN' type: boolean capabilities: description: The capabilities to add/drop when running containers. Defaults to the default set of capabilities granted by the container runtime. properties: add: description: Added capabilities items: description: Capability represent POSIX capabilities type type: string type: array drop: description: Removed capabilities items: description: Capability represent POSIX capabilities type type: string type: array type: object privileged: description: Run container in privileged mode. Processes in privileged containers are essentially equivalent to root on the host. Defaults to false. type: boolean procMount: description: procMount denotes the type of proc mount to use for the containers. The default is DefaultProcMount which uses the container runtime defaults for readonly paths and masked paths. This requires the ProcMountType feature flag to be enabled. type: string readOnlyRootFilesystem: description: Whether this container has a read-only root filesystem. Default is false. type: boolean runAsGroup: description: The GID to run the entrypoint of the container process. Uses runtime default if unset. May also be set in PodSecurityContext. If set in both SecurityContext and PodSecurityContext, the value specified in SecurityContext takes precedence. format: int64 type: integer runAsNonRoot: description: Indicates that the container must run as a non-root user. If true, the Kubelet will validate the image at runtime to ensure that it does not run as UID 0 (root) and fail to start the container if it does. If unset or false, no such validation will be performed. May also be set in PodSecurityContext. If set in both SecurityContext and PodSecurityContext, the value specified in SecurityContext takes precedence. type: boolean runAsUser: description: The UID to run the entrypoint of the container process. Defaults to user specified in image metadata if unspecified. May also be set in PodSecurityContext. If set in both SecurityContext and PodSecurityContext, the value specified in SecurityContext takes precedence. format: int64 type: integer seLinuxOptions: description: The SELinux context to be applied to the container. If unspecified, the container runtime will allocate a random SELinux context for each container. May also be set in PodSecurityContext. If set in both SecurityContext and PodSecurityContext, the value specified in SecurityContext takes precedence. properties: level: description: Level is SELinux level label that applies to the container. type: string role: description: Role is a SELinux role label that applies to the container. type: string type: description: Type is a SELinux type label that applies to the container. type: string user: description: User is a SELinux user label that applies to the container. type: string type: object seccompProfile: description: The seccomp options to use by this container. If seccomp options are provided at both the pod & container level, the container options override the pod options. properties: localhostProfile: description: localhostProfile indicates a profile defined in a file on the node should be used. The profile must be preconfigured on the node to work. Must be a descending path, relative to the kubelet's configured seccomp profile location. Must only be set if type is "Localhost". type: string type: description: "type indicates which kind of seccomp profile will be applied. Valid options are: \n Localhost - a profile defined in a file on the node should be used. RuntimeDefault - the container runtime default profile should be used. Unconfined - no profile should be applied." type: string required: - type type: object windowsOptions: description: The Windows specific settings applied to all containers. If unspecified, the options from the PodSecurityContext will be used. If set in both SecurityContext and PodSecurityContext, the value specified in SecurityContext takes precedence. properties: gmsaCredentialSpec: description: GMSACredentialSpec is where the GMSA admission webhook (https://github.com/kubernetes-sigs/windows-gmsa) inlines the contents of the GMSA credential spec named by the GMSACredentialSpecName field. type: string gmsaCredentialSpecName: description: GMSACredentialSpecName is the name of the GMSA credential spec to use. type: string runAsUserName: description: The UserName in Windows to run the entrypoint of the container process. Defaults to the user specified in image metadata if unspecified. May also be set in PodSecurityContext. If set in both SecurityContext and PodSecurityContext, the value specified in SecurityContext takes precedence. type: string type: object type: object startupProbe: description: 'StartupProbe indicates that the Pod has successfully initialized. If specified, no other probes are executed until this completes successfully. If this probe fails, the Pod will be restarted, just as if the livenessProbe failed. This can be used to provide different probe parameters at the beginning of a Pod''s lifecycle, when it might take a long time to load data or warm a cache, than during steady-state operation. This cannot be updated. More info: https://kubernetes.io/docs/concepts/workloads/pods/pod-lifecycle#container-probes' properties: exec: description: One and only one of the following should be specified. Exec specifies the action to take. properties: command: description: Command is the command line to execute inside the container, the working directory for the command is root ('/') in the container's filesystem. The command is simply exec'd, it is not run inside a shell, so traditional shell instructions ('|', etc) won't work. To use a shell, you need to explicitly call out to that shell. Exit status of 0 is treated as live/healthy and non-zero is unhealthy. items: type: string type: array type: object failureThreshold: description: Minimum consecutive failures for the probe to be considered failed after having succeeded. Defaults to 3. Minimum value is 1. format: int32 type: integer httpGet: description: HTTPGet specifies the http request to perform. properties: host: description: Host name to connect to, defaults to the pod IP. You probably want to set "Host" in httpHeaders instead. type: string httpHeaders: description: Custom headers to set in the request. HTTP allows repeated headers. items: description: HTTPHeader describes a custom header to be used in HTTP probes properties: name: description: The header field name type: string value: description: The header field value type: string required: - name - value type: object type: array path: description: Path to access on the HTTP server. type: string port: anyOf: - type: integer - type: string description: Name or number of the port to access on the container. Number must be in the range 1 to 65535. Name must be an IANA_SVC_NAME. x-kubernetes-int-or-string: true scheme: description: Scheme to use for connecting to the host. Defaults to HTTP. type: string required: - port type: object initialDelaySeconds: description: 'Number of seconds after the container has started before liveness probes are initiated. More info: https://kubernetes.io/docs/concepts/workloads/pods/pod-lifecycle#container-probes' format: int32 type: integer periodSeconds: description: How often (in seconds) to perform the probe. Default to 10 seconds. Minimum value is 1. format: int32 type: integer successThreshold: description: Minimum consecutive successes for the probe to be considered successful after having failed. Defaults to 1. Must be 1 for liveness and startup. Minimum value is 1. format: int32 type: integer tcpSocket: description: 'TCPSocket specifies an action involving a TCP port. TCP hooks not yet supported TODO: implement a realistic TCP lifecycle hook' properties: host: description: 'Optional: Host name to connect to, defaults to the pod IP.' type: string port: anyOf: - type: integer - type: string description: Number or name of the port to access on the container. Number must be in the range 1 to 65535. Name must be an IANA_SVC_NAME. x-kubernetes-int-or-string: true required: - port type: object timeoutSeconds: description: 'Number of seconds after which the probe times out. Defaults to 1 second. Minimum value is 1. More info: https://kubernetes.io/docs/concepts/workloads/pods/pod-lifecycle#container-probes' format: int32 type: integer type: object stdin: description: Whether this container should allocate a buffer for stdin in the container runtime. If this is not set, reads from stdin in the container will always result in EOF. Default is false. type: boolean stdinOnce: description: Whether the container runtime should close the stdin channel after it has been opened by a single attach. When stdin is true the stdin stream will remain open across multiple attach sessions. If stdinOnce is set to true, stdin is opened on container start, is empty until the first client attaches to stdin, and then remains open and accepts data until the client disconnects, at which time stdin is closed and remains closed until the container is restarted. If this flag is false, a container processes that reads from stdin will never receive an EOF. Default is false type: boolean terminationMessagePath: description: 'Optional: Path at which the file to which the container''s termination message will be written is mounted into the container''s filesystem. Message written is intended to be brief final status, such as an assertion failure message. Will be truncated by the node if greater than 4096 bytes. The total message length across all containers will be limited to 12kb. Defaults to /dev/termination-log. Cannot be updated.' type: string terminationMessagePolicy: description: Indicate how the termination message should be populated. File will use the contents of terminationMessagePath to populate the container status message on both success and failure. FallbackToLogsOnError will use the last chunk of container log output if the termination message file is empty and the container exited with an error. The log output is limited to 2048 bytes or 80 lines, whichever is smaller. Defaults to File. Cannot be updated. type: string tty: description: Whether this container should allocate a TTY for itself, also requires 'stdin' to be true. Default is false. type: boolean volumeDevices: description: volumeDevices is the list of block devices to be used by the container. items: description: volumeDevice describes a mapping of a raw block device within a container. properties: devicePath: description: devicePath is the path inside of the container that the device will be mapped to. type: string name: description: name must match the name of a persistentVolumeClaim in the pod type: string required: - devicePath - name type: object type: array volumeMounts: description: Pod volumes to mount into the container's filesystem. Cannot be updated. items: description: VolumeMount describes a mounting of a Volume within a container. properties: mountPath: description: Path within the container at which the volume should be mounted. Must not contain ':'. type: string mountPropagation: description: mountPropagation determines how mounts are propagated from the host to container and the other way around. When not set, MountPropagationNone is used. This field is beta in 1.10. type: string name: description: This must match the Name of a Volume. type: string readOnly: description: Mounted read-only if true, read-write otherwise (false or unspecified). Defaults to false. type: boolean subPath: description: Path within the volume from which the container's volume should be mounted. Defaults to "" (volume's root). type: string subPathExpr: description: Expanded path within the volume from which the container's volume should be mounted. Behaves similarly to SubPath but environment variable references $(VAR_NAME) are expanded using the container's environment. Defaults to "" (volume's root). SubPathExpr and SubPath are mutually exclusive. type: string required: - mountPath - name type: object type: array workingDir: description: Container's working directory. If not specified, the container runtime's default will be used, which might be configured in the container image. Cannot be updated. type: string required: - name type: object type: array dnsConfig: description: Specifies the DNS parameters of a pod. Parameters specified here will be merged to the generated DNS configuration based on DNSPolicy. properties: nameservers: description: A list of DNS name server IP addresses. This will be appended to the base nameservers generated from DNSPolicy. Duplicated nameservers will be removed. items: type: string type: array options: description: A list of DNS resolver options. This will be merged with the base options generated from DNSPolicy. Duplicated entries will be removed. Resolution options given in Options will override those that appear in the base DNSPolicy. items: description: PodDNSConfigOption defines DNS resolver options of a pod. properties: name: description: Required. type: string value: type: string type: object type: array searches: description: A list of DNS search domains for host-name lookup. This will be appended to the base search paths generated from DNSPolicy. Duplicated search paths will be removed. items: type: string type: array type: object dnsPolicy: description: Set DNS policy for the pod. Defaults to "ClusterFirst". Valid values are 'ClusterFirstWithHostNet', 'ClusterFirst', 'Default' or 'None'. DNS parameters given in DNSConfig will be merged with the policy selected with DNSPolicy. To have DNS options set along with hostNetwork, you have to specify DNS policy explicitly to 'ClusterFirstWithHostNet'. type: string enableServiceLinks: description: 'EnableServiceLinks indicates whether information about services should be injected into pod''s environment variables, matching the syntax of Docker links. Optional: Defaults to true.' type: boolean ephemeralContainers: description: List of ephemeral containers run in this pod. Ephemeral containers may be run in an existing pod to perform user-initiated actions such as debugging. This list cannot be specified when creating a pod, and it cannot be modified by updating the pod spec. In order to add an ephemeral container to an existing pod, use the pod's ephemeralcontainers subresource. This field is alpha-level and is only honored by servers that enable the EphemeralContainers feature. items: description: An EphemeralContainer is a container that may be added temporarily to an existing pod for user-initiated activities such as debugging. Ephemeral containers have no resource or scheduling guarantees, and they will not be restarted when they exit or when a pod is removed or restarted. If an ephemeral container causes a pod to exceed its resource allocation, the pod may be evicted. Ephemeral containers may not be added by directly updating the pod spec. They must be added via the pod's ephemeralcontainers subresource, and they will appear in the pod spec once added. This is an alpha feature enabled by the EphemeralContainers feature flag. properties: args: description: 'Arguments to the entrypoint. The docker image''s CMD is used if this is not provided. Variable references $(VAR_NAME) are expanded using the container''s environment. If a variable cannot be resolved, the reference in the input string will be unchanged. The $(VAR_NAME) syntax can be escaped with a double $$, ie: $$(VAR_NAME). Escaped references will never be expanded, regardless of whether the variable exists or not. Cannot be updated. More info: https://kubernetes.io/docs/tasks/inject-data-application/define-command-argument-container/#running-a-command-in-a-shell' items: type: string type: array command: description: 'Entrypoint array. Not executed within a shell. The docker image''s ENTRYPOINT is used if this is not provided. Variable references $(VAR_NAME) are expanded using the container''s environment. If a variable cannot be resolved, the reference in the input string will be unchanged. The $(VAR_NAME) syntax can be escaped with a double $$, ie: $$(VAR_NAME). Escaped references will never be expanded, regardless of whether the variable exists or not. Cannot be updated. More info: https://kubernetes.io/docs/tasks/inject-data-application/define-command-argument-container/#running-a-command-in-a-shell' items: type: string type: array env: description: List of environment variables to set in the container. Cannot be updated. items: description: EnvVar represents an environment variable present in a Container. properties: name: description: Name of the environment variable. Must be a C_IDENTIFIER. type: string value: description: 'Variable references $(VAR_NAME) are expanded using the previous defined environment variables in the container and any service environment variables. If a variable cannot be resolved, the reference in the input string will be unchanged. The $(VAR_NAME) syntax can be escaped with a double $$, ie: $$(VAR_NAME). Escaped references will never be expanded, regardless of whether the variable exists or not. Defaults to "".' type: string valueFrom: description: Source for the environment variable's value. Cannot be used if value is not empty. properties: configMapKeyRef: description: Selects a key of a ConfigMap. properties: key: description: The key to select. type: string name: description: 'Name of the referent. More info: https://kubernetes.io/docs/concepts/overview/working-with-objects/names/#names TODO: Add other useful fields. apiVersion, kind, uid?' type: string optional: description: Specify whether the ConfigMap or its key must be defined type: boolean required: - key type: object fieldRef: description: 'Selects a field of the pod: supports metadata.name, metadata.namespace, `metadata.labels['''']`, `metadata.annotations['''']`, spec.nodeName, spec.serviceAccountName, status.hostIP, status.podIP, status.podIPs.' properties: apiVersion: description: Version of the schema the FieldPath is written in terms of, defaults to "v1". type: string fieldPath: description: Path of the field to select in the specified API version. type: string required: - fieldPath type: object resourceFieldRef: description: 'Selects a resource of the container: only resources limits and requests (limits.cpu, limits.memory, limits.ephemeral-storage, requests.cpu, requests.memory and requests.ephemeral-storage) are currently supported.' properties: containerName: description: 'Container name: required for volumes, optional for env vars' type: string divisor: anyOf: - type: integer - type: string description: Specifies the output format of the exposed resources, defaults to "1" pattern: ^(\+|-)?(([0-9]+(\.[0-9]*)?)|(\.[0-9]+))(([KMGTPE]i)|[numkMGTPE]|([eE](\+|-)?(([0-9]+(\.[0-9]*)?)|(\.[0-9]+))))?$ x-kubernetes-int-or-string: true resource: description: 'Required: resource to select' type: string required: - resource type: object secretKeyRef: description: Selects a key of a secret in the pod's namespace properties: key: description: The key of the secret to select from. Must be a valid secret key. type: string name: description: 'Name of the referent. More info: https://kubernetes.io/docs/concepts/overview/working-with-objects/names/#names TODO: Add other useful fields. apiVersion, kind, uid?' type: string optional: description: Specify whether the Secret or its key must be defined type: boolean required: - key type: object type: object required: - name type: object type: array envFrom: description: List of sources to populate environment variables in the container. The keys defined within a source must be a C_IDENTIFIER. All invalid keys will be reported as an event when the container is starting. When a key exists in multiple sources, the value associated with the last source will take precedence. Values defined by an Env with a duplicate key will take precedence. Cannot be updated. items: description: EnvFromSource represents the source of a set of ConfigMaps properties: configMapRef: description: The ConfigMap to select from properties: name: description: 'Name of the referent. More info: https://kubernetes.io/docs/concepts/overview/working-with-objects/names/#names TODO: Add other useful fields. apiVersion, kind, uid?' type: string optional: description: Specify whether the ConfigMap must be defined type: boolean type: object prefix: description: An optional identifier to prepend to each key in the ConfigMap. Must be a C_IDENTIFIER. type: string secretRef: description: The Secret to select from properties: name: description: 'Name of the referent. More info: https://kubernetes.io/docs/concepts/overview/working-with-objects/names/#names TODO: Add other useful fields. apiVersion, kind, uid?' type: string optional: description: Specify whether the Secret must be defined type: boolean type: object type: object type: array image: description: 'Docker image name. More info: https://kubernetes.io/docs/concepts/containers/images' type: string imagePullPolicy: description: 'Image pull policy. One of Always, Never, IfNotPresent. Defaults to Always if :latest tag is specified, or IfNotPresent otherwise. Cannot be updated. More info: https://kubernetes.io/docs/concepts/containers/images#updating-images' type: string lifecycle: description: Lifecycle is not allowed for ephemeral containers. properties: postStart: description: 'PostStart is called immediately after a container is created. If the handler fails, the container is terminated and restarted according to its restart policy. Other management of the container blocks until the hook completes. More info: https://kubernetes.io/docs/concepts/containers/container-lifecycle-hooks/#container-hooks' properties: exec: description: One and only one of the following should be specified. Exec specifies the action to take. properties: command: description: Command is the command line to execute inside the container, the working directory for the command is root ('/') in the container's filesystem. The command is simply exec'd, it is not run inside a shell, so traditional shell instructions ('|', etc) won't work. To use a shell, you need to explicitly call out to that shell. Exit status of 0 is treated as live/healthy and non-zero is unhealthy. items: type: string type: array type: object httpGet: description: HTTPGet specifies the http request to perform. properties: host: description: Host name to connect to, defaults to the pod IP. You probably want to set "Host" in httpHeaders instead. type: string httpHeaders: description: Custom headers to set in the request. HTTP allows repeated headers. items: description: HTTPHeader describes a custom header to be used in HTTP probes properties: name: description: The header field name type: string value: description: The header field value type: string required: - name - value type: object type: array path: description: Path to access on the HTTP server. type: string port: anyOf: - type: integer - type: string description: Name or number of the port to access on the container. Number must be in the range 1 to 65535. Name must be an IANA_SVC_NAME. x-kubernetes-int-or-string: true scheme: description: Scheme to use for connecting to the host. Defaults to HTTP. type: string required: - port type: object tcpSocket: description: 'TCPSocket specifies an action involving a TCP port. TCP hooks not yet supported TODO: implement a realistic TCP lifecycle hook' properties: host: description: 'Optional: Host name to connect to, defaults to the pod IP.' type: string port: anyOf: - type: integer - type: string description: Number or name of the port to access on the container. Number must be in the range 1 to 65535. Name must be an IANA_SVC_NAME. x-kubernetes-int-or-string: true required: - port type: object type: object preStop: description: 'PreStop is called immediately before a container is terminated due to an API request or management event such as liveness/startup probe failure, preemption, resource contention, etc. The handler is not called if the container crashes or exits. The reason for termination is passed to the handler. The Pod''s termination grace period countdown begins before the PreStop hooked is executed. Regardless of the outcome of the handler, the container will eventually terminate within the Pod''s termination grace period. Other management of the container blocks until the hook completes or until the termination grace period is reached. More info: https://kubernetes.io/docs/concepts/containers/container-lifecycle-hooks/#container-hooks' properties: exec: description: One and only one of the following should be specified. Exec specifies the action to take. properties: command: description: Command is the command line to execute inside the container, the working directory for the command is root ('/') in the container's filesystem. The command is simply exec'd, it is not run inside a shell, so traditional shell instructions ('|', etc) won't work. To use a shell, you need to explicitly call out to that shell. Exit status of 0 is treated as live/healthy and non-zero is unhealthy. items: type: string type: array type: object httpGet: description: HTTPGet specifies the http request to perform. properties: host: description: Host name to connect to, defaults to the pod IP. You probably want to set "Host" in httpHeaders instead. type: string httpHeaders: description: Custom headers to set in the request. HTTP allows repeated headers. items: description: HTTPHeader describes a custom header to be used in HTTP probes properties: name: description: The header field name type: string value: description: The header field value type: string required: - name - value type: object type: array path: description: Path to access on the HTTP server. type: string port: anyOf: - type: integer - type: string description: Name or number of the port to access on the container. Number must be in the range 1 to 65535. Name must be an IANA_SVC_NAME. x-kubernetes-int-or-string: true scheme: description: Scheme to use for connecting to the host. Defaults to HTTP. type: string required: - port type: object tcpSocket: description: 'TCPSocket specifies an action involving a TCP port. TCP hooks not yet supported TODO: implement a realistic TCP lifecycle hook' properties: host: description: 'Optional: Host name to connect to, defaults to the pod IP.' type: string port: anyOf: - type: integer - type: string description: Number or name of the port to access on the container. Number must be in the range 1 to 65535. Name must be an IANA_SVC_NAME. x-kubernetes-int-or-string: true required: - port type: object type: object type: object livenessProbe: description: Probes are not allowed for ephemeral containers. properties: exec: description: One and only one of the following should be specified. Exec specifies the action to take. properties: command: description: Command is the command line to execute inside the container, the working directory for the command is root ('/') in the container's filesystem. The command is simply exec'd, it is not run inside a shell, so traditional shell instructions ('|', etc) won't work. To use a shell, you need to explicitly call out to that shell. Exit status of 0 is treated as live/healthy and non-zero is unhealthy. items: type: string type: array type: object failureThreshold: description: Minimum consecutive failures for the probe to be considered failed after having succeeded. Defaults to 3. Minimum value is 1. format: int32 type: integer httpGet: description: HTTPGet specifies the http request to perform. properties: host: description: Host name to connect to, defaults to the pod IP. You probably want to set "Host" in httpHeaders instead. type: string httpHeaders: description: Custom headers to set in the request. HTTP allows repeated headers. items: description: HTTPHeader describes a custom header to be used in HTTP probes properties: name: description: The header field name type: string value: description: The header field value type: string required: - name - value type: object type: array path: description: Path to access on the HTTP server. type: string port: anyOf: - type: integer - type: string description: Name or number of the port to access on the container. Number must be in the range 1 to 65535. Name must be an IANA_SVC_NAME. x-kubernetes-int-or-string: true scheme: description: Scheme to use for connecting to the host. Defaults to HTTP. type: string required: - port type: object initialDelaySeconds: description: 'Number of seconds after the container has started before liveness probes are initiated. More info: https://kubernetes.io/docs/concepts/workloads/pods/pod-lifecycle#container-probes' format: int32 type: integer periodSeconds: description: How often (in seconds) to perform the probe. Default to 10 seconds. Minimum value is 1. format: int32 type: integer successThreshold: description: Minimum consecutive successes for the probe to be considered successful after having failed. Defaults to 1. Must be 1 for liveness and startup. Minimum value is 1. format: int32 type: integer tcpSocket: description: 'TCPSocket specifies an action involving a TCP port. TCP hooks not yet supported TODO: implement a realistic TCP lifecycle hook' properties: host: description: 'Optional: Host name to connect to, defaults to the pod IP.' type: string port: anyOf: - type: integer - type: string description: Number or name of the port to access on the container. Number must be in the range 1 to 65535. Name must be an IANA_SVC_NAME. x-kubernetes-int-or-string: true required: - port type: object timeoutSeconds: description: 'Number of seconds after which the probe times out. Defaults to 1 second. Minimum value is 1. More info: https://kubernetes.io/docs/concepts/workloads/pods/pod-lifecycle#container-probes' format: int32 type: integer type: object name: description: Name of the ephemeral container specified as a DNS_LABEL. This name must be unique among all containers, init containers and ephemeral containers. type: string ports: description: Ports are not allowed for ephemeral containers. items: description: ContainerPort represents a network port in a single container. properties: containerPort: description: Number of port to expose on the pod's IP address. This must be a valid port number, 0 < x < 65536. format: int32 type: integer hostIP: description: What host IP to bind the external port to. type: string hostPort: description: Number of port to expose on the host. If specified, this must be a valid port number, 0 < x < 65536. If HostNetwork is specified, this must match ContainerPort. Most containers do not need this. format: int32 type: integer name: description: If specified, this must be an IANA_SVC_NAME and unique within the pod. Each named port in a pod must have a unique name. Name for the port that can be referred to by services. type: string protocol: default: TCP description: Protocol for port. Must be UDP, TCP, or SCTP. Defaults to "TCP". type: string required: - containerPort type: object type: array readinessProbe: description: Probes are not allowed for ephemeral containers. properties: exec: description: One and only one of the following should be specified. Exec specifies the action to take. properties: command: description: Command is the command line to execute inside the container, the working directory for the command is root ('/') in the container's filesystem. The command is simply exec'd, it is not run inside a shell, so traditional shell instructions ('|', etc) won't work. To use a shell, you need to explicitly call out to that shell. Exit status of 0 is treated as live/healthy and non-zero is unhealthy. items: type: string type: array type: object failureThreshold: description: Minimum consecutive failures for the probe to be considered failed after having succeeded. Defaults to 3. Minimum value is 1. format: int32 type: integer httpGet: description: HTTPGet specifies the http request to perform. properties: host: description: Host name to connect to, defaults to the pod IP. You probably want to set "Host" in httpHeaders instead. type: string httpHeaders: description: Custom headers to set in the request. HTTP allows repeated headers. items: description: HTTPHeader describes a custom header to be used in HTTP probes properties: name: description: The header field name type: string value: description: The header field value type: string required: - name - value type: object type: array path: description: Path to access on the HTTP server. type: string port: anyOf: - type: integer - type: string description: Name or number of the port to access on the container. Number must be in the range 1 to 65535. Name must be an IANA_SVC_NAME. x-kubernetes-int-or-string: true scheme: description: Scheme to use for connecting to the host. Defaults to HTTP. type: string required: - port type: object initialDelaySeconds: description: 'Number of seconds after the container has started before liveness probes are initiated. More info: https://kubernetes.io/docs/concepts/workloads/pods/pod-lifecycle#container-probes' format: int32 type: integer periodSeconds: description: How often (in seconds) to perform the probe. Default to 10 seconds. Minimum value is 1. format: int32 type: integer successThreshold: description: Minimum consecutive successes for the probe to be considered successful after having failed. Defaults to 1. Must be 1 for liveness and startup. Minimum value is 1. format: int32 type: integer tcpSocket: description: 'TCPSocket specifies an action involving a TCP port. TCP hooks not yet supported TODO: implement a realistic TCP lifecycle hook' properties: host: description: 'Optional: Host name to connect to, defaults to the pod IP.' type: string port: anyOf: - type: integer - type: string description: Number or name of the port to access on the container. Number must be in the range 1 to 65535. Name must be an IANA_SVC_NAME. x-kubernetes-int-or-string: true required: - port type: object timeoutSeconds: description: 'Number of seconds after which the probe times out. Defaults to 1 second. Minimum value is 1. More info: https://kubernetes.io/docs/concepts/workloads/pods/pod-lifecycle#container-probes' format: int32 type: integer type: object resources: description: Resources are not allowed for ephemeral containers. Ephemeral containers use spare resources already allocated to the pod. properties: limits: additionalProperties: anyOf: - type: integer - type: string pattern: ^(\+|-)?(([0-9]+(\.[0-9]*)?)|(\.[0-9]+))(([KMGTPE]i)|[numkMGTPE]|([eE](\+|-)?(([0-9]+(\.[0-9]*)?)|(\.[0-9]+))))?$ x-kubernetes-int-or-string: true description: 'Limits describes the maximum amount of compute resources allowed. More info: https://kubernetes.io/docs/concepts/configuration/manage-compute-resources-container/' type: object requests: additionalProperties: anyOf: - type: integer - type: string pattern: ^(\+|-)?(([0-9]+(\.[0-9]*)?)|(\.[0-9]+))(([KMGTPE]i)|[numkMGTPE]|([eE](\+|-)?(([0-9]+(\.[0-9]*)?)|(\.[0-9]+))))?$ x-kubernetes-int-or-string: true description: 'Requests describes the minimum amount of compute resources required. If Requests is omitted for a container, it defaults to Limits if that is explicitly specified, otherwise to an implementation-defined value. More info: https://kubernetes.io/docs/concepts/configuration/manage-compute-resources-container/' type: object type: object securityContext: description: SecurityContext is not allowed for ephemeral containers. properties: allowPrivilegeEscalation: description: 'AllowPrivilegeEscalation controls whether a process can gain more privileges than its parent process. This bool directly controls if the no_new_privs flag will be set on the container process. AllowPrivilegeEscalation is true always when the container is: 1) run as Privileged 2) has CAP_SYS_ADMIN' type: boolean capabilities: description: The capabilities to add/drop when running containers. Defaults to the default set of capabilities granted by the container runtime. properties: add: description: Added capabilities items: description: Capability represent POSIX capabilities type type: string type: array drop: description: Removed capabilities items: description: Capability represent POSIX capabilities type type: string type: array type: object privileged: description: Run container in privileged mode. Processes in privileged containers are essentially equivalent to root on the host. Defaults to false. type: boolean procMount: description: procMount denotes the type of proc mount to use for the containers. The default is DefaultProcMount which uses the container runtime defaults for readonly paths and masked paths. This requires the ProcMountType feature flag to be enabled. type: string readOnlyRootFilesystem: description: Whether this container has a read-only root filesystem. Default is false. type: boolean runAsGroup: description: The GID to run the entrypoint of the container process. Uses runtime default if unset. May also be set in PodSecurityContext. If set in both SecurityContext and PodSecurityContext, the value specified in SecurityContext takes precedence. format: int64 type: integer runAsNonRoot: description: Indicates that the container must run as a non-root user. If true, the Kubelet will validate the image at runtime to ensure that it does not run as UID 0 (root) and fail to start the container if it does. If unset or false, no such validation will be performed. May also be set in PodSecurityContext. If set in both SecurityContext and PodSecurityContext, the value specified in SecurityContext takes precedence. type: boolean runAsUser: description: The UID to run the entrypoint of the container process. Defaults to user specified in image metadata if unspecified. May also be set in PodSecurityContext. If set in both SecurityContext and PodSecurityContext, the value specified in SecurityContext takes precedence. format: int64 type: integer seLinuxOptions: description: The SELinux context to be applied to the container. If unspecified, the container runtime will allocate a random SELinux context for each container. May also be set in PodSecurityContext. If set in both SecurityContext and PodSecurityContext, the value specified in SecurityContext takes precedence. properties: level: description: Level is SELinux level label that applies to the container. type: string role: description: Role is a SELinux role label that applies to the container. type: string type: description: Type is a SELinux type label that applies to the container. type: string user: description: User is a SELinux user label that applies to the container. type: string type: object seccompProfile: description: The seccomp options to use by this container. If seccomp options are provided at both the pod & container level, the container options override the pod options. properties: localhostProfile: description: localhostProfile indicates a profile defined in a file on the node should be used. The profile must be preconfigured on the node to work. Must be a descending path, relative to the kubelet's configured seccomp profile location. Must only be set if type is "Localhost". type: string type: description: "type indicates which kind of seccomp profile will be applied. Valid options are: \n Localhost - a profile defined in a file on the node should be used. RuntimeDefault - the container runtime default profile should be used. Unconfined - no profile should be applied." type: string required: - type type: object windowsOptions: description: The Windows specific settings applied to all containers. If unspecified, the options from the PodSecurityContext will be used. If set in both SecurityContext and PodSecurityContext, the value specified in SecurityContext takes precedence. properties: gmsaCredentialSpec: description: GMSACredentialSpec is where the GMSA admission webhook (https://github.com/kubernetes-sigs/windows-gmsa) inlines the contents of the GMSA credential spec named by the GMSACredentialSpecName field. type: string gmsaCredentialSpecName: description: GMSACredentialSpecName is the name of the GMSA credential spec to use. type: string runAsUserName: description: The UserName in Windows to run the entrypoint of the container process. Defaults to the user specified in image metadata if unspecified. May also be set in PodSecurityContext. If set in both SecurityContext and PodSecurityContext, the value specified in SecurityContext takes precedence. type: string type: object type: object startupProbe: description: Probes are not allowed for ephemeral containers. properties: exec: description: One and only one of the following should be specified. Exec specifies the action to take. properties: command: description: Command is the command line to execute inside the container, the working directory for the command is root ('/') in the container's filesystem. The command is simply exec'd, it is not run inside a shell, so traditional shell instructions ('|', etc) won't work. To use a shell, you need to explicitly call out to that shell. Exit status of 0 is treated as live/healthy and non-zero is unhealthy. items: type: string type: array type: object failureThreshold: description: Minimum consecutive failures for the probe to be considered failed after having succeeded. Defaults to 3. Minimum value is 1. format: int32 type: integer httpGet: description: HTTPGet specifies the http request to perform. properties: host: description: Host name to connect to, defaults to the pod IP. You probably want to set "Host" in httpHeaders instead. type: string httpHeaders: description: Custom headers to set in the request. HTTP allows repeated headers. items: description: HTTPHeader describes a custom header to be used in HTTP probes properties: name: description: The header field name type: string value: description: The header field value type: string required: - name - value type: object type: array path: description: Path to access on the HTTP server. type: string port: anyOf: - type: integer - type: string description: Name or number of the port to access on the container. Number must be in the range 1 to 65535. Name must be an IANA_SVC_NAME. x-kubernetes-int-or-string: true scheme: description: Scheme to use for connecting to the host. Defaults to HTTP. type: string required: - port type: object initialDelaySeconds: description: 'Number of seconds after the container has started before liveness probes are initiated. More info: https://kubernetes.io/docs/concepts/workloads/pods/pod-lifecycle#container-probes' format: int32 type: integer periodSeconds: description: How often (in seconds) to perform the probe. Default to 10 seconds. Minimum value is 1. format: int32 type: integer successThreshold: description: Minimum consecutive successes for the probe to be considered successful after having failed. Defaults to 1. Must be 1 for liveness and startup. Minimum value is 1. format: int32 type: integer tcpSocket: description: 'TCPSocket specifies an action involving a TCP port. TCP hooks not yet supported TODO: implement a realistic TCP lifecycle hook' properties: host: description: 'Optional: Host name to connect to, defaults to the pod IP.' type: string port: anyOf: - type: integer - type: string description: Number or name of the port to access on the container. Number must be in the range 1 to 65535. Name must be an IANA_SVC_NAME. x-kubernetes-int-or-string: true required: - port type: object timeoutSeconds: description: 'Number of seconds after which the probe times out. Defaults to 1 second. Minimum value is 1. More info: https://kubernetes.io/docs/concepts/workloads/pods/pod-lifecycle#container-probes' format: int32 type: integer type: object stdin: description: Whether this container should allocate a buffer for stdin in the container runtime. If this is not set, reads from stdin in the container will always result in EOF. Default is false. type: boolean stdinOnce: description: Whether the container runtime should close the stdin channel after it has been opened by a single attach. When stdin is true the stdin stream will remain open across multiple attach sessions. If stdinOnce is set to true, stdin is opened on container start, is empty until the first client attaches to stdin, and then remains open and accepts data until the client disconnects, at which time stdin is closed and remains closed until the container is restarted. If this flag is false, a container processes that reads from stdin will never receive an EOF. Default is false type: boolean targetContainerName: description: If set, the name of the container from PodSpec that this ephemeral container targets. The ephemeral container will be run in the namespaces (IPC, PID, etc) of this container. If not set then the ephemeral container is run in whatever namespaces are shared for the pod. Note that the container runtime must support this feature. type: string terminationMessagePath: description: 'Optional: Path at which the file to which the container''s termination message will be written is mounted into the container''s filesystem. Message written is intended to be brief final status, such as an assertion failure message. Will be truncated by the node if greater than 4096 bytes. The total message length across all containers will be limited to 12kb. Defaults to /dev/termination-log. Cannot be updated.' type: string terminationMessagePolicy: description: Indicate how the termination message should be populated. File will use the contents of terminationMessagePath to populate the container status message on both success and failure. FallbackToLogsOnError will use the last chunk of container log output if the termination message file is empty and the container exited with an error. The log output is limited to 2048 bytes or 80 lines, whichever is smaller. Defaults to File. Cannot be updated. type: string tty: description: Whether this container should allocate a TTY for itself, also requires 'stdin' to be true. Default is false. type: boolean volumeDevices: description: volumeDevices is the list of block devices to be used by the container. items: description: volumeDevice describes a mapping of a raw block device within a container. properties: devicePath: description: devicePath is the path inside of the container that the device will be mapped to. type: string name: description: name must match the name of a persistentVolumeClaim in the pod type: string required: - devicePath - name type: object type: array volumeMounts: description: Pod volumes to mount into the container's filesystem. Cannot be updated. items: description: VolumeMount describes a mounting of a Volume within a container. properties: mountPath: description: Path within the container at which the volume should be mounted. Must not contain ':'. type: string mountPropagation: description: mountPropagation determines how mounts are propagated from the host to container and the other way around. When not set, MountPropagationNone is used. This field is beta in 1.10. type: string name: description: This must match the Name of a Volume. type: string readOnly: description: Mounted read-only if true, read-write otherwise (false or unspecified). Defaults to false. type: boolean subPath: description: Path within the volume from which the container's volume should be mounted. Defaults to "" (volume's root). type: string subPathExpr: description: Expanded path within the volume from which the container's volume should be mounted. Behaves similarly to SubPath but environment variable references $(VAR_NAME) are expanded using the container's environment. Defaults to "" (volume's root). SubPathExpr and SubPath are mutually exclusive. type: string required: - mountPath - name type: object type: array workingDir: description: Container's working directory. If not specified, the container runtime's default will be used, which might be configured in the container image. Cannot be updated. type: string required: - name type: object type: array hostAliases: description: HostAliases is an optional list of hosts and IPs that will be injected into the pod's hosts file if specified. This is only valid for non-hostNetwork pods. items: description: HostAlias holds the mapping between IP and hostnames that will be injected as an entry in the pod's hosts file. properties: hostnames: description: Hostnames for the above IP address. items: type: string type: array ip: description: IP address of the host file entry. type: string type: object type: array hostIPC: description: 'Use the host''s ipc namespace. Optional: Default to false.' type: boolean hostNetwork: description: Host networking requested for this pod. Use the host's network namespace. If this option is set, the ports that will be used must be specified. Default to false. type: boolean hostPID: description: 'Use the host''s pid namespace. Optional: Default to false.' type: boolean hostname: description: Specifies the hostname of the Pod If not specified, the pod's hostname will be set to a system-defined value. type: string imagePullSecrets: description: 'ImagePullSecrets is an optional list of references to secrets in the same namespace to use for pulling any of the images used by this PodSpec. If specified, these secrets will be passed to individual puller implementations for them to use. For example, in the case of docker, only DockerConfig type secrets are honored. More info: https://kubernetes.io/docs/concepts/containers/images#specifying-imagepullsecrets-on-a-pod' items: description: LocalObjectReference contains enough information to let you locate the referenced object inside the same namespace. properties: name: description: 'Name of the referent. More info: https://kubernetes.io/docs/concepts/overview/working-with-objects/names/#names TODO: Add other useful fields. apiVersion, kind, uid?' type: string type: object type: array initContainers: description: 'List of initialization containers belonging to the pod. Init containers are executed in order prior to containers being started. If any init container fails, the pod is considered to have failed and is handled according to its restartPolicy. The name for an init container or normal container must be unique among all containers. Init containers may not have Lifecycle actions, Readiness probes, Liveness probes, or Startup probes. The resourceRequirements of an init container are taken into account during scheduling by finding the highest request/limit for each resource type, and then using the max of of that value or the sum of the normal containers. Limits are applied to init containers in a similar fashion. Init containers cannot currently be added or removed. Cannot be updated. More info: https://kubernetes.io/docs/concepts/workloads/pods/init-containers/' items: description: A single application container that you want to run within a pod. properties: args: description: 'Arguments to the entrypoint. The docker image''s CMD is used if this is not provided. Variable references $(VAR_NAME) are expanded using the container''s environment. If a variable cannot be resolved, the reference in the input string will be unchanged. The $(VAR_NAME) syntax can be escaped with a double $$, ie: $$(VAR_NAME). Escaped references will never be expanded, regardless of whether the variable exists or not. Cannot be updated. More info: https://kubernetes.io/docs/tasks/inject-data-application/define-command-argument-container/#running-a-command-in-a-shell' items: type: string type: array command: description: 'Entrypoint array. Not executed within a shell. The docker image''s ENTRYPOINT is used if this is not provided. Variable references $(VAR_NAME) are expanded using the container''s environment. If a variable cannot be resolved, the reference in the input string will be unchanged. The $(VAR_NAME) syntax can be escaped with a double $$, ie: $$(VAR_NAME). Escaped references will never be expanded, regardless of whether the variable exists or not. Cannot be updated. More info: https://kubernetes.io/docs/tasks/inject-data-application/define-command-argument-container/#running-a-command-in-a-shell' items: type: string type: array env: description: List of environment variables to set in the container. Cannot be updated. items: description: EnvVar represents an environment variable present in a Container. properties: name: description: Name of the environment variable. Must be a C_IDENTIFIER. type: string value: description: 'Variable references $(VAR_NAME) are expanded using the previous defined environment variables in the container and any service environment variables. If a variable cannot be resolved, the reference in the input string will be unchanged. The $(VAR_NAME) syntax can be escaped with a double $$, ie: $$(VAR_NAME). Escaped references will never be expanded, regardless of whether the variable exists or not. Defaults to "".' type: string valueFrom: description: Source for the environment variable's value. Cannot be used if value is not empty. properties: configMapKeyRef: description: Selects a key of a ConfigMap. properties: key: description: The key to select. type: string name: description: 'Name of the referent. More info: https://kubernetes.io/docs/concepts/overview/working-with-objects/names/#names TODO: Add other useful fields. apiVersion, kind, uid?' type: string optional: description: Specify whether the ConfigMap or its key must be defined type: boolean required: - key type: object fieldRef: description: 'Selects a field of the pod: supports metadata.name, metadata.namespace, `metadata.labels['''']`, `metadata.annotations['''']`, spec.nodeName, spec.serviceAccountName, status.hostIP, status.podIP, status.podIPs.' properties: apiVersion: description: Version of the schema the FieldPath is written in terms of, defaults to "v1". type: string fieldPath: description: Path of the field to select in the specified API version. type: string required: - fieldPath type: object resourceFieldRef: description: 'Selects a resource of the container: only resources limits and requests (limits.cpu, limits.memory, limits.ephemeral-storage, requests.cpu, requests.memory and requests.ephemeral-storage) are currently supported.' properties: containerName: description: 'Container name: required for volumes, optional for env vars' type: string divisor: anyOf: - type: integer - type: string description: Specifies the output format of the exposed resources, defaults to "1" pattern: ^(\+|-)?(([0-9]+(\.[0-9]*)?)|(\.[0-9]+))(([KMGTPE]i)|[numkMGTPE]|([eE](\+|-)?(([0-9]+(\.[0-9]*)?)|(\.[0-9]+))))?$ x-kubernetes-int-or-string: true resource: description: 'Required: resource to select' type: string required: - resource type: object secretKeyRef: description: Selects a key of a secret in the pod's namespace properties: key: description: The key of the secret to select from. Must be a valid secret key. type: string name: description: 'Name of the referent. More info: https://kubernetes.io/docs/concepts/overview/working-with-objects/names/#names TODO: Add other useful fields. apiVersion, kind, uid?' type: string optional: description: Specify whether the Secret or its key must be defined type: boolean required: - key type: object type: object required: - name type: object type: array envFrom: description: List of sources to populate environment variables in the container. The keys defined within a source must be a C_IDENTIFIER. All invalid keys will be reported as an event when the container is starting. When a key exists in multiple sources, the value associated with the last source will take precedence. Values defined by an Env with a duplicate key will take precedence. Cannot be updated. items: description: EnvFromSource represents the source of a set of ConfigMaps properties: configMapRef: description: The ConfigMap to select from properties: name: description: 'Name of the referent. More info: https://kubernetes.io/docs/concepts/overview/working-with-objects/names/#names TODO: Add other useful fields. apiVersion, kind, uid?' type: string optional: description: Specify whether the ConfigMap must be defined type: boolean type: object prefix: description: An optional identifier to prepend to each key in the ConfigMap. Must be a C_IDENTIFIER. type: string secretRef: description: The Secret to select from properties: name: description: 'Name of the referent. More info: https://kubernetes.io/docs/concepts/overview/working-with-objects/names/#names TODO: Add other useful fields. apiVersion, kind, uid?' type: string optional: description: Specify whether the Secret must be defined type: boolean type: object type: object type: array image: description: 'Docker image name. More info: https://kubernetes.io/docs/concepts/containers/images This field is optional to allow higher level config management to default or override container images in workload controllers like Deployments and StatefulSets.' type: string imagePullPolicy: description: 'Image pull policy. One of Always, Never, IfNotPresent. Defaults to Always if :latest tag is specified, or IfNotPresent otherwise. Cannot be updated. More info: https://kubernetes.io/docs/concepts/containers/images#updating-images' type: string lifecycle: description: Actions that the management system should take in response to container lifecycle events. Cannot be updated. properties: postStart: description: 'PostStart is called immediately after a container is created. If the handler fails, the container is terminated and restarted according to its restart policy. Other management of the container blocks until the hook completes. More info: https://kubernetes.io/docs/concepts/containers/container-lifecycle-hooks/#container-hooks' properties: exec: description: One and only one of the following should be specified. Exec specifies the action to take. properties: command: description: Command is the command line to execute inside the container, the working directory for the command is root ('/') in the container's filesystem. The command is simply exec'd, it is not run inside a shell, so traditional shell instructions ('|', etc) won't work. To use a shell, you need to explicitly call out to that shell. Exit status of 0 is treated as live/healthy and non-zero is unhealthy. items: type: string type: array type: object httpGet: description: HTTPGet specifies the http request to perform. properties: host: description: Host name to connect to, defaults to the pod IP. You probably want to set "Host" in httpHeaders instead. type: string httpHeaders: description: Custom headers to set in the request. HTTP allows repeated headers. items: description: HTTPHeader describes a custom header to be used in HTTP probes properties: name: description: The header field name type: string value: description: The header field value type: string required: - name - value type: object type: array path: description: Path to access on the HTTP server. type: string port: anyOf: - type: integer - type: string description: Name or number of the port to access on the container. Number must be in the range 1 to 65535. Name must be an IANA_SVC_NAME. x-kubernetes-int-or-string: true scheme: description: Scheme to use for connecting to the host. Defaults to HTTP. type: string required: - port type: object tcpSocket: description: 'TCPSocket specifies an action involving a TCP port. TCP hooks not yet supported TODO: implement a realistic TCP lifecycle hook' properties: host: description: 'Optional: Host name to connect to, defaults to the pod IP.' type: string port: anyOf: - type: integer - type: string description: Number or name of the port to access on the container. Number must be in the range 1 to 65535. Name must be an IANA_SVC_NAME. x-kubernetes-int-or-string: true required: - port type: object type: object preStop: description: 'PreStop is called immediately before a container is terminated due to an API request or management event such as liveness/startup probe failure, preemption, resource contention, etc. The handler is not called if the container crashes or exits. The reason for termination is passed to the handler. The Pod''s termination grace period countdown begins before the PreStop hooked is executed. Regardless of the outcome of the handler, the container will eventually terminate within the Pod''s termination grace period. Other management of the container blocks until the hook completes or until the termination grace period is reached. More info: https://kubernetes.io/docs/concepts/containers/container-lifecycle-hooks/#container-hooks' properties: exec: description: One and only one of the following should be specified. Exec specifies the action to take. properties: command: description: Command is the command line to execute inside the container, the working directory for the command is root ('/') in the container's filesystem. The command is simply exec'd, it is not run inside a shell, so traditional shell instructions ('|', etc) won't work. To use a shell, you need to explicitly call out to that shell. Exit status of 0 is treated as live/healthy and non-zero is unhealthy. items: type: string type: array type: object httpGet: description: HTTPGet specifies the http request to perform. properties: host: description: Host name to connect to, defaults to the pod IP. You probably want to set "Host" in httpHeaders instead. type: string httpHeaders: description: Custom headers to set in the request. HTTP allows repeated headers. items: description: HTTPHeader describes a custom header to be used in HTTP probes properties: name: description: The header field name type: string value: description: The header field value type: string required: - name - value type: object type: array path: description: Path to access on the HTTP server. type: string port: anyOf: - type: integer - type: string description: Name or number of the port to access on the container. Number must be in the range 1 to 65535. Name must be an IANA_SVC_NAME. x-kubernetes-int-or-string: true scheme: description: Scheme to use for connecting to the host. Defaults to HTTP. type: string required: - port type: object tcpSocket: description: 'TCPSocket specifies an action involving a TCP port. TCP hooks not yet supported TODO: implement a realistic TCP lifecycle hook' properties: host: description: 'Optional: Host name to connect to, defaults to the pod IP.' type: string port: anyOf: - type: integer - type: string description: Number or name of the port to access on the container. Number must be in the range 1 to 65535. Name must be an IANA_SVC_NAME. x-kubernetes-int-or-string: true required: - port type: object type: object type: object livenessProbe: description: 'Periodic probe of container liveness. Container will be restarted if the probe fails. Cannot be updated. More info: https://kubernetes.io/docs/concepts/workloads/pods/pod-lifecycle#container-probes' properties: exec: description: One and only one of the following should be specified. Exec specifies the action to take. properties: command: description: Command is the command line to execute inside the container, the working directory for the command is root ('/') in the container's filesystem. The command is simply exec'd, it is not run inside a shell, so traditional shell instructions ('|', etc) won't work. To use a shell, you need to explicitly call out to that shell. Exit status of 0 is treated as live/healthy and non-zero is unhealthy. items: type: string type: array type: object failureThreshold: description: Minimum consecutive failures for the probe to be considered failed after having succeeded. Defaults to 3. Minimum value is 1. format: int32 type: integer httpGet: description: HTTPGet specifies the http request to perform. properties: host: description: Host name to connect to, defaults to the pod IP. You probably want to set "Host" in httpHeaders instead. type: string httpHeaders: description: Custom headers to set in the request. HTTP allows repeated headers. items: description: HTTPHeader describes a custom header to be used in HTTP probes properties: name: description: The header field name type: string value: description: The header field value type: string required: - name - value type: object type: array path: description: Path to access on the HTTP server. type: string port: anyOf: - type: integer - type: string description: Name or number of the port to access on the container. Number must be in the range 1 to 65535. Name must be an IANA_SVC_NAME. x-kubernetes-int-or-string: true scheme: description: Scheme to use for connecting to the host. Defaults to HTTP. type: string required: - port type: object initialDelaySeconds: description: 'Number of seconds after the container has started before liveness probes are initiated. More info: https://kubernetes.io/docs/concepts/workloads/pods/pod-lifecycle#container-probes' format: int32 type: integer periodSeconds: description: How often (in seconds) to perform the probe. Default to 10 seconds. Minimum value is 1. format: int32 type: integer successThreshold: description: Minimum consecutive successes for the probe to be considered successful after having failed. Defaults to 1. Must be 1 for liveness and startup. Minimum value is 1. format: int32 type: integer tcpSocket: description: 'TCPSocket specifies an action involving a TCP port. TCP hooks not yet supported TODO: implement a realistic TCP lifecycle hook' properties: host: description: 'Optional: Host name to connect to, defaults to the pod IP.' type: string port: anyOf: - type: integer - type: string description: Number or name of the port to access on the container. Number must be in the range 1 to 65535. Name must be an IANA_SVC_NAME. x-kubernetes-int-or-string: true required: - port type: object timeoutSeconds: description: 'Number of seconds after which the probe times out. Defaults to 1 second. Minimum value is 1. More info: https://kubernetes.io/docs/concepts/workloads/pods/pod-lifecycle#container-probes' format: int32 type: integer type: object name: description: Name of the container specified as a DNS_LABEL. Each container in a pod must have a unique name (DNS_LABEL). Cannot be updated. type: string ports: description: List of ports to expose from the container. Exposing a port here gives the system additional information about the network connections a container uses, but is primarily informational. Not specifying a port here DOES NOT prevent that port from being exposed. Any port which is listening on the default "0.0.0.0" address inside a container will be accessible from the network. Cannot be updated. items: description: ContainerPort represents a network port in a single container. properties: containerPort: description: Number of port to expose on the pod's IP address. This must be a valid port number, 0 < x < 65536. format: int32 type: integer hostIP: description: What host IP to bind the external port to. type: string hostPort: description: Number of port to expose on the host. If specified, this must be a valid port number, 0 < x < 65536. If HostNetwork is specified, this must match ContainerPort. Most containers do not need this. format: int32 type: integer name: description: If specified, this must be an IANA_SVC_NAME and unique within the pod. Each named port in a pod must have a unique name. Name for the port that can be referred to by services. type: string protocol: default: TCP description: Protocol for port. Must be UDP, TCP, or SCTP. Defaults to "TCP". type: string required: - containerPort type: object type: array x-kubernetes-list-map-keys: - containerPort - protocol x-kubernetes-list-type: map readinessProbe: description: 'Periodic probe of container service readiness. Container will be removed from service endpoints if the probe fails. Cannot be updated. More info: https://kubernetes.io/docs/concepts/workloads/pods/pod-lifecycle#container-probes' properties: exec: description: One and only one of the following should be specified. Exec specifies the action to take. properties: command: description: Command is the command line to execute inside the container, the working directory for the command is root ('/') in the container's filesystem. The command is simply exec'd, it is not run inside a shell, so traditional shell instructions ('|', etc) won't work. To use a shell, you need to explicitly call out to that shell. Exit status of 0 is treated as live/healthy and non-zero is unhealthy. items: type: string type: array type: object failureThreshold: description: Minimum consecutive failures for the probe to be considered failed after having succeeded. Defaults to 3. Minimum value is 1. format: int32 type: integer httpGet: description: HTTPGet specifies the http request to perform. properties: host: description: Host name to connect to, defaults to the pod IP. You probably want to set "Host" in httpHeaders instead. type: string httpHeaders: description: Custom headers to set in the request. HTTP allows repeated headers. items: description: HTTPHeader describes a custom header to be used in HTTP probes properties: name: description: The header field name type: string value: description: The header field value type: string required: - name - value type: object type: array path: description: Path to access on the HTTP server. type: string port: anyOf: - type: integer - type: string description: Name or number of the port to access on the container. Number must be in the range 1 to 65535. Name must be an IANA_SVC_NAME. x-kubernetes-int-or-string: true scheme: description: Scheme to use for connecting to the host. Defaults to HTTP. type: string required: - port type: object initialDelaySeconds: description: 'Number of seconds after the container has started before liveness probes are initiated. More info: https://kubernetes.io/docs/concepts/workloads/pods/pod-lifecycle#container-probes' format: int32 type: integer periodSeconds: description: How often (in seconds) to perform the probe. Default to 10 seconds. Minimum value is 1. format: int32 type: integer successThreshold: description: Minimum consecutive successes for the probe to be considered successful after having failed. Defaults to 1. Must be 1 for liveness and startup. Minimum value is 1. format: int32 type: integer tcpSocket: description: 'TCPSocket specifies an action involving a TCP port. TCP hooks not yet supported TODO: implement a realistic TCP lifecycle hook' properties: host: description: 'Optional: Host name to connect to, defaults to the pod IP.' type: string port: anyOf: - type: integer - type: string description: Number or name of the port to access on the container. Number must be in the range 1 to 65535. Name must be an IANA_SVC_NAME. x-kubernetes-int-or-string: true required: - port type: object timeoutSeconds: description: 'Number of seconds after which the probe times out. Defaults to 1 second. Minimum value is 1. More info: https://kubernetes.io/docs/concepts/workloads/pods/pod-lifecycle#container-probes' format: int32 type: integer type: object resources: description: 'Compute Resources required by this container. Cannot be updated. More info: https://kubernetes.io/docs/concepts/configuration/manage-compute-resources-container/' properties: limits: additionalProperties: anyOf: - type: integer - type: string pattern: ^(\+|-)?(([0-9]+(\.[0-9]*)?)|(\.[0-9]+))(([KMGTPE]i)|[numkMGTPE]|([eE](\+|-)?(([0-9]+(\.[0-9]*)?)|(\.[0-9]+))))?$ x-kubernetes-int-or-string: true description: 'Limits describes the maximum amount of compute resources allowed. More info: https://kubernetes.io/docs/concepts/configuration/manage-compute-resources-container/' type: object requests: additionalProperties: anyOf: - type: integer - type: string pattern: ^(\+|-)?(([0-9]+(\.[0-9]*)?)|(\.[0-9]+))(([KMGTPE]i)|[numkMGTPE]|([eE](\+|-)?(([0-9]+(\.[0-9]*)?)|(\.[0-9]+))))?$ x-kubernetes-int-or-string: true description: 'Requests describes the minimum amount of compute resources required. If Requests is omitted for a container, it defaults to Limits if that is explicitly specified, otherwise to an implementation-defined value. More info: https://kubernetes.io/docs/concepts/configuration/manage-compute-resources-container/' type: object type: object securityContext: description: 'Security options the pod should run with. More info: https://kubernetes.io/docs/concepts/policy/security-context/ More info: https://kubernetes.io/docs/tasks/configure-pod-container/security-context/' properties: allowPrivilegeEscalation: description: 'AllowPrivilegeEscalation controls whether a process can gain more privileges than its parent process. This bool directly controls if the no_new_privs flag will be set on the container process. AllowPrivilegeEscalation is true always when the container is: 1) run as Privileged 2) has CAP_SYS_ADMIN' type: boolean capabilities: description: The capabilities to add/drop when running containers. Defaults to the default set of capabilities granted by the container runtime. properties: add: description: Added capabilities items: description: Capability represent POSIX capabilities type type: string type: array drop: description: Removed capabilities items: description: Capability represent POSIX capabilities type type: string type: array type: object privileged: description: Run container in privileged mode. Processes in privileged containers are essentially equivalent to root on the host. Defaults to false. type: boolean procMount: description: procMount denotes the type of proc mount to use for the containers. The default is DefaultProcMount which uses the container runtime defaults for readonly paths and masked paths. This requires the ProcMountType feature flag to be enabled. type: string readOnlyRootFilesystem: description: Whether this container has a read-only root filesystem. Default is false. type: boolean runAsGroup: description: The GID to run the entrypoint of the container process. Uses runtime default if unset. May also be set in PodSecurityContext. If set in both SecurityContext and PodSecurityContext, the value specified in SecurityContext takes precedence. format: int64 type: integer runAsNonRoot: description: Indicates that the container must run as a non-root user. If true, the Kubelet will validate the image at runtime to ensure that it does not run as UID 0 (root) and fail to start the container if it does. If unset or false, no such validation will be performed. May also be set in PodSecurityContext. If set in both SecurityContext and PodSecurityContext, the value specified in SecurityContext takes precedence. type: boolean runAsUser: description: The UID to run the entrypoint of the container process. Defaults to user specified in image metadata if unspecified. May also be set in PodSecurityContext. If set in both SecurityContext and PodSecurityContext, the value specified in SecurityContext takes precedence. format: int64 type: integer seLinuxOptions: description: The SELinux context to be applied to the container. If unspecified, the container runtime will allocate a random SELinux context for each container. May also be set in PodSecurityContext. If set in both SecurityContext and PodSecurityContext, the value specified in SecurityContext takes precedence. properties: level: description: Level is SELinux level label that applies to the container. type: string role: description: Role is a SELinux role label that applies to the container. type: string type: description: Type is a SELinux type label that applies to the container. type: string user: description: User is a SELinux user label that applies to the container. type: string type: object seccompProfile: description: The seccomp options to use by this container. If seccomp options are provided at both the pod & container level, the container options override the pod options. properties: localhostProfile: description: localhostProfile indicates a profile defined in a file on the node should be used. The profile must be preconfigured on the node to work. Must be a descending path, relative to the kubelet's configured seccomp profile location. Must only be set if type is "Localhost". type: string type: description: "type indicates which kind of seccomp profile will be applied. Valid options are: \n Localhost - a profile defined in a file on the node should be used. RuntimeDefault - the container runtime default profile should be used. Unconfined - no profile should be applied." type: string required: - type type: object windowsOptions: description: The Windows specific settings applied to all containers. If unspecified, the options from the PodSecurityContext will be used. If set in both SecurityContext and PodSecurityContext, the value specified in SecurityContext takes precedence. properties: gmsaCredentialSpec: description: GMSACredentialSpec is where the GMSA admission webhook (https://github.com/kubernetes-sigs/windows-gmsa) inlines the contents of the GMSA credential spec named by the GMSACredentialSpecName field. type: string gmsaCredentialSpecName: description: GMSACredentialSpecName is the name of the GMSA credential spec to use. type: string runAsUserName: description: The UserName in Windows to run the entrypoint of the container process. Defaults to the user specified in image metadata if unspecified. May also be set in PodSecurityContext. If set in both SecurityContext and PodSecurityContext, the value specified in SecurityContext takes precedence. type: string type: object type: object startupProbe: description: 'StartupProbe indicates that the Pod has successfully initialized. If specified, no other probes are executed until this completes successfully. If this probe fails, the Pod will be restarted, just as if the livenessProbe failed. This can be used to provide different probe parameters at the beginning of a Pod''s lifecycle, when it might take a long time to load data or warm a cache, than during steady-state operation. This cannot be updated. More info: https://kubernetes.io/docs/concepts/workloads/pods/pod-lifecycle#container-probes' properties: exec: description: One and only one of the following should be specified. Exec specifies the action to take. properties: command: description: Command is the command line to execute inside the container, the working directory for the command is root ('/') in the container's filesystem. The command is simply exec'd, it is not run inside a shell, so traditional shell instructions ('|', etc) won't work. To use a shell, you need to explicitly call out to that shell. Exit status of 0 is treated as live/healthy and non-zero is unhealthy. items: type: string type: array type: object failureThreshold: description: Minimum consecutive failures for the probe to be considered failed after having succeeded. Defaults to 3. Minimum value is 1. format: int32 type: integer httpGet: description: HTTPGet specifies the http request to perform. properties: host: description: Host name to connect to, defaults to the pod IP. You probably want to set "Host" in httpHeaders instead. type: string httpHeaders: description: Custom headers to set in the request. HTTP allows repeated headers. items: description: HTTPHeader describes a custom header to be used in HTTP probes properties: name: description: The header field name type: string value: description: The header field value type: string required: - name - value type: object type: array path: description: Path to access on the HTTP server. type: string port: anyOf: - type: integer - type: string description: Name or number of the port to access on the container. Number must be in the range 1 to 65535. Name must be an IANA_SVC_NAME. x-kubernetes-int-or-string: true scheme: description: Scheme to use for connecting to the host. Defaults to HTTP. type: string required: - port type: object initialDelaySeconds: description: 'Number of seconds after the container has started before liveness probes are initiated. More info: https://kubernetes.io/docs/concepts/workloads/pods/pod-lifecycle#container-probes' format: int32 type: integer periodSeconds: description: How often (in seconds) to perform the probe. Default to 10 seconds. Minimum value is 1. format: int32 type: integer successThreshold: description: Minimum consecutive successes for the probe to be considered successful after having failed. Defaults to 1. Must be 1 for liveness and startup. Minimum value is 1. format: int32 type: integer tcpSocket: description: 'TCPSocket specifies an action involving a TCP port. TCP hooks not yet supported TODO: implement a realistic TCP lifecycle hook' properties: host: description: 'Optional: Host name to connect to, defaults to the pod IP.' type: string port: anyOf: - type: integer - type: string description: Number or name of the port to access on the container. Number must be in the range 1 to 65535. Name must be an IANA_SVC_NAME. x-kubernetes-int-or-string: true required: - port type: object timeoutSeconds: description: 'Number of seconds after which the probe times out. Defaults to 1 second. Minimum value is 1. More info: https://kubernetes.io/docs/concepts/workloads/pods/pod-lifecycle#container-probes' format: int32 type: integer type: object stdin: description: Whether this container should allocate a buffer for stdin in the container runtime. If this is not set, reads from stdin in the container will always result in EOF. Default is false. type: boolean stdinOnce: description: Whether the container runtime should close the stdin channel after it has been opened by a single attach. When stdin is true the stdin stream will remain open across multiple attach sessions. If stdinOnce is set to true, stdin is opened on container start, is empty until the first client attaches to stdin, and then remains open and accepts data until the client disconnects, at which time stdin is closed and remains closed until the container is restarted. If this flag is false, a container processes that reads from stdin will never receive an EOF. Default is false type: boolean terminationMessagePath: description: 'Optional: Path at which the file to which the container''s termination message will be written is mounted into the container''s filesystem. Message written is intended to be brief final status, such as an assertion failure message. Will be truncated by the node if greater than 4096 bytes. The total message length across all containers will be limited to 12kb. Defaults to /dev/termination-log. Cannot be updated.' type: string terminationMessagePolicy: description: Indicate how the termination message should be populated. File will use the contents of terminationMessagePath to populate the container status message on both success and failure. FallbackToLogsOnError will use the last chunk of container log output if the termination message file is empty and the container exited with an error. The log output is limited to 2048 bytes or 80 lines, whichever is smaller. Defaults to File. Cannot be updated. type: string tty: description: Whether this container should allocate a TTY for itself, also requires 'stdin' to be true. Default is false. type: boolean volumeDevices: description: volumeDevices is the list of block devices to be used by the container. items: description: volumeDevice describes a mapping of a raw block device within a container. properties: devicePath: description: devicePath is the path inside of the container that the device will be mapped to. type: string name: description: name must match the name of a persistentVolumeClaim in the pod type: string required: - devicePath - name type: object type: array volumeMounts: description: Pod volumes to mount into the container's filesystem. Cannot be updated. items: description: VolumeMount describes a mounting of a Volume within a container. properties: mountPath: description: Path within the container at which the volume should be mounted. Must not contain ':'. type: string mountPropagation: description: mountPropagation determines how mounts are propagated from the host to container and the other way around. When not set, MountPropagationNone is used. This field is beta in 1.10. type: string name: description: This must match the Name of a Volume. type: string readOnly: description: Mounted read-only if true, read-write otherwise (false or unspecified). Defaults to false. type: boolean subPath: description: Path within the volume from which the container's volume should be mounted. Defaults to "" (volume's root). type: string subPathExpr: description: Expanded path within the volume from which the container's volume should be mounted. Behaves similarly to SubPath but environment variable references $(VAR_NAME) are expanded using the container's environment. Defaults to "" (volume's root). SubPathExpr and SubPath are mutually exclusive. type: string required: - mountPath - name type: object type: array workingDir: description: Container's working directory. If not specified, the container runtime's default will be used, which might be configured in the container image. Cannot be updated. type: string required: - name type: object type: array nodeName: description: NodeName is a request to schedule this pod onto a specific node. If it is non-empty, the scheduler simply schedules this pod onto that node, assuming that it fits resource requirements. type: string nodeSelector: additionalProperties: type: string description: 'NodeSelector is a selector which must be true for the pod to fit on a node. Selector which must match a node''s labels for the pod to be scheduled on that node. More info: https://kubernetes.io/docs/concepts/configuration/assign-pod-node/' type: object overhead: additionalProperties: anyOf: - type: integer - type: string pattern: ^(\+|-)?(([0-9]+(\.[0-9]*)?)|(\.[0-9]+))(([KMGTPE]i)|[numkMGTPE]|([eE](\+|-)?(([0-9]+(\.[0-9]*)?)|(\.[0-9]+))))?$ x-kubernetes-int-or-string: true description: 'Overhead represents the resource overhead associated with running a pod for a given RuntimeClass. This field will be autopopulated at admission time by the RuntimeClass admission controller. If the RuntimeClass admission controller is enabled, overhead must not be set in Pod create requests. The RuntimeClass admission controller will reject Pod create requests which have the overhead already set. If RuntimeClass is configured and selected in the PodSpec, Overhead will be set to the value defined in the corresponding RuntimeClass, otherwise it will remain unset and treated as zero. More info: https://git.k8s.io/enhancements/keps/sig-node/20190226-pod-overhead.md This field is alpha-level as of Kubernetes v1.16, and is only honored by servers that enable the PodOverhead feature.' type: object preemptionPolicy: description: PreemptionPolicy is the Policy for preempting pods with lower priority. One of Never, PreemptLowerPriority. Defaults to PreemptLowerPriority if unset. This field is beta-level, gated by the NonPreemptingPriority feature-gate. type: string priority: description: The priority value. Various system components use this field to find the priority of the pod. When Priority Admission Controller is enabled, it prevents users from setting this field. The admission controller populates this field from PriorityClassName. The higher the value, the higher the priority. format: int32 type: integer priorityClassName: description: If specified, indicates the pod's priority. "system-node-critical" and "system-cluster-critical" are two special keywords which indicate the highest priorities with the former being the highest priority. Any other name must be defined by creating a PriorityClass object with that name. If not specified, the pod priority will be default or zero if there is no default. type: string readinessGates: description: 'If specified, all readiness gates will be evaluated for pod readiness. A pod is ready when all its containers are ready AND all conditions specified in the readiness gates have status equal to "True" More info: https://git.k8s.io/enhancements/keps/sig-network/0007-pod-ready%2B%2B.md' items: description: PodReadinessGate contains the reference to a pod condition properties: conditionType: description: ConditionType refers to a condition in the pod's condition list with matching type. type: string required: - conditionType type: object type: array restartPolicy: description: 'Restart policy for all containers within the pod. One of Always, OnFailure, Never. Default to Always. More info: https://kubernetes.io/docs/concepts/workloads/pods/pod-lifecycle/#restart-policy' type: string runtimeClassName: description: 'RuntimeClassName refers to a RuntimeClass object in the node.k8s.io group, which should be used to run this pod. If no RuntimeClass resource matches the named class, the pod will not be run. If unset or empty, the "legacy" RuntimeClass will be used, which is an implicit class with an empty definition that uses the default runtime handler. More info: https://git.k8s.io/enhancements/keps/sig-node/runtime-class.md This is a beta feature as of Kubernetes v1.14.' type: string schedulerName: description: If specified, the pod will be dispatched by specified scheduler. If not specified, the pod will be dispatched by default scheduler. type: string securityContext: description: 'SecurityContext holds pod-level security attributes and common container settings. Optional: Defaults to empty. See type description for default values of each field.' properties: fsGroup: description: "A special supplemental group that applies to all containers in a pod. Some volume types allow the Kubelet to change the ownership of that volume to be owned by the pod: \n 1. The owning GID will be the FSGroup 2. The setgid bit is set (new files created in the volume will be owned by FSGroup) 3. The permission bits are OR'd with rw-rw---- \n If unset, the Kubelet will not modify the ownership and permissions of any volume." format: int64 type: integer fsGroupChangePolicy: description: 'fsGroupChangePolicy defines behavior of changing ownership and permission of the volume before being exposed inside Pod. This field will only apply to volume types which support fsGroup based ownership(and permissions). It will have no effect on ephemeral volume types such as: secret, configmaps and emptydir. Valid values are "OnRootMismatch" and "Always". If not specified, "Always" is used.' type: string runAsGroup: description: The GID to run the entrypoint of the container process. Uses runtime default if unset. May also be set in SecurityContext. If set in both SecurityContext and PodSecurityContext, the value specified in SecurityContext takes precedence for that container. format: int64 type: integer runAsNonRoot: description: Indicates that the container must run as a non-root user. If true, the Kubelet will validate the image at runtime to ensure that it does not run as UID 0 (root) and fail to start the container if it does. If unset or false, no such validation will be performed. May also be set in SecurityContext. If set in both SecurityContext and PodSecurityContext, the value specified in SecurityContext takes precedence. type: boolean runAsUser: description: The UID to run the entrypoint of the container process. Defaults to user specified in image metadata if unspecified. May also be set in SecurityContext. If set in both SecurityContext and PodSecurityContext, the value specified in SecurityContext takes precedence for that container. format: int64 type: integer seLinuxOptions: description: The SELinux context to be applied to all containers. If unspecified, the container runtime will allocate a random SELinux context for each container. May also be set in SecurityContext. If set in both SecurityContext and PodSecurityContext, the value specified in SecurityContext takes precedence for that container. properties: level: description: Level is SELinux level label that applies to the container. type: string role: description: Role is a SELinux role label that applies to the container. type: string type: description: Type is a SELinux type label that applies to the container. type: string user: description: User is a SELinux user label that applies to the container. type: string type: object seccompProfile: description: The seccomp options to use by the containers in this pod. properties: localhostProfile: description: localhostProfile indicates a profile defined in a file on the node should be used. The profile must be preconfigured on the node to work. Must be a descending path, relative to the kubelet's configured seccomp profile location. Must only be set if type is "Localhost". type: string type: description: "type indicates which kind of seccomp profile will be applied. Valid options are: \n Localhost - a profile defined in a file on the node should be used. RuntimeDefault - the container runtime default profile should be used. Unconfined - no profile should be applied." type: string required: - type type: object supplementalGroups: description: A list of groups applied to the first process run in each container, in addition to the container's primary GID. If unspecified, no groups will be added to any container. items: format: int64 type: integer type: array sysctls: description: Sysctls hold a list of namespaced sysctls used for the pod. Pods with unsupported sysctls (by the container runtime) might fail to launch. items: description: Sysctl defines a kernel parameter to be set properties: name: description: Name of a property to set type: string value: description: Value of a property to set type: string required: - name - value type: object type: array windowsOptions: description: The Windows specific settings applied to all containers. If unspecified, the options within a container's SecurityContext will be used. If set in both SecurityContext and PodSecurityContext, the value specified in SecurityContext takes precedence. properties: gmsaCredentialSpec: description: GMSACredentialSpec is where the GMSA admission webhook (https://github.com/kubernetes-sigs/windows-gmsa) inlines the contents of the GMSA credential spec named by the GMSACredentialSpecName field. type: string gmsaCredentialSpecName: description: GMSACredentialSpecName is the name of the GMSA credential spec to use. type: string runAsUserName: description: The UserName in Windows to run the entrypoint of the container process. Defaults to the user specified in image metadata if unspecified. May also be set in PodSecurityContext. If set in both SecurityContext and PodSecurityContext, the value specified in SecurityContext takes precedence. type: string type: object type: object serviceAccount: description: 'DeprecatedServiceAccount is a depreciated alias for ServiceAccountName. Deprecated: Use serviceAccountName instead.' type: string serviceAccountName: description: 'ServiceAccountName is the name of the ServiceAccount to use to run this pod. More info: https://kubernetes.io/docs/tasks/configure-pod-container/configure-service-account/' type: string setHostnameAsFQDN: description: If true the pod's hostname will be configured as the pod's FQDN, rather than the leaf name (the default). In Linux containers, this means setting the FQDN in the hostname field of the kernel (the nodename field of struct utsname). In Windows containers, this means setting the registry value of hostname for the registry key HKEY_LOCAL_MACHINE\\SYSTEM\\CurrentControlSet\\Services\\Tcpip\\Parameters to FQDN. If a pod does not have FQDN, this has no effect. Default to false. type: boolean shareProcessNamespace: description: 'Share a single process namespace between all of the containers in a pod. When this is set containers will be able to view and signal processes from other containers in the same pod, and the first process in each container will not be assigned PID 1. HostPID and ShareProcessNamespace cannot both be set. Optional: Default to false.' type: boolean subdomain: description: If specified, the fully qualified Pod hostname will be "...svc.". If not specified, the pod will not have a domainname at all. type: string terminationGracePeriodSeconds: description: Optional duration in seconds the pod needs to terminate gracefully. May be decreased in delete request. Value must be non-negative integer. The value zero indicates delete immediately. If this value is nil, the default grace period will be used instead. The grace period is the duration in seconds after the processes running in the pod are sent a termination signal and the time when the processes are forcibly halted with a kill signal. Set this value longer than the expected cleanup time for your process. Defaults to 30 seconds. format: int64 type: integer tolerations: description: If specified, the pod's tolerations. items: description: The pod this Toleration is attached to tolerates any taint that matches the triple using the matching operator . properties: effect: description: Effect indicates the taint effect to match. Empty means match all taint effects. When specified, allowed values are NoSchedule, PreferNoSchedule and NoExecute. type: string key: description: Key is the taint key that the toleration applies to. Empty means match all taint keys. If the key is empty, operator must be Exists; this combination means to match all values and all keys. type: string operator: description: Operator represents a key's relationship to the value. Valid operators are Exists and Equal. Defaults to Equal. Exists is equivalent to wildcard for value, so that a pod can tolerate all taints of a particular category. type: string tolerationSeconds: description: TolerationSeconds represents the period of time the toleration (which must be of effect NoExecute, otherwise this field is ignored) tolerates the taint. By default, it is not set, which means tolerate the taint forever (do not evict). Zero and negative values will be treated as 0 (evict immediately) by the system. format: int64 type: integer value: description: Value is the taint value the toleration matches to. If the operator is Exists, the value should be empty, otherwise just a regular string. type: string type: object type: array topologySpreadConstraints: description: TopologySpreadConstraints describes how a group of pods ought to spread across topology domains. Scheduler will schedule pods in a way which abides by the constraints. All topologySpreadConstraints are ANDed. items: description: TopologySpreadConstraint specifies how to spread matching pods among the given topology. properties: labelSelector: description: LabelSelector is used to find matching pods. Pods that match this label selector are counted to determine the number of pods in their corresponding topology domain. properties: matchExpressions: description: matchExpressions is a list of label selector requirements. The requirements are ANDed. items: description: A label selector requirement is a selector that contains values, a key, and an operator that relates the key and values. properties: key: description: key is the label key that the selector applies to. type: string operator: description: operator represents a key's relationship to a set of values. Valid operators are In, NotIn, Exists and DoesNotExist. type: string values: description: values is an array of string values. If the operator is In or NotIn, the values array must be non-empty. If the operator is Exists or DoesNotExist, the values array must be empty. This array is replaced during a strategic merge patch. items: type: string type: array required: - key - operator type: object type: array matchLabels: additionalProperties: type: string description: matchLabels is a map of {key,value} pairs. A single {key,value} in the matchLabels map is equivalent to an element of matchExpressions, whose key field is "key", the operator is "In", and the values array contains only "value". The requirements are ANDed. type: object type: object maxSkew: description: 'MaxSkew describes the degree to which pods may be unevenly distributed. When `whenUnsatisfiable=DoNotSchedule`, it is the maximum permitted difference between the number of matching pods in the target topology and the global minimum. For example, in a 3-zone cluster, MaxSkew is set to 1, and pods with the same labelSelector spread as 1/1/0: | zone1 | zone2 | zone3 | | P | P | | - if MaxSkew is 1, incoming pod can only be scheduled to zone3 to become 1/1/1; scheduling it onto zone1(zone2) would make the ActualSkew(2-0) on zone1(zone2) violate MaxSkew(1). - if MaxSkew is 2, incoming pod can be scheduled onto any zone. When `whenUnsatisfiable=ScheduleAnyway`, it is used to give higher precedence to topologies that satisfy it. It''s a required field. Default value is 1 and 0 is not allowed.' format: int32 type: integer topologyKey: description: TopologyKey is the key of node labels. Nodes that have a label with this key and identical values are considered to be in the same topology. We consider each as a "bucket", and try to put balanced number of pods into each bucket. It's a required field. type: string whenUnsatisfiable: description: 'WhenUnsatisfiable indicates how to deal with a pod if it doesn''t satisfy the spread constraint. - DoNotSchedule (default) tells the scheduler not to schedule it. - ScheduleAnyway tells the scheduler to schedule the pod in any location, but giving higher precedence to topologies that would help reduce the skew. A constraint is considered "Unsatisfiable" for an incoming pod if and only if every possible node assigment for that pod would violate "MaxSkew" on some topology. For example, in a 3-zone cluster, MaxSkew is set to 1, and pods with the same labelSelector spread as 3/1/1: | zone1 | zone2 | zone3 | | P P P | P | P | If WhenUnsatisfiable is set to DoNotSchedule, incoming pod can only be scheduled to zone2(zone3) to become 3/2/1(3/1/2) as ActualSkew(2-1) on zone2(zone3) satisfies MaxSkew(1). In other words, the cluster can still be imbalanced, but scheduler won''t make it *more* imbalanced. It''s a required field.' type: string required: - maxSkew - topologyKey - whenUnsatisfiable type: object type: array x-kubernetes-list-map-keys: - topologyKey - whenUnsatisfiable x-kubernetes-list-type: map volumes: description: 'List of volumes that can be mounted by containers belonging to the pod. More info: https://kubernetes.io/docs/concepts/storage/volumes' items: description: Volume represents a named volume in a pod that may be accessed by any container in the pod. properties: awsElasticBlockStore: description: 'AWSElasticBlockStore represents an AWS Disk resource that is attached to a kubelet''s host machine and then exposed to the pod. More info: https://kubernetes.io/docs/concepts/storage/volumes#awselasticblockstore' properties: fsType: description: 'Filesystem type of the volume that you want to mount. Tip: Ensure that the filesystem type is supported by the host operating system. Examples: "ext4", "xfs", "ntfs". Implicitly inferred to be "ext4" if unspecified. More info: https://kubernetes.io/docs/concepts/storage/volumes#awselasticblockstore TODO: how do we prevent errors in the filesystem from compromising the machine' type: string partition: description: 'The partition in the volume that you want to mount. If omitted, the default is to mount by volume name. Examples: For volume /dev/sda1, you specify the partition as "1". Similarly, the volume partition for /dev/sda is "0" (or you can leave the property empty).' format: int32 type: integer readOnly: description: 'Specify "true" to force and set the ReadOnly property in VolumeMounts to "true". If omitted, the default is "false". More info: https://kubernetes.io/docs/concepts/storage/volumes#awselasticblockstore' type: boolean volumeID: description: 'Unique ID of the persistent disk resource in AWS (Amazon EBS volume). More info: https://kubernetes.io/docs/concepts/storage/volumes#awselasticblockstore' type: string required: - volumeID type: object azureDisk: description: AzureDisk represents an Azure Data Disk mount on the host and bind mount to the pod. properties: cachingMode: description: 'Host Caching mode: None, Read Only, Read Write.' type: string diskName: description: The Name of the data disk in the blob storage type: string diskURI: description: The URI the data disk in the blob storage type: string fsType: description: Filesystem type to mount. Must be a filesystem type supported by the host operating system. Ex. "ext4", "xfs", "ntfs". Implicitly inferred to be "ext4" if unspecified. type: string kind: description: 'Expected values Shared: multiple blob disks per storage account Dedicated: single blob disk per storage account Managed: azure managed data disk (only in managed availability set). defaults to shared' type: string readOnly: description: Defaults to false (read/write). ReadOnly here will force the ReadOnly setting in VolumeMounts. type: boolean required: - diskName - diskURI type: object azureFile: description: AzureFile represents an Azure File Service mount on the host and bind mount to the pod. properties: readOnly: description: Defaults to false (read/write). ReadOnly here will force the ReadOnly setting in VolumeMounts. type: boolean secretName: description: the name of secret that contains Azure Storage Account Name and Key type: string shareName: description: Share Name type: string required: - secretName - shareName type: object cephfs: description: CephFS represents a Ceph FS mount on the host that shares a pod's lifetime properties: monitors: description: 'Required: Monitors is a collection of Ceph monitors More info: https://examples.k8s.io/volumes/cephfs/README.md#how-to-use-it' items: type: string type: array path: description: 'Optional: Used as the mounted root, rather than the full Ceph tree, default is /' type: string readOnly: description: 'Optional: Defaults to false (read/write). ReadOnly here will force the ReadOnly setting in VolumeMounts. More info: https://examples.k8s.io/volumes/cephfs/README.md#how-to-use-it' type: boolean secretFile: description: 'Optional: SecretFile is the path to key ring for User, default is /etc/ceph/user.secret More info: https://examples.k8s.io/volumes/cephfs/README.md#how-to-use-it' type: string secretRef: description: 'Optional: SecretRef is reference to the authentication secret for User, default is empty. More info: https://examples.k8s.io/volumes/cephfs/README.md#how-to-use-it' properties: name: description: 'Name of the referent. More info: https://kubernetes.io/docs/concepts/overview/working-with-objects/names/#names TODO: Add other useful fields. apiVersion, kind, uid?' type: string type: object user: description: 'Optional: User is the rados user name, default is admin More info: https://examples.k8s.io/volumes/cephfs/README.md#how-to-use-it' type: string required: - monitors type: object cinder: description: 'Cinder represents a cinder volume attached and mounted on kubelets host machine. More info: https://examples.k8s.io/mysql-cinder-pd/README.md' properties: fsType: description: 'Filesystem type to mount. Must be a filesystem type supported by the host operating system. Examples: "ext4", "xfs", "ntfs". Implicitly inferred to be "ext4" if unspecified. More info: https://examples.k8s.io/mysql-cinder-pd/README.md' type: string readOnly: description: 'Optional: Defaults to false (read/write). ReadOnly here will force the ReadOnly setting in VolumeMounts. More info: https://examples.k8s.io/mysql-cinder-pd/README.md' type: boolean secretRef: description: 'Optional: points to a secret object containing parameters used to connect to OpenStack.' properties: name: description: 'Name of the referent. More info: https://kubernetes.io/docs/concepts/overview/working-with-objects/names/#names TODO: Add other useful fields. apiVersion, kind, uid?' type: string type: object volumeID: description: 'volume id used to identify the volume in cinder. More info: https://examples.k8s.io/mysql-cinder-pd/README.md' type: string required: - volumeID type: object configMap: description: ConfigMap represents a configMap that should populate this volume properties: defaultMode: description: 'Optional: mode bits used to set permissions on created files by default. Must be an octal value between 0000 and 0777 or a decimal value between 0 and 511. YAML accepts both octal and decimal values, JSON requires decimal values for mode bits. Defaults to 0644. Directories within the path are not affected by this setting. This might be in conflict with other options that affect the file mode, like fsGroup, and the result can be other mode bits set.' format: int32 type: integer items: description: If unspecified, each key-value pair in the Data field of the referenced ConfigMap will be projected into the volume as a file whose name is the key and content is the value. If specified, the listed keys will be projected into the specified paths, and unlisted keys will not be present. If a key is specified which is not present in the ConfigMap, the volume setup will error unless it is marked optional. Paths must be relative and may not contain the '..' path or start with '..'. items: description: Maps a string key to a path within a volume. properties: key: description: The key to project. type: string mode: description: 'Optional: mode bits used to set permissions on this file. Must be an octal value between 0000 and 0777 or a decimal value between 0 and 511. YAML accepts both octal and decimal values, JSON requires decimal values for mode bits. If not specified, the volume defaultMode will be used. This might be in conflict with other options that affect the file mode, like fsGroup, and the result can be other mode bits set.' format: int32 type: integer path: description: The relative path of the file to map the key to. May not be an absolute path. May not contain the path element '..'. May not start with the string '..'. type: string required: - key - path type: object type: array name: description: 'Name of the referent. More info: https://kubernetes.io/docs/concepts/overview/working-with-objects/names/#names TODO: Add other useful fields. apiVersion, kind, uid?' type: string optional: description: Specify whether the ConfigMap or its keys must be defined type: boolean type: object csi: description: CSI (Container Storage Interface) represents ephemeral storage that is handled by certain external CSI drivers (Beta feature). properties: driver: description: Driver is the name of the CSI driver that handles this volume. Consult with your admin for the correct name as registered in the cluster. type: string fsType: description: Filesystem type to mount. Ex. "ext4", "xfs", "ntfs". If not provided, the empty value is passed to the associated CSI driver which will determine the default filesystem to apply. type: string nodePublishSecretRef: description: NodePublishSecretRef is a reference to the secret object containing sensitive information to pass to the CSI driver to complete the CSI NodePublishVolume and NodeUnpublishVolume calls. This field is optional, and may be empty if no secret is required. If the secret object contains more than one secret, all secret references are passed. properties: name: description: 'Name of the referent. More info: https://kubernetes.io/docs/concepts/overview/working-with-objects/names/#names TODO: Add other useful fields. apiVersion, kind, uid?' type: string type: object readOnly: description: Specifies a read-only configuration for the volume. Defaults to false (read/write). type: boolean volumeAttributes: additionalProperties: type: string description: VolumeAttributes stores driver-specific properties that are passed to the CSI driver. Consult your driver's documentation for supported values. type: object required: - driver type: object downwardAPI: description: DownwardAPI represents downward API about the pod that should populate this volume properties: defaultMode: description: 'Optional: mode bits to use on created files by default. Must be a Optional: mode bits used to set permissions on created files by default. Must be an octal value between 0000 and 0777 or a decimal value between 0 and 511. YAML accepts both octal and decimal values, JSON requires decimal values for mode bits. Defaults to 0644. Directories within the path are not affected by this setting. This might be in conflict with other options that affect the file mode, like fsGroup, and the result can be other mode bits set.' format: int32 type: integer items: description: Items is a list of downward API volume file items: description: DownwardAPIVolumeFile represents information to create the file containing the pod field properties: fieldRef: description: 'Required: Selects a field of the pod: only annotations, labels, name and namespace are supported.' properties: apiVersion: description: Version of the schema the FieldPath is written in terms of, defaults to "v1". type: string fieldPath: description: Path of the field to select in the specified API version. type: string required: - fieldPath type: object mode: description: 'Optional: mode bits used to set permissions on this file, must be an octal value between 0000 and 0777 or a decimal value between 0 and 511. YAML accepts both octal and decimal values, JSON requires decimal values for mode bits. If not specified, the volume defaultMode will be used. This might be in conflict with other options that affect the file mode, like fsGroup, and the result can be other mode bits set.' format: int32 type: integer path: description: 'Required: Path is the relative path name of the file to be created. Must not be absolute or contain the ''..'' path. Must be utf-8 encoded. The first item of the relative path must not start with ''..''' type: string resourceFieldRef: description: 'Selects a resource of the container: only resources limits and requests (limits.cpu, limits.memory, requests.cpu and requests.memory) are currently supported.' properties: containerName: description: 'Container name: required for volumes, optional for env vars' type: string divisor: anyOf: - type: integer - type: string description: Specifies the output format of the exposed resources, defaults to "1" pattern: ^(\+|-)?(([0-9]+(\.[0-9]*)?)|(\.[0-9]+))(([KMGTPE]i)|[numkMGTPE]|([eE](\+|-)?(([0-9]+(\.[0-9]*)?)|(\.[0-9]+))))?$ x-kubernetes-int-or-string: true resource: description: 'Required: resource to select' type: string required: - resource type: object required: - path type: object type: array type: object emptyDir: description: 'EmptyDir represents a temporary directory that shares a pod''s lifetime. More info: https://kubernetes.io/docs/concepts/storage/volumes#emptydir' properties: medium: description: 'What type of storage medium should back this directory. The default is "" which means to use the node''s default medium. Must be an empty string (default) or Memory. More info: https://kubernetes.io/docs/concepts/storage/volumes#emptydir' type: string sizeLimit: anyOf: - type: integer - type: string description: 'Total amount of local storage required for this EmptyDir volume. The size limit is also applicable for memory medium. The maximum usage on memory medium EmptyDir would be the minimum value between the SizeLimit specified here and the sum of memory limits of all containers in a pod. The default is nil which means that the limit is undefined. More info: http://kubernetes.io/docs/user-guide/volumes#emptydir' pattern: ^(\+|-)?(([0-9]+(\.[0-9]*)?)|(\.[0-9]+))(([KMGTPE]i)|[numkMGTPE]|([eE](\+|-)?(([0-9]+(\.[0-9]*)?)|(\.[0-9]+))))?$ x-kubernetes-int-or-string: true type: object ephemeral: description: "Ephemeral represents a volume that is handled by a cluster storage driver (Alpha feature). The volume's lifecycle is tied to the pod that defines it - it will be created before the pod starts, and deleted when the pod is removed. \n Use this if: a) the volume is only needed while the pod runs, b) features of normal volumes like restoring from snapshot or capacity tracking are needed, c) the storage driver is specified through a storage class, and d) the storage driver supports dynamic volume provisioning through a PersistentVolumeClaim (see EphemeralVolumeSource for more \ information on the connection between this volume type and PersistentVolumeClaim). \n Use PersistentVolumeClaim or one of the vendor-specific APIs for volumes that persist for longer than the lifecycle of an individual pod. \n Use CSI for light-weight local ephemeral volumes if the CSI driver is meant to be used that way - see the documentation of the driver for more information. \n A pod can use both types of ephemeral volumes and persistent volumes at the same time." properties: readOnly: description: Specifies a read-only configuration for the volume. Defaults to false (read/write). type: boolean volumeClaimTemplate: description: "Will be used to create a stand-alone PVC to provision the volume. The pod in which this EphemeralVolumeSource is embedded will be the owner of the PVC, i.e. the PVC will be deleted together with the pod. The name of the PVC will be `-` where `` is the name from the `PodSpec.Volumes` array entry. Pod validation will reject the pod if the concatenated name is not valid for a PVC (for example, too long). \n An existing PVC with that name that is not owned by the pod will *not* be used for the pod to avoid using an unrelated volume by mistake. Starting the pod is then blocked until the unrelated PVC is removed. If such a pre-created PVC is meant to be used by the pod, the PVC has to updated with an owner reference to the pod once the pod exists. Normally this should not be necessary, but it may be useful when manually reconstructing a broken cluster. \n This field is read-only and no changes will be made by Kubernetes to the PVC after it has been created. \n Required, must not be nil." properties: metadata: description: May contain labels and annotations that will be copied into the PVC when creating it. No other fields are allowed and will be rejected during validation. properties: annotations: additionalProperties: type: string type: object finalizers: items: type: string type: array labels: additionalProperties: type: string type: object name: type: string namespace: type: string type: object spec: description: The specification for the PersistentVolumeClaim. The entire content is copied unchanged into the PVC that gets created from this template. The same fields as in a PersistentVolumeClaim are also valid here. properties: accessModes: description: 'AccessModes contains the desired access modes the volume should have. More info: https://kubernetes.io/docs/concepts/storage/persistent-volumes#access-modes-1' items: type: string type: array dataSource: description: 'This field can be used to specify either: * An existing VolumeSnapshot object (snapshot.storage.k8s.io/VolumeSnapshot) * An existing PVC (PersistentVolumeClaim) * An existing custom resource that implements data population (Alpha) In order to use custom resource types that implement data population, the AnyVolumeDataSource feature gate must be enabled. If the provisioner or an external controller can support the specified data source, it will create a new volume based on the contents of the specified data source.' properties: apiGroup: description: APIGroup is the group for the resource being referenced. If APIGroup is not specified, the specified Kind must be in the core API group. For any other third-party types, APIGroup is required. type: string kind: description: Kind is the type of resource being referenced type: string name: description: Name is the name of resource being referenced type: string required: - kind - name type: object resources: description: 'Resources represents the minimum resources the volume should have. More info: https://kubernetes.io/docs/concepts/storage/persistent-volumes#resources' properties: limits: additionalProperties: anyOf: - type: integer - type: string pattern: ^(\+|-)?(([0-9]+(\.[0-9]*)?)|(\.[0-9]+))(([KMGTPE]i)|[numkMGTPE]|([eE](\+|-)?(([0-9]+(\.[0-9]*)?)|(\.[0-9]+))))?$ x-kubernetes-int-or-string: true description: 'Limits describes the maximum amount of compute resources allowed. More info: https://kubernetes.io/docs/concepts/configuration/manage-compute-resources-container/' type: object requests: additionalProperties: anyOf: - type: integer - type: string pattern: ^(\+|-)?(([0-9]+(\.[0-9]*)?)|(\.[0-9]+))(([KMGTPE]i)|[numkMGTPE]|([eE](\+|-)?(([0-9]+(\.[0-9]*)?)|(\.[0-9]+))))?$ x-kubernetes-int-or-string: true description: 'Requests describes the minimum amount of compute resources required. If Requests is omitted for a container, it defaults to Limits if that is explicitly specified, otherwise to an implementation-defined value. More info: https://kubernetes.io/docs/concepts/configuration/manage-compute-resources-container/' type: object type: object selector: description: A label query over volumes to consider for binding. properties: matchExpressions: description: matchExpressions is a list of label selector requirements. The requirements are ANDed. items: description: A label selector requirement is a selector that contains values, a key, and an operator that relates the key and values. properties: key: description: key is the label key that the selector applies to. type: string operator: description: operator represents a key's relationship to a set of values. Valid operators are In, NotIn, Exists and DoesNotExist. type: string values: description: values is an array of string values. If the operator is In or NotIn, the values array must be non-empty. If the operator is Exists or DoesNotExist, the values array must be empty. This array is replaced during a strategic merge patch. items: type: string type: array required: - key - operator type: object type: array matchLabels: additionalProperties: type: string description: matchLabels is a map of {key,value} pairs. A single {key,value} in the matchLabels map is equivalent to an element of matchExpressions, whose key field is "key", the operator is "In", and the values array contains only "value". The requirements are ANDed. type: object type: object storageClassName: description: 'Name of the StorageClass required by the claim. More info: https://kubernetes.io/docs/concepts/storage/persistent-volumes#class-1' type: string volumeMode: description: volumeMode defines what type of volume is required by the claim. Value of Filesystem is implied when not included in claim spec. type: string volumeName: description: VolumeName is the binding reference to the PersistentVolume backing this claim. type: string type: object required: - spec type: object type: object fc: description: FC represents a Fibre Channel resource that is attached to a kubelet's host machine and then exposed to the pod. properties: fsType: description: 'Filesystem type to mount. Must be a filesystem type supported by the host operating system. Ex. "ext4", "xfs", "ntfs". Implicitly inferred to be "ext4" if unspecified. TODO: how do we prevent errors in the filesystem from compromising the machine' type: string lun: description: 'Optional: FC target lun number' format: int32 type: integer readOnly: description: 'Optional: Defaults to false (read/write). ReadOnly here will force the ReadOnly setting in VolumeMounts.' type: boolean targetWWNs: description: 'Optional: FC target worldwide names (WWNs)' items: type: string type: array wwids: description: 'Optional: FC volume world wide identifiers (wwids) Either wwids or combination of targetWWNs and lun must be set, but not both simultaneously.' items: type: string type: array type: object flexVolume: description: FlexVolume represents a generic volume resource that is provisioned/attached using an exec based plugin. properties: driver: description: Driver is the name of the driver to use for this volume. type: string fsType: description: Filesystem type to mount. Must be a filesystem type supported by the host operating system. Ex. "ext4", "xfs", "ntfs". The default filesystem depends on FlexVolume script. type: string options: additionalProperties: type: string description: 'Optional: Extra command options if any.' type: object readOnly: description: 'Optional: Defaults to false (read/write). ReadOnly here will force the ReadOnly setting in VolumeMounts.' type: boolean secretRef: description: 'Optional: SecretRef is reference to the secret object containing sensitive information to pass to the plugin scripts. This may be empty if no secret object is specified. If the secret object contains more than one secret, all secrets are passed to the plugin scripts.' properties: name: description: 'Name of the referent. More info: https://kubernetes.io/docs/concepts/overview/working-with-objects/names/#names TODO: Add other useful fields. apiVersion, kind, uid?' type: string type: object required: - driver type: object flocker: description: Flocker represents a Flocker volume attached to a kubelet's host machine. This depends on the Flocker control service being running properties: datasetName: description: Name of the dataset stored as metadata -> name on the dataset for Flocker should be considered as deprecated type: string datasetUUID: description: UUID of the dataset. This is unique identifier of a Flocker dataset type: string type: object gcePersistentDisk: description: 'GCEPersistentDisk represents a GCE Disk resource that is attached to a kubelet''s host machine and then exposed to the pod. More info: https://kubernetes.io/docs/concepts/storage/volumes#gcepersistentdisk' properties: fsType: description: 'Filesystem type of the volume that you want to mount. Tip: Ensure that the filesystem type is supported by the host operating system. Examples: "ext4", "xfs", "ntfs". Implicitly inferred to be "ext4" if unspecified. More info: https://kubernetes.io/docs/concepts/storage/volumes#gcepersistentdisk TODO: how do we prevent errors in the filesystem from compromising the machine' type: string partition: description: 'The partition in the volume that you want to mount. If omitted, the default is to mount by volume name. Examples: For volume /dev/sda1, you specify the partition as "1". Similarly, the volume partition for /dev/sda is "0" (or you can leave the property empty). More info: https://kubernetes.io/docs/concepts/storage/volumes#gcepersistentdisk' format: int32 type: integer pdName: description: 'Unique name of the PD resource in GCE. Used to identify the disk in GCE. More info: https://kubernetes.io/docs/concepts/storage/volumes#gcepersistentdisk' type: string readOnly: description: 'ReadOnly here will force the ReadOnly setting in VolumeMounts. Defaults to false. More info: https://kubernetes.io/docs/concepts/storage/volumes#gcepersistentdisk' type: boolean required: - pdName type: object gitRepo: description: 'GitRepo represents a git repository at a particular revision. DEPRECATED: GitRepo is deprecated. To provision a container with a git repo, mount an EmptyDir into an InitContainer that clones the repo using git, then mount the EmptyDir into the Pod''s container.' properties: directory: description: Target directory name. Must not contain or start with '..'. If '.' is supplied, the volume directory will be the git repository. Otherwise, if specified, the volume will contain the git repository in the subdirectory with the given name. type: string repository: description: Repository URL type: string revision: description: Commit hash for the specified revision. type: string required: - repository type: object glusterfs: description: 'Glusterfs represents a Glusterfs mount on the host that shares a pod''s lifetime. More info: https://examples.k8s.io/volumes/glusterfs/README.md' properties: endpoints: description: 'EndpointsName is the endpoint name that details Glusterfs topology. More info: https://examples.k8s.io/volumes/glusterfs/README.md#create-a-pod' type: string path: description: 'Path is the Glusterfs volume path. More info: https://examples.k8s.io/volumes/glusterfs/README.md#create-a-pod' type: string readOnly: description: 'ReadOnly here will force the Glusterfs volume to be mounted with read-only permissions. Defaults to false. More info: https://examples.k8s.io/volumes/glusterfs/README.md#create-a-pod' type: boolean required: - endpoints - path type: object hostPath: description: 'HostPath represents a pre-existing file or directory on the host machine that is directly exposed to the container. This is generally used for system agents or other privileged things that are allowed to see the host machine. Most containers will NOT need this. More info: https://kubernetes.io/docs/concepts/storage/volumes#hostpath --- TODO(jonesdl) We need to restrict who can use host directory mounts and who can/can not mount host directories as read/write.' properties: path: description: 'Path of the directory on the host. If the path is a symlink, it will follow the link to the real path. More info: https://kubernetes.io/docs/concepts/storage/volumes#hostpath' type: string type: description: 'Type for HostPath Volume Defaults to "" More info: https://kubernetes.io/docs/concepts/storage/volumes#hostpath' type: string required: - path type: object iscsi: description: 'ISCSI represents an ISCSI Disk resource that is attached to a kubelet''s host machine and then exposed to the pod. More info: https://examples.k8s.io/volumes/iscsi/README.md' properties: chapAuthDiscovery: description: whether support iSCSI Discovery CHAP authentication type: boolean chapAuthSession: description: whether support iSCSI Session CHAP authentication type: boolean fsType: description: 'Filesystem type of the volume that you want to mount. Tip: Ensure that the filesystem type is supported by the host operating system. Examples: "ext4", "xfs", "ntfs". Implicitly inferred to be "ext4" if unspecified. More info: https://kubernetes.io/docs/concepts/storage/volumes#iscsi TODO: how do we prevent errors in the filesystem from compromising the machine' type: string initiatorName: description: Custom iSCSI Initiator Name. If initiatorName is specified with iscsiInterface simultaneously, new iSCSI interface : will be created for the connection. type: string iqn: description: Target iSCSI Qualified Name. type: string iscsiInterface: description: iSCSI Interface Name that uses an iSCSI transport. Defaults to 'default' (tcp). type: string lun: description: iSCSI Target Lun number. format: int32 type: integer portals: description: iSCSI Target Portal List. The portal is either an IP or ip_addr:port if the port is other than default (typically TCP ports 860 and 3260). items: type: string type: array readOnly: description: ReadOnly here will force the ReadOnly setting in VolumeMounts. Defaults to false. type: boolean secretRef: description: CHAP Secret for iSCSI target and initiator authentication properties: name: description: 'Name of the referent. More info: https://kubernetes.io/docs/concepts/overview/working-with-objects/names/#names TODO: Add other useful fields. apiVersion, kind, uid?' type: string type: object targetPortal: description: iSCSI Target Portal. The Portal is either an IP or ip_addr:port if the port is other than default (typically TCP ports 860 and 3260). type: string required: - iqn - lun - targetPortal type: object name: description: 'Volume''s name. Must be a DNS_LABEL and unique within the pod. More info: https://kubernetes.io/docs/concepts/overview/working-with-objects/names/#names' type: string nfs: description: 'NFS represents an NFS mount on the host that shares a pod''s lifetime More info: https://kubernetes.io/docs/concepts/storage/volumes#nfs' properties: path: description: 'Path that is exported by the NFS server. More info: https://kubernetes.io/docs/concepts/storage/volumes#nfs' type: string readOnly: description: 'ReadOnly here will force the NFS export to be mounted with read-only permissions. Defaults to false. More info: https://kubernetes.io/docs/concepts/storage/volumes#nfs' type: boolean server: description: 'Server is the hostname or IP address of the NFS server. More info: https://kubernetes.io/docs/concepts/storage/volumes#nfs' type: string required: - path - server type: object persistentVolumeClaim: description: 'PersistentVolumeClaimVolumeSource represents a reference to a PersistentVolumeClaim in the same namespace. More info: https://kubernetes.io/docs/concepts/storage/persistent-volumes#persistentvolumeclaims' properties: claimName: description: 'ClaimName is the name of a PersistentVolumeClaim in the same namespace as the pod using this volume. More info: https://kubernetes.io/docs/concepts/storage/persistent-volumes#persistentvolumeclaims' type: string readOnly: description: Will force the ReadOnly setting in VolumeMounts. Default false. type: boolean required: - claimName type: object photonPersistentDisk: description: PhotonPersistentDisk represents a PhotonController persistent disk attached and mounted on kubelets host machine properties: fsType: description: Filesystem type to mount. Must be a filesystem type supported by the host operating system. Ex. "ext4", "xfs", "ntfs". Implicitly inferred to be "ext4" if unspecified. type: string pdID: description: ID that identifies Photon Controller persistent disk type: string required: - pdID type: object portworxVolume: description: PortworxVolume represents a portworx volume attached and mounted on kubelets host machine properties: fsType: description: FSType represents the filesystem type to mount Must be a filesystem type supported by the host operating system. Ex. "ext4", "xfs". Implicitly inferred to be "ext4" if unspecified. type: string readOnly: description: Defaults to false (read/write). ReadOnly here will force the ReadOnly setting in VolumeMounts. type: boolean volumeID: description: VolumeID uniquely identifies a Portworx volume type: string required: - volumeID type: object projected: description: Items for all in one resources secrets, configmaps, and downward API properties: defaultMode: description: Mode bits used to set permissions on created files by default. Must be an octal value between 0000 and 0777 or a decimal value between 0 and 511. YAML accepts both octal and decimal values, JSON requires decimal values for mode bits. Directories within the path are not affected by this setting. This might be in conflict with other options that affect the file mode, like fsGroup, and the result can be other mode bits set. format: int32 type: integer sources: description: list of volume projections items: description: Projection that may be projected along with other supported volume types properties: configMap: description: information about the configMap data to project properties: items: description: If unspecified, each key-value pair in the Data field of the referenced ConfigMap will be projected into the volume as a file whose name is the key and content is the value. If specified, the listed keys will be projected into the specified paths, and unlisted keys will not be present. If a key is specified which is not present in the ConfigMap, the volume setup will error unless it is marked optional. Paths must be relative and may not contain the '..' path or start with '..'. items: description: Maps a string key to a path within a volume. properties: key: description: The key to project. type: string mode: description: 'Optional: mode bits used to set permissions on this file. Must be an octal value between 0000 and 0777 or a decimal value between 0 and 511. YAML accepts both octal and decimal values, JSON requires decimal values for mode bits. If not specified, the volume defaultMode will be used. This might be in conflict with other options that affect the file mode, like fsGroup, and the result can be other mode bits set.' format: int32 type: integer path: description: The relative path of the file to map the key to. May not be an absolute path. May not contain the path element '..'. May not start with the string '..'. type: string required: - key - path type: object type: array name: description: 'Name of the referent. More info: https://kubernetes.io/docs/concepts/overview/working-with-objects/names/#names TODO: Add other useful fields. apiVersion, kind, uid?' type: string optional: description: Specify whether the ConfigMap or its keys must be defined type: boolean type: object downwardAPI: description: information about the downwardAPI data to project properties: items: description: Items is a list of DownwardAPIVolume file items: description: DownwardAPIVolumeFile represents information to create the file containing the pod field properties: fieldRef: description: 'Required: Selects a field of the pod: only annotations, labels, name and namespace are supported.' properties: apiVersion: description: Version of the schema the FieldPath is written in terms of, defaults to "v1". type: string fieldPath: description: Path of the field to select in the specified API version. type: string required: - fieldPath type: object mode: description: 'Optional: mode bits used to set permissions on this file, must be an octal value between 0000 and 0777 or a decimal value between 0 and 511. YAML accepts both octal and decimal values, JSON requires decimal values for mode bits. If not specified, the volume defaultMode will be used. This might be in conflict with other options that affect the file mode, like fsGroup, and the result can be other mode bits set.' format: int32 type: integer path: description: 'Required: Path is the relative path name of the file to be created. Must not be absolute or contain the ''..'' path. Must be utf-8 encoded. The first item of the relative path must not start with ''..''' type: string resourceFieldRef: description: 'Selects a resource of the container: only resources limits and requests (limits.cpu, limits.memory, requests.cpu and requests.memory) are currently supported.' properties: containerName: description: 'Container name: required for volumes, optional for env vars' type: string divisor: anyOf: - type: integer - type: string description: Specifies the output format of the exposed resources, defaults to "1" pattern: ^(\+|-)?(([0-9]+(\.[0-9]*)?)|(\.[0-9]+))(([KMGTPE]i)|[numkMGTPE]|([eE](\+|-)?(([0-9]+(\.[0-9]*)?)|(\.[0-9]+))))?$ x-kubernetes-int-or-string: true resource: description: 'Required: resource to select' type: string required: - resource type: object required: - path type: object type: array type: object secret: description: information about the secret data to project properties: items: description: If unspecified, each key-value pair in the Data field of the referenced Secret will be projected into the volume as a file whose name is the key and content is the value. If specified, the listed keys will be projected into the specified paths, and unlisted keys will not be present. If a key is specified which is not present in the Secret, the volume setup will error unless it is marked optional. Paths must be relative and may not contain the '..' path or start with '..'. items: description: Maps a string key to a path within a volume. properties: key: description: The key to project. type: string mode: description: 'Optional: mode bits used to set permissions on this file. Must be an octal value between 0000 and 0777 or a decimal value between 0 and 511. YAML accepts both octal and decimal values, JSON requires decimal values for mode bits. If not specified, the volume defaultMode will be used. This might be in conflict with other options that affect the file mode, like fsGroup, and the result can be other mode bits set.' format: int32 type: integer path: description: The relative path of the file to map the key to. May not be an absolute path. May not contain the path element '..'. May not start with the string '..'. type: string required: - key - path type: object type: array name: description: 'Name of the referent. More info: https://kubernetes.io/docs/concepts/overview/working-with-objects/names/#names TODO: Add other useful fields. apiVersion, kind, uid?' type: string optional: description: Specify whether the Secret or its key must be defined type: boolean type: object serviceAccountToken: description: information about the serviceAccountToken data to project properties: audience: description: Audience is the intended audience of the token. A recipient of a token must identify itself with an identifier specified in the audience of the token, and otherwise should reject the token. The audience defaults to the identifier of the apiserver. type: string expirationSeconds: description: ExpirationSeconds is the requested duration of validity of the service account token. As the token approaches expiration, the kubelet volume plugin will proactively rotate the service account token. The kubelet will start trying to rotate the token if the token is older than 80 percent of its time to live or if the token is older than 24 hours.Defaults to 1 hour and must be at least 10 minutes. format: int64 type: integer path: description: Path is the path relative to the mount point of the file to project the token into. type: string required: - path type: object type: object type: array type: object quobyte: description: Quobyte represents a Quobyte mount on the host that shares a pod's lifetime properties: group: description: Group to map volume access to Default is no group type: string readOnly: description: ReadOnly here will force the Quobyte volume to be mounted with read-only permissions. Defaults to false. type: boolean registry: description: Registry represents a single or multiple Quobyte Registry services specified as a string as host:port pair (multiple entries are separated with commas) which acts as the central registry for volumes type: string tenant: description: Tenant owning the given Quobyte volume in the Backend Used with dynamically provisioned Quobyte volumes, value is set by the plugin type: string user: description: User to map volume access to Defaults to serivceaccount user type: string volume: description: Volume is a string that references an already created Quobyte volume by name. type: string required: - registry - volume type: object rbd: description: 'RBD represents a Rados Block Device mount on the host that shares a pod''s lifetime. More info: https://examples.k8s.io/volumes/rbd/README.md' properties: fsType: description: 'Filesystem type of the volume that you want to mount. Tip: Ensure that the filesystem type is supported by the host operating system. Examples: "ext4", "xfs", "ntfs". Implicitly inferred to be "ext4" if unspecified. More info: https://kubernetes.io/docs/concepts/storage/volumes#rbd TODO: how do we prevent errors in the filesystem from compromising the machine' type: string image: description: 'The rados image name. More info: https://examples.k8s.io/volumes/rbd/README.md#how-to-use-it' type: string keyring: description: 'Keyring is the path to key ring for RBDUser. Default is /etc/ceph/keyring. More info: https://examples.k8s.io/volumes/rbd/README.md#how-to-use-it' type: string monitors: description: 'A collection of Ceph monitors. More info: https://examples.k8s.io/volumes/rbd/README.md#how-to-use-it' items: type: string type: array pool: description: 'The rados pool name. Default is rbd. More info: https://examples.k8s.io/volumes/rbd/README.md#how-to-use-it' type: string readOnly: description: 'ReadOnly here will force the ReadOnly setting in VolumeMounts. Defaults to false. More info: https://examples.k8s.io/volumes/rbd/README.md#how-to-use-it' type: boolean secretRef: description: 'SecretRef is name of the authentication secret for RBDUser. If provided overrides keyring. Default is nil. More info: https://examples.k8s.io/volumes/rbd/README.md#how-to-use-it' properties: name: description: 'Name of the referent. More info: https://kubernetes.io/docs/concepts/overview/working-with-objects/names/#names TODO: Add other useful fields. apiVersion, kind, uid?' type: string type: object user: description: 'The rados user name. Default is admin. More info: https://examples.k8s.io/volumes/rbd/README.md#how-to-use-it' type: string required: - image - monitors type: object scaleIO: description: ScaleIO represents a ScaleIO persistent volume attached and mounted on Kubernetes nodes. properties: fsType: description: Filesystem type to mount. Must be a filesystem type supported by the host operating system. Ex. "ext4", "xfs", "ntfs". Default is "xfs". type: string gateway: description: The host address of the ScaleIO API Gateway. type: string protectionDomain: description: The name of the ScaleIO Protection Domain for the configured storage. type: string readOnly: description: Defaults to false (read/write). ReadOnly here will force the ReadOnly setting in VolumeMounts. type: boolean secretRef: description: SecretRef references to the secret for ScaleIO user and other sensitive information. If this is not provided, Login operation will fail. properties: name: description: 'Name of the referent. More info: https://kubernetes.io/docs/concepts/overview/working-with-objects/names/#names TODO: Add other useful fields. apiVersion, kind, uid?' type: string type: object sslEnabled: description: Flag to enable/disable SSL communication with Gateway, default false type: boolean storageMode: description: Indicates whether the storage for a volume should be ThickProvisioned or ThinProvisioned. Default is ThinProvisioned. type: string storagePool: description: The ScaleIO Storage Pool associated with the protection domain. type: string system: description: The name of the storage system as configured in ScaleIO. type: string volumeName: description: The name of a volume already created in the ScaleIO system that is associated with this volume source. type: string required: - gateway - secretRef - system type: object secret: description: 'Secret represents a secret that should populate this volume. More info: https://kubernetes.io/docs/concepts/storage/volumes#secret' properties: defaultMode: description: 'Optional: mode bits used to set permissions on created files by default. Must be an octal value between 0000 and 0777 or a decimal value between 0 and 511. YAML accepts both octal and decimal values, JSON requires decimal values for mode bits. Defaults to 0644. Directories within the path are not affected by this setting. This might be in conflict with other options that affect the file mode, like fsGroup, and the result can be other mode bits set.' format: int32 type: integer items: description: If unspecified, each key-value pair in the Data field of the referenced Secret will be projected into the volume as a file whose name is the key and content is the value. If specified, the listed keys will be projected into the specified paths, and unlisted keys will not be present. If a key is specified which is not present in the Secret, the volume setup will error unless it is marked optional. Paths must be relative and may not contain the '..' path or start with '..'. items: description: Maps a string key to a path within a volume. properties: key: description: The key to project. type: string mode: description: 'Optional: mode bits used to set permissions on this file. Must be an octal value between 0000 and 0777 or a decimal value between 0 and 511. YAML accepts both octal and decimal values, JSON requires decimal values for mode bits. If not specified, the volume defaultMode will be used. This might be in conflict with other options that affect the file mode, like fsGroup, and the result can be other mode bits set.' format: int32 type: integer path: description: The relative path of the file to map the key to. May not be an absolute path. May not contain the path element '..'. May not start with the string '..'. type: string required: - key - path type: object type: array optional: description: Specify whether the Secret or its keys must be defined type: boolean secretName: description: 'Name of the secret in the pod''s namespace to use. More info: https://kubernetes.io/docs/concepts/storage/volumes#secret' type: string type: object storageos: description: StorageOS represents a StorageOS volume attached and mounted on Kubernetes nodes. properties: fsType: description: Filesystem type to mount. Must be a filesystem type supported by the host operating system. Ex. "ext4", "xfs", "ntfs". Implicitly inferred to be "ext4" if unspecified. type: string readOnly: description: Defaults to false (read/write). ReadOnly here will force the ReadOnly setting in VolumeMounts. type: boolean secretRef: description: SecretRef specifies the secret to use for obtaining the StorageOS API credentials. If not specified, default values will be attempted. properties: name: description: 'Name of the referent. More info: https://kubernetes.io/docs/concepts/overview/working-with-objects/names/#names TODO: Add other useful fields. apiVersion, kind, uid?' type: string type: object volumeName: description: VolumeName is the human-readable name of the StorageOS volume. Volume names are only unique within a namespace. type: string volumeNamespace: description: VolumeNamespace specifies the scope of the volume within StorageOS. If no namespace is specified then the Pod's namespace will be used. This allows the Kubernetes name scoping to be mirrored within StorageOS for tighter integration. Set VolumeName to any name to override the default behaviour. Set to "default" if you are not using namespaces within StorageOS. Namespaces that do not pre-exist within StorageOS will be created. type: string type: object vsphereVolume: description: VsphereVolume represents a vSphere volume attached and mounted on kubelets host machine properties: fsType: description: Filesystem type to mount. Must be a filesystem type supported by the host operating system. Ex. "ext4", "xfs", "ntfs". Implicitly inferred to be "ext4" if unspecified. type: string storagePolicyID: description: Storage Policy Based Management (SPBM) profile ID associated with the StoragePolicyName. type: string storagePolicyName: description: Storage Policy Based Management (SPBM) profile name. type: string volumePath: description: Path that identifies vSphere volume vmdk type: string required: - volumePath type: object required: - name type: object type: array required: - containers type: object type: object required: - selector - template type: object required: - spec type: object required: - name - template type: object type: array selector: description: selector is a label query over deployment. It must match the deployment template's labels. properties: matchExpressions: description: matchExpressions is a list of label selector requirements. The requirements are ANDed. items: description: A label selector requirement is a selector that contains values, a key, and an operator that relates the key and values. properties: key: description: key is the label key that the selector applies to. type: string operator: description: operator represents a key's relationship to a set of values. Valid operators are In, NotIn, Exists and DoesNotExist. type: string values: description: values is an array of string values. If the operator is In or NotIn, the values array must be non-empty. If the operator is Exists or DoesNotExist, the values array must be empty. This array is replaced during a strategic merge patch. items: type: string type: array required: - key - operator type: object type: array matchLabels: additionalProperties: type: string description: matchLabels is a map of {key,value} pairs. A single {key,value} in the matchLabels map is equivalent to an element of matchExpressions, whose key field is "key", the operator is "In", and the values array contains only "value". The requirements are ANDed. type: object type: object required: - roles - selector type: object status: description: MLServiceStatus defines the observed state of MLService properties: lastTransitionTime: description: LastTransitionTime is time the last Phase transitioned to current one. format: date-time type: string message: description: Human-readable message indicating details about last transition. type: string phase: description: Phase is a simple, high-level summary of where the Service is in its lifecycle. type: string reason: description: Unique, one-word, CamelCase reason for the phase's last transition. type: string roleServiceClusterIps: additionalProperties: type: string description: RoleServiceClusterIps shows the cluster ip for all Services. The key is Service name, value is its clusterIP type: object roleServiceStatusMap: additionalProperties: description: ServiceStatus represents the current status of a service. properties: conditions: description: Current service state items: description: "Condition contains details for one aspect of the current state of this API Resource. --- This struct is intended for direct use as an array at the field path .status.conditions. For example, type FooStatus struct{ \ // Represents the observations of a foo's current state. \ // Known .status.conditions.type are: \"Available\", \"Progressing\", and \"Degraded\" // +patchMergeKey=type \ // +patchStrategy=merge // +listType=map // +listMapKey=type Conditions []metav1.Condition `json:\"conditions,omitempty\" patchStrategy:\"merge\" patchMergeKey:\"type\" protobuf:\"bytes,1,rep,name=conditions\"` \n // other fields }" properties: lastTransitionTime: description: lastTransitionTime is the last time the condition transitioned from one status to another. This should be when the underlying condition changed. If that is not known, then using the time when the API field changed is acceptable. format: date-time type: string message: description: message is a human readable message indicating details about the transition. This may be an empty string. maxLength: 32768 type: string observedGeneration: description: observedGeneration represents the .metadata.generation that the condition was set based upon. For instance, if .metadata.generation is currently 12, but the .status.conditions[x].observedGeneration is 9, the condition is out of date with respect to the current state of the instance. format: int64 minimum: 0 type: integer reason: description: reason contains a programmatic identifier indicating the reason for the condition's last transition. Producers of specific condition types may define expected values and meanings for this field, and whether the values are considered a guaranteed API. The value should be a CamelCase string. This field may not be empty. maxLength: 1024 minLength: 1 pattern: ^[A-Za-z]([A-Za-z0-9_,:]*[A-Za-z0-9_])?$ type: string status: description: status of the condition, one of True, False, Unknown. enum: - "True" - "False" - Unknown type: string type: description: type of condition in CamelCase or in foo.example.com/CamelCase. --- Many .condition.type values are consistent across resources like Available, but because arbitrary conditions can be useful (see .node.status.conditions), the ability to deconflict is important. The regex it matches is (dns1123SubdomainFmt/)?(qualifiedNameFmt) maxLength: 316 pattern: ^([a-z0-9]([-a-z0-9]*[a-z0-9])?(\.[a-z0-9]([-a-z0-9]*[a-z0-9])?)*/)?(([A-Za-z0-9][-A-Za-z0-9_.]*)?[A-Za-z0-9])$ type: string required: - lastTransitionTime - message - reason - status - type type: object type: array x-kubernetes-list-map-keys: - type x-kubernetes-list-type: map loadBalancer: description: LoadBalancer contains the current status of the load-balancer, if one is present. properties: ingress: description: Ingress is a list containing ingress points for the load-balancer. Traffic intended for the service should be sent to these ingress points. items: description: 'LoadBalancerIngress represents the status of a load-balancer ingress point: traffic intended for the service should be sent to an ingress point.' properties: hostname: description: Hostname is set for load-balancer ingress points that are DNS based (typically AWS load-balancers) type: string ip: description: IP is set for load-balancer ingress points that are IP based (typically GCE or OpenStack load-balancers) type: string ports: description: Ports is a list of records of service ports If used, every port defined in the service should have an entry in it items: properties: error: description: 'Error is to record the problem with the service port The format of the error shall comply with the following rules: - built-in error values shall be specified in this file and those shall use CamelCase names - cloud provider specific error values must have names that comply with the format foo.example.com/CamelCase. --- The regex it matches is (dns1123SubdomainFmt/)?(qualifiedNameFmt)' maxLength: 316 pattern: ^([a-z0-9]([-a-z0-9]*[a-z0-9])?(\.[a-z0-9]([-a-z0-9]*[a-z0-9])?)*/)?(([A-Za-z0-9][-A-Za-z0-9_.]*)?[A-Za-z0-9])$ type: string port: description: Port is the port number of the service port of which status is recorded here format: int32 type: integer protocol: default: TCP description: 'Protocol is the protocol of the service port of which status is recorded here The supported values are: "TCP", "UDP", "SCTP"' type: string required: - port - protocol type: object type: array x-kubernetes-list-type: atomic type: object type: array type: object type: object description: RoleShardStatusMap shows the current status for all Services. The key is Service name, value is its status info type: object roleShardStatusMap: additionalProperties: description: DeploymentStatus is the most recently observed status of the Deployment. properties: availableReplicas: description: Total number of available pods (ready for at least minReadySeconds) targeted by this deployment. format: int32 type: integer collisionCount: description: Count of hash collisions for the Deployment. The Deployment controller uses this field as a collision avoidance mechanism when it needs to create the name for the newest ReplicaSet. format: int32 type: integer conditions: description: Represents the latest available observations of a deployment's current state. items: description: DeploymentCondition describes the state of a deployment at a certain point. properties: lastTransitionTime: description: Last time the condition transitioned from one status to another. format: date-time type: string lastUpdateTime: description: The last time this condition was updated. format: date-time type: string message: description: A human readable message indicating details about the transition. type: string reason: description: The reason for the condition's last transition. type: string status: description: Status of the condition, one of True, False, Unknown. type: string type: description: Type of deployment condition. type: string required: - status - type type: object type: array observedGeneration: description: The generation observed by the deployment controller. format: int64 type: integer readyReplicas: description: Total number of ready pods targeted by this deployment. format: int32 type: integer replicas: description: Total number of non-terminated pods targeted by this deployment (their labels match the selector). format: int32 type: integer unavailableReplicas: description: Total number of unavailable pods targeted by this deployment. This is the total number of pods that are still required for the deployment to have 100% available capacity. They may either be pods that are running but not yet available or pods that still have not been created. format: int32 type: integer updatedReplicas: description: Total number of non-terminated pods targeted by this deployment that have the desired template spec. format: int32 type: integer type: object description: RoleShardStatusMap shows the current status for all Deployments. The key is Deployment name, value is its status info type: object type: object type: object served: true storage: true subresources: status: {} status: acceptedNames: kind: "" plural: "" conditions: [] storedVersions: [] ================================================ FILE: deploy/config/crd/kustomization.yaml ================================================ # This kustomization.yaml is not intended to be run by itself, # since it depends on service name and namespace that are out of this kustomize package. # It should be run by config/default resources: - bases/mlplatform.volcengine.com_mlservices.yaml #+kubebuilder:scaffold:crdkustomizeresource patchesStrategicMerge: # [WEBHOOK] To enable webhook, uncomment all the sections with [WEBHOOK] prefix. # patches here are for enabling the conversion webhook for each CRD #- patches/webhook_in_mlservices.yaml #+kubebuilder:scaffold:crdkustomizewebhookpatch # [CERTMANAGER] To enable webhook, uncomment all the sections with [CERTMANAGER] prefix. # patches here are for enabling the CA injection for each CRD #- patches/cainjection_in_mlservices.yaml #+kubebuilder:scaffold:crdkustomizecainjectionpatch # the following config is for teaching kustomize how to do kustomization for CRDs. configurations: - kustomizeconfig.yaml ================================================ FILE: deploy/config/crd/kustomizeconfig.yaml ================================================ # This file is for teaching kustomize how to substitute name and namespace reference in CRD nameReference: - kind: Service version: v1 fieldSpecs: - kind: CustomResourceDefinition version: v1 group: apiextensions.k8s.io path: spec/conversion/webhook/clientConfig/service/name namespace: - kind: CustomResourceDefinition version: v1 group: apiextensions.k8s.io path: spec/conversion/webhook/clientConfig/service/namespace create: false varReference: - path: metadata/annotations ================================================ FILE: deploy/config/crd/patches/cainjection_in_mlservices.yaml ================================================ # The following patch adds a directive for certmanager to inject CA into the CRD apiVersion: apiextensions.k8s.io/v1 kind: CustomResourceDefinition metadata: annotations: cert-manager.io/inject-ca-from: $(CERTIFICATE_NAMESPACE)/$(CERTIFICATE_NAME) name: mlservices.mlplatform.volcengine.com ================================================ FILE: deploy/config/crd/patches/webhook_in_mlservices.yaml ================================================ # The following patch enables a conversion webhook for the CRD apiVersion: apiextensions.k8s.io/v1 kind: CustomResourceDefinition metadata: name: mlservices.mlplatform.volcengine.com spec: conversion: strategy: Webhook webhook: clientConfig: service: namespace: system name: webhook-service path: /convert ================================================ FILE: deploy/config/default/kustomization.yaml ================================================ # Adds namespace to all resources. namespace: monolith-system # Value of this field is prepended to the # names of all resources, e.g. a deployment named # "wordpress" becomes "alices-wordpress". # Note that it should also match with the prefix (text before '-') of the namespace # field above. namePrefix: monolith- # Labels to add to all resources and selectors. #commonLabels: # someName: someValue bases: - ../crd - ../rbac - ../manager # [WEBHOOK] To enable webhook, uncomment all the sections with [WEBHOOK] prefix including the one in # crd/kustomization.yaml #- ../webhook # [CERTMANAGER] To enable cert-manager, uncomment all sections with 'CERTMANAGER'. 'WEBHOOK' components are required. #- ../certmanager # [PROMETHEUS] To enable prometheus monitor, uncomment all sections with 'PROMETHEUS'. #- ../prometheus patchesStrategicMerge: # Protect the /metrics endpoint by putting it behind auth. # If you want your controller-manager to expose the /metrics # endpoint w/o any authn/z, please comment the following line. - manager_auth_proxy_patch.yaml # Mount the controller config file for loading manager configurations # through a ComponentConfig type #- manager_config_patch.yaml # [WEBHOOK] To enable webhook, uncomment all the sections with [WEBHOOK] prefix including the one in # crd/kustomization.yaml #- manager_webhook_patch.yaml # [CERTMANAGER] To enable cert-manager, uncomment all sections with 'CERTMANAGER'. # Uncomment 'CERTMANAGER' sections in crd/kustomization.yaml to enable the CA injection in the admission webhooks. # 'CERTMANAGER' needs to be enabled to use ca injection #- webhookcainjection_patch.yaml # the following config is for teaching kustomize how to do var substitution vars: # [CERTMANAGER] To enable cert-manager, uncomment all sections with 'CERTMANAGER' prefix. #- name: CERTIFICATE_NAMESPACE # namespace of the certificate CR # objref: # kind: Certificate # group: cert-manager.io # version: v1 # name: serving-cert # this name should match the one in certificate.yaml # fieldref: # fieldpath: metadata.namespace #- name: CERTIFICATE_NAME # objref: # kind: Certificate # group: cert-manager.io # version: v1 # name: serving-cert # this name should match the one in certificate.yaml #- name: SERVICE_NAMESPACE # namespace of the service # objref: # kind: Service # version: v1 # name: webhook-service # fieldref: # fieldpath: metadata.namespace #- name: SERVICE_NAME # objref: # kind: Service # version: v1 # name: webhook-service ================================================ FILE: deploy/config/default/manager_auth_proxy_patch.yaml ================================================ # This patch inject a sidecar container which is a HTTP proxy for the # controller manager, it performs RBAC authorization against the Kubernetes API using SubjectAccessReviews. apiVersion: apps/v1 kind: Deployment metadata: name: controller-manager namespace: system spec: template: spec: containers: - name: kube-rbac-proxy image: ml-platform-cn-guilin-boe.cr.volces.com/ml-platform/kube-rbac-proxy:0.13.0 args: - "--secure-listen-address=0.0.0.0:8443" - "--upstream=http://127.0.0.1:8080/" - "--logtostderr=true" - "--v=10" ports: - containerPort: 8443 name: https - name: manager args: - "--health-probe-bind-address=:8081" - "--metrics-bind-address=127.0.0.1:8080" - "--leader-elect" ================================================ FILE: deploy/config/default/manager_config_patch.yaml ================================================ apiVersion: apps/v1 kind: Deployment metadata: name: controller-manager namespace: system spec: template: spec: containers: - name: manager args: - "--config=controller_manager_config.yaml" volumeMounts: - name: manager-config mountPath: /controller_manager_config.yaml subPath: controller_manager_config.yaml volumes: - name: manager-config configMap: name: manager-config ================================================ FILE: deploy/config/manager/controller_manager_config.yaml ================================================ apiVersion: controller-runtime.sigs.k8s.io/v1alpha1 kind: ControllerManagerConfig health: healthProbeBindAddress: :8081 metrics: bindAddress: 127.0.0.1:8080 webhook: port: 9443 leaderElection: leaderElect: true resourceName: 183d5a48.volcengine.com ================================================ FILE: deploy/config/manager/kustomization.yaml ================================================ resources: - manager.yaml generatorOptions: disableNameSuffixHash: true configMapGenerator: - files: - controller_manager_config.yaml name: manager-config apiVersion: kustomize.config.k8s.io/v1beta1 kind: Kustomization images: - name: controller newName: ml-platform-cn-guilin-boe.cr.volces.com/ml-platform/data.monolith.controller-manager newTag: b85906ce01ef40a75ba48779efdd4e3f ================================================ FILE: deploy/config/manager/manager.yaml ================================================ apiVersion: v1 kind: Namespace metadata: labels: control-plane: controller-manager name: system --- apiVersion: apps/v1 kind: Deployment metadata: name: controller-manager namespace: system labels: control-plane: controller-manager spec: selector: matchLabels: control-plane: controller-manager replicas: 1 template: metadata: labels: control-plane: controller-manager spec: containers: - command: - ./manager args: - --leader-elect image: controller:latest name: manager securityContext: allowPrivilegeEscalation: false livenessProbe: httpGet: path: /healthz port: 8081 initialDelaySeconds: 15 periodSeconds: 20 readinessProbe: httpGet: path: /readyz port: 8081 initialDelaySeconds: 5 periodSeconds: 10 resources: limits: cpu: 100m memory: 30Mi requests: cpu: 100m memory: 20Mi serviceAccountName: controller-manager terminationGracePeriodSeconds: 10 ================================================ FILE: deploy/config/prometheus/kustomization.yaml ================================================ resources: - monitor.yaml ================================================ FILE: deploy/config/prometheus/monitor.yaml ================================================ # Prometheus Monitor Service (Metrics) apiVersion: monitoring.coreos.com/v1 kind: ServiceMonitor metadata: labels: control-plane: controller-manager name: controller-manager-metrics-monitor namespace: system spec: endpoints: - path: /metrics port: https scheme: https bearerTokenFile: /var/run/secrets/kubernetes.io/serviceaccount/token tlsConfig: insecureSkipVerify: true selector: matchLabels: control-plane: controller-manager ================================================ FILE: deploy/config/rbac/auth_proxy_client_clusterrole.yaml ================================================ apiVersion: rbac.authorization.k8s.io/v1 kind: ClusterRole metadata: name: metrics-reader rules: - nonResourceURLs: - "/metrics" verbs: - get ================================================ FILE: deploy/config/rbac/auth_proxy_role.yaml ================================================ apiVersion: rbac.authorization.k8s.io/v1 kind: ClusterRole metadata: name: proxy-role rules: - apiGroups: - authentication.k8s.io resources: - tokenreviews verbs: - create - apiGroups: - authorization.k8s.io resources: - subjectaccessreviews verbs: - create ================================================ FILE: deploy/config/rbac/auth_proxy_role_binding.yaml ================================================ apiVersion: rbac.authorization.k8s.io/v1 kind: ClusterRoleBinding metadata: name: proxy-rolebinding roleRef: apiGroup: rbac.authorization.k8s.io kind: ClusterRole name: proxy-role subjects: - kind: ServiceAccount name: controller-manager namespace: system ================================================ FILE: deploy/config/rbac/auth_proxy_service.yaml ================================================ apiVersion: v1 kind: Service metadata: labels: control-plane: controller-manager name: controller-manager-metrics-service namespace: system spec: ports: - name: https port: 8443 targetPort: https selector: control-plane: controller-manager ================================================ FILE: deploy/config/rbac/kustomization.yaml ================================================ resources: # All RBAC will be applied under this service account in # the deployment namespace. You may comment out this resource # if your manager will use a service account that exists at # runtime. Be sure to update RoleBinding and ClusterRoleBinding # subjects if changing service account names. - service_account.yaml - role.yaml - role_binding.yaml - leader_election_role.yaml - leader_election_role_binding.yaml # Comment the following 4 lines if you want to disable # the auth proxy (https://github.com/brancz/kube-rbac-proxy) # which protects your /metrics endpoint. - auth_proxy_service.yaml - auth_proxy_role.yaml - auth_proxy_role_binding.yaml - auth_proxy_client_clusterrole.yaml ================================================ FILE: deploy/config/rbac/leader_election_role.yaml ================================================ # permissions to do leader election. apiVersion: rbac.authorization.k8s.io/v1 kind: Role metadata: name: leader-election-role rules: - apiGroups: - "" resources: - configmaps verbs: - get - list - watch - create - update - patch - delete - apiGroups: - coordination.k8s.io resources: - leases verbs: - get - list - watch - create - update - patch - delete - apiGroups: - "" resources: - events verbs: - create - patch ================================================ FILE: deploy/config/rbac/leader_election_role_binding.yaml ================================================ apiVersion: rbac.authorization.k8s.io/v1 kind: RoleBinding metadata: name: leader-election-rolebinding roleRef: apiGroup: rbac.authorization.k8s.io kind: Role name: leader-election-role subjects: - kind: ServiceAccount name: controller-manager namespace: system ================================================ FILE: deploy/config/rbac/mlservice_editor_role.yaml ================================================ # permissions for end users to edit mlservices. apiVersion: rbac.authorization.k8s.io/v1 kind: ClusterRole metadata: name: mlservice-editor-role rules: - apiGroups: - mlplatform.volcengine.com resources: - mlservices verbs: - create - delete - get - list - patch - update - watch - apiGroups: - mlplatform.volcengine.com resources: - mlservices/status verbs: - get ================================================ FILE: deploy/config/rbac/mlservice_viewer_role.yaml ================================================ # permissions for end users to view mlservices. apiVersion: rbac.authorization.k8s.io/v1 kind: ClusterRole metadata: name: mlservice-viewer-role rules: - apiGroups: - mlplatform.volcengine.com resources: - mlservices verbs: - get - list - watch - apiGroups: - mlplatform.volcengine.com resources: - mlservices/status verbs: - get ================================================ FILE: deploy/config/rbac/role.yaml ================================================ --- apiVersion: rbac.authorization.k8s.io/v1 kind: ClusterRole metadata: creationTimestamp: null name: manager-role rules: - apiGroups: - "" resources: - pods verbs: - get - list - apiGroups: - apps resources: - deployments verbs: - create - delete - get - list - patch - update - watch - apiGroups: - apps resources: - replicasets verbs: - get - list - apiGroups: - "" resources: - services verbs: - create - delete - get - list - patch - update - watch - apiGroups: - mlplatform.volcengine.com resources: - mlservices verbs: - create - delete - get - list - patch - update - watch - apiGroups: - mlplatform.volcengine.com resources: - mlservices/finalizers verbs: - update - apiGroups: - mlplatform.volcengine.com resources: - mlservices/status verbs: - get - patch - update ================================================ FILE: deploy/config/rbac/role_binding.yaml ================================================ apiVersion: rbac.authorization.k8s.io/v1 kind: ClusterRoleBinding metadata: name: manager-rolebinding roleRef: apiGroup: rbac.authorization.k8s.io kind: ClusterRole name: manager-role subjects: - kind: ServiceAccount name: controller-manager namespace: system ================================================ FILE: deploy/config/rbac/service_account.yaml ================================================ apiVersion: v1 kind: ServiceAccount metadata: name: controller-manager namespace: system ================================================ FILE: deploy/config/samples/mlplatform_v1_mlservice.yaml ================================================ apiVersion: mlplatform.volcengine.com/v1 kind: MLService metadata: name: mlservice-demo namespace: mlplatform-service spec: selector: matchLabels: app: mlservice-demo roles: - name: "Entry" shardNum: 1 serviceSpec: serviceType: "ClusterIP" template: metadata: labels: app: mlservice-demo spec: progressDeadlineSeconds: 600 replicas: 1 selector: matchLabels: app: mlservice-demo strategy: rollingUpdate: maxSurge: 25% maxUnavailable: 25% type: RollingUpdate template: metadata: labels: app: mlservice-demo spec: containers: - command: - sleep - infinity env: - name: TEST value: "1" image: cr-cn-guilin-boe.volces.com/ml_platform/tfserving:8bf6def4f68f89fd07bce144723f7a97 imagePullPolicy: Always name: mlservice-demo ports: - containerPort: 8500 protocol: TCP - containerPort: 8501 protocol: TCP resources: limits: cpu: "500m" memory: 1Gi requests: cpu: "500m" memory: 1Gi terminationMessagePath: /dev/termination-log terminationMessagePolicy: File dnsPolicy: ClusterFirst restartPolicy: Always securityContext: runAsNonRoot: false terminationGracePeriodSeconds: 30 - name: "PS" shardNum: 2 template: metadata: labels: app: mlservice-demo spec: progressDeadlineSeconds: 600 replicas: 1 selector: matchLabels: app: mlservice-demo strategy: rollingUpdate: maxSurge: 25% maxUnavailable: 25% type: RollingUpdate template: metadata: labels: app: mlservice-demo spec: containers: - command: - sleep - infinity env: - name: TEST value: "1" image: cr-cn-guilin-boe.volces.com/ml_platform/tfserving:8bf6def4f68f89fd07bce144723f7a97 imagePullPolicy: Always name: mlservice-demo ports: - containerPort: 8500 protocol: TCP - containerPort: 8501 protocol: TCP resources: limits: cpu: "500m" memory: 1Gi requests: cpu: "500m" memory: 1Gi terminationMessagePath: /dev/termination-log terminationMessagePolicy: File dnsPolicy: ClusterFirst restartPolicy: Always securityContext: runAsNonRoot: false terminationGracePeriodSeconds: 30 ================================================ FILE: deploy/controllers/constants.go ================================================ package controllers const ( ModuleInference = "inference" MLPlatformVolcPrefix = "mlplatform.volcengine.com" ) const ( ImmutableLabelServiceId = ModuleInference + "." + MLPlatformVolcPrefix + "/service-id" ImmutableLabelRoleName = ModuleInference + "." + MLPlatformVolcPrefix + "/role-name" ImmutableLabelShardId = ModuleInference + "." + MLPlatformVolcPrefix + "/shard-id" ImmutableLabelShardNum = ModuleInference + "." + MLPlatformVolcPrefix + "/shard-num" ) const ( EnvShardId = "MLP_SHARD_ID" EnvPodName = "MLP_POD_NAME" EnvHostIp = "MLP_HOST_IP" EnvShardNum = "MLP_SHARD_NUM" EnvIdc = "MLP_IDC" EnvServiceName = "MLP_SERVICE_NAME" EnvRoleName = "MLP_ROLE_NAME" EnvPort = "MLP_%s_PORT" ) // kubelet const ( PodInitializing = "PodInitializing" ContainerCreating = "ContainerCreating" ) const ( DefaultRpcPort = 8500 DefaultHttpPort = 8501 ) const ( ContainerEvicted = "Evicted" ) // reason for pod const ( ReasonInsufficientClusterResources = "InsufficientClusterResources" ReasonInProgress = "" ReasonStatusNotFound = "StatusNotFound" ReasonEvicted = "Evicted" ReasonServiceExceptionExited = "ExceptionExited" ) ================================================ FILE: deploy/controllers/deployment_handler.go ================================================ package controllers import ( "context" "fmt" "strconv" "strings" monolithv1 "code.byted.org/data/monolith/deploy/api/v1" appsv1 "k8s.io/api/apps/v1" corev1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/equality" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/controller/controllerutil" "sigs.k8s.io/controller-runtime/pkg/log" ) // getDeploymentName returns Deployment name with pattern {mlsvcName}-{role}-{shardIdx} func getDeploymentName(mlsvcName, role string, shardIdx int) string { return fmt.Sprintf("%s-%s-%d", mlsvcName, strings.ToLower(role), shardIdx) } // DeploymentHandler handles with k8s Deployment resource, // make sure deployments owned by MLService in cluster match the desired state the MLService spec defines. func (r *MLServiceReconciler) DeploymentHandler(ctx context.Context, mlsvc *monolithv1.MLService) error { if mlsvc == nil { return nil } log := log.FromContext(ctx).WithName("DeploymentHandler") // delete deployments if MLService is deleted mlsvcDeleting := !mlsvc.GetDeletionTimestamp().IsZero() if mlsvcDeleting { return r.cleanOwnedDeployments(ctx, mlsvc) } for roleIdx, role := range mlsvc.Spec.Roles { shardNum := int(role.ShardNum) if shardNum == 0 { // default value of ShardNum is 1 shardNum = 1 } for shardIdx := 1; shardIdx <= shardNum; shardIdx++ { deploy := &appsv1.Deployment{ ObjectMeta: metav1.ObjectMeta{ Name: getDeploymentName(mlsvc.Name, mlsvc.Spec.Roles[roleIdx].Name, shardIdx), Namespace: mlsvc.Namespace, }, } if _, err := ctrl.CreateOrUpdate(ctx, r.Client, deploy, func() error { template := mlsvc.Spec.Roles[roleIdx].Template.DeepCopy() // set additional labels,annotations,label selector for deployment if template.ObjectMeta.Labels == nil { template.ObjectMeta.Labels = make(map[string]string, 0) } if template.ObjectMeta.Annotations == nil { template.ObjectMeta.Annotations = make(map[string]string, 0) } if template.Spec.Selector.MatchLabels == nil { template.Spec.Selector.MatchLabels = make(map[string]string, 0) } SetAdditionalKeyValuePairs(template.ObjectMeta.Labels, mlsvc.Name, role.Name, &shardIdx, &shardNum) SetAdditionalKeyValuePairs(template.ObjectMeta.Annotations, mlsvc.Name, role.Name, &shardIdx, &shardNum) SetAdditionalKeyValuePairs(template.Spec.Selector.MatchLabels, mlsvc.Name, role.Name, &shardIdx, &shardNum) // set additional labels for pod if template.Spec.Template.ObjectMeta.Labels == nil { template.Spec.Template.ObjectMeta.Labels = make(map[string]string, 0) } SetAdditionalKeyValuePairs(template.Spec.Template.ObjectMeta.Labels, mlsvc.Name, role.Name, &shardIdx, &shardNum) // set additional Env to the container idc, _ := mlsvc.GetAnnotations()[EnvIdc] var ports []corev1.ServicePort if mlsvc.Spec.Roles[roleIdx].ServiceSpec != nil { ports = GetServicePorts(mlsvc.Spec.Roles[roleIdx].ServiceSpec.Ports) } for i := range template.Spec.Template.Spec.Containers { template.Spec.Template.Spec.Containers[i].Env = append(template.Spec.Template.Spec.Containers[i].Env, AdditionalEnvs(mlsvc.Name, role.Name, idc, shardIdx, int(shardNum), ports)..., ) } deploy.ResourceVersion = "" // set ObjectMeta.Labels if deploy.ObjectMeta.Labels == nil { deploy.ObjectMeta.Labels = make(map[string]string) } for k, v := range template.ObjectMeta.Labels { deploy.ObjectMeta.Labels[k] = v } // set ObjectMeta.Annotations if deploy.ObjectMeta.Annotations == nil { deploy.ObjectMeta.Annotations = make(map[string]string) } for k, v := range template.ObjectMeta.Annotations { deploy.ObjectMeta.Annotations[k] = v } // set Finalizers for _, finalizer := range template.ObjectMeta.Finalizers { controllerutil.AddFinalizer(deploy, finalizer) } // set Spec deploy.Spec = template.Spec // set the owner so that garbage collection can kicks in if err := ctrl.SetControllerReference(mlsvc, deploy, r.Scheme); err != nil { log.Error(err, "unable to set ownerReference from MLService to Deployment") return err } // end of ctrl.CreateOrUpdate return nil }); err != nil { // error handling of ctrl.CreateOrUpdate log.Error(err, "unable to ensure deployment is correct") return err } } } return nil } func (r *MLServiceReconciler) createDeployment(ctx context.Context, dp *appsv1.Deployment) error { log := log.FromContext(ctx).WithValues("DeploymentName", dp.Name) if err := r.Client.Create(ctx, dp); err != nil { log.Error(err, "failed to create Deployment resource") return err } log.Info("created Deployment resource for MLService") return nil } func (r *MLServiceReconciler) updateDeployment(ctx context.Context, desired, existing *appsv1.Deployment) error { log := log.FromContext(ctx).WithValues("DeploymentName", existing.Name) if equality.Semantic.DeepEqual(existing, desired) { return nil } if err := r.Client.Update(ctx, desired); err != nil { log.Error(err, "failed to update Deployment resource") return err } log.Info("update Deployment resource for MLService") return nil } func (r *MLServiceReconciler) deleteDeployment(ctx context.Context, dp *appsv1.Deployment) error { log := log.FromContext(ctx).WithValues("DeploymentName", dp.Name) if err := r.Client.Delete(ctx, dp); err != nil { log.Error(err, "failed to delete Deployment resource") return err } log.Info("delete deployment resource: " + dp.Name) return nil } // cleanOwnedDeployments will delete any existing Deployment resources that // were created for the given MLService func (r *MLServiceReconciler) cleanOwnedDeployments(ctx context.Context, mlsvc *monolithv1.MLService) error { log := log.FromContext(ctx).WithValues("MLService", mlsvc.Name) log.Info("finding existing Deployments for MLService resource") // list all deployment resources owned by this MLService deployments, err := r.getOwnedDeployments(ctx, mlsvc) if err != nil { return err } for _, deployment := range deployments.Items { if !deployment.GetDeletionTimestamp().IsZero() { // deployment already deleted, ignore. continue } // delete deployment if err := r.Delete(ctx, &deployment); err != nil { log.Error(err, "failed to delete Deployment resource: "+deployment.Name) return err } log.Info("delete deployment resource: " + deployment.Name) } return nil } // getOwnedDeployments return all deployments owned by the MLService func (r *MLServiceReconciler) getOwnedDeployments(ctx context.Context, mlsvc *monolithv1.MLService) (*appsv1.DeploymentList, error) { var deployments appsv1.DeploymentList if err := r.List(ctx, &deployments, client.InNamespace(mlsvc.Namespace), client.MatchingLabels(mlsvc.Spec.Selector.MatchLabels)); err != nil { return nil, err } return &deployments, nil } // AdditionalEnvs return a list of EnvVar, these Envs will be injected to pod container func AdditionalEnvs(mlsvcName, roleName, idc string, shardIdx, shardNum int, ports []corev1.ServicePort) []corev1.EnvVar { envs := []corev1.EnvVar{ { Name: EnvShardId, Value: strconv.Itoa(shardIdx), }, { Name: EnvShardNum, Value: strconv.Itoa(int(shardNum)), }, { Name: EnvServiceName, Value: mlsvcName, }, { Name: EnvRoleName, Value: roleName, }, { Name: EnvIdc, Value: idc, }, { Name: EnvPodName, ValueFrom: &corev1.EnvVarSource{ FieldRef: &corev1.ObjectFieldSelector{ FieldPath: "metadata.name", }, }, }, { Name: EnvHostIp, ValueFrom: &corev1.EnvVarSource{ FieldRef: &corev1.ObjectFieldSelector{ FieldPath: "status.podIP", }, }, }, } for _, port := range ports { envs = append(envs, corev1.EnvVar{ Name: fmt.Sprintf(EnvPort, strings.ToUpper(string(port.Name))), Value: strconv.Itoa(int(port.Port)), }) } return envs } // SetAdditionalKeyValuePairs inserts additional labels to the existing Labels map func SetAdditionalKeyValuePairs(existing map[string]string, mlsvcName, roleName string, shardIdx, shardNum *int) { additional := map[string]string{ ImmutableLabelServiceId: mlsvcName, ImmutableLabelRoleName: roleName, } if shardIdx != nil { additional[ImmutableLabelShardId] = strconv.Itoa(*shardIdx) } if shardNum != nil { additional[ImmutableLabelShardNum] = strconv.Itoa(*shardNum) } for k, v := range additional { existing[k] = v } } ================================================ FILE: deploy/controllers/mlservice_controller.go ================================================ /* Copyright 2023. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package controllers import ( "context" "reflect" "time" appsv1 "k8s.io/api/apps/v1" corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime" appsclient "k8s.io/client-go/kubernetes/typed/apps/v1" ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/log" monolithv1 "code.byted.org/data/monolith/deploy/api/v1" ) type MLSvcHandler func(ctx context.Context, mlsvc *monolithv1.MLService) error var handlers []MLSvcHandler // MLServiceReconciler reconciles a MLService object type MLServiceReconciler struct { appsclient.AppsV1Client client.Client Scheme *runtime.Scheme } //+kubebuilder:rbac:groups=mlplatform.volcengine.com,resources=mlservices,verbs=get;list;watch;create;update;patch;delete //+kubebuilder:rbac:groups=mlplatform.volcengine.com,resources=mlservices/status,verbs=get;update;patch //+kubebuilder:rbac:groups=mlplatform.volcengine.com,resources=mlservices/finalizers,verbs=update //+kubebuilder:rbac:groups=apps,resources=deployments,verbs=get;list;watch;create;update;patch;delete //+kubebuilder:rbac:groups=core,resources=services,verbs=get;list;watch;create;update;patch;delete //+kubebuilder:rbac:groups=apps,resources=replicasets,verbs=get;list //+kubebuilder:rbac:groups="",resources=pods,verbs=get;list // Reconcile is part of the main kubernetes reconciliation loop which aims to // move the current state of the cluster closer to the desired state. // TODO(user): Modify the Reconcile function to compare the state specified by // the MLService object against the actual cluster state, and then // perform operations to make the cluster state reflect the state specified by // the user. // // For more details, check Reconcile and its Result here: // - https://pkg.go.dev/sigs.k8s.io/controller-runtime@v0.7.2/pkg/reconcile func (r *MLServiceReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) { log := log.FromContext(ctx).WithName("MLService").WithValues("mlservice", req.NamespacedName) var mlsvc monolithv1.MLService if err := r.Get(ctx, req.NamespacedName, &mlsvc); err != nil { log.Error(err, "unable to fetch MLService") return ctrl.Result{}, client.IgnoreNotFound(err) } for _, h := range handlers { err := h(ctx, &mlsvc) if err != nil { log.Error(err, "handler failed") return ctrl.Result{}, err } } if err := r.updateStatus(ctx, &mlsvc); err != nil { log.Error(err, "unable to update status") return ctrl.Result{}, err } // if phase is Stopping, chances are that there will be no events to trigger the Reconcile, // so requeue is needed. if mlsvc.Status.Phase == monolithv1.ServiceStopping { return ctrl.Result{RequeueAfter: 2 * time.Second}, nil } return ctrl.Result{}, nil } // SetupWithManager sets up the controller with the Manager. func (r *MLServiceReconciler) SetupWithManager(mgr ctrl.Manager) error { handlers = []MLSvcHandler{ r.DeploymentHandler, r.ServiceHandler, } return ctrl.NewControllerManagedBy(mgr). For(&monolithv1.MLService{}). Owns(&appsv1.Deployment{}). Owns(&corev1.Service{}). Complete(r) } // updateStatus update the status of MLService according to status of resources owned by this MLService func (r *MLServiceReconciler) updateStatus(ctx context.Context, mlsvc *monolithv1.MLService) error { log := log.FromContext(ctx).WithName("MLService").WithValues("mlservice", mlsvc.Name) // List all deployment resources owned by this MLService deployments, err := r.getOwnedDeployments(ctx, mlsvc) if err != nil { log.Error(err, "get owned deployments failed") return err } // Shard status map var newRoleShardStatusMap map[string]appsv1.DeploymentStatus if len(deployments.Items) > 0 { newRoleShardStatusMap = make(map[string]appsv1.DeploymentStatus) } for _, dp := range deployments.Items { newRoleShardStatusMap[dp.Name] = *dp.Status.DeepCopy() } // List all service resources owned by this MLService services, err := r.getOwnedServices(ctx, mlsvc) if err != nil { log.Error(err, "get owned services failed") return err } // Service status map var newRoleServiceStatusMap map[string]corev1.ServiceStatus if len(services.Items) > 0 { newRoleServiceStatusMap = make(map[string]corev1.ServiceStatus) } for _, svc := range services.Items { newRoleServiceStatusMap[svc.Name] = *svc.Status.DeepCopy() } // Service ClusterIps var newRoleServiceClusterIps map[string]string if len(services.Items) > 0 { newRoleServiceClusterIps = make(map[string]string) } for _, svc := range services.Items { newRoleServiceClusterIps[svc.Name] = svc.Spec.ClusterIP } // phase, reason, message phase, reason, message, err := r.getMLServiceStatus(ctx, mlsvc) if err != nil { log.Error(err, "get MLService status failed") return err } if mlsvc.Status.Phase == phase && mlsvc.Status.Reason == reason && mlsvc.Status.Message == message && reflect.DeepEqual(newRoleShardStatusMap, mlsvc.Status.RoleShardStatusMap) && reflect.DeepEqual(newRoleServiceStatusMap, mlsvc.Status.RoleServiceStatusMap) && reflect.DeepEqual(newRoleServiceClusterIps, mlsvc.Status.RoleServiceClusterIps) { log.Info("no changes of MLService status") return nil } // update MLService status log.Info("MLService status", "phase", phase, "reason", reason, "message", message) mlsvc.Status.RoleShardStatusMap = newRoleShardStatusMap mlsvc.Status.RoleServiceStatusMap = newRoleServiceStatusMap mlsvc.Status.RoleServiceClusterIps = newRoleServiceClusterIps mlsvc.Status.LastTransitionTime = metav1.Now() mlsvc.Status.Phase = phase mlsvc.Status.Reason = reason mlsvc.Status.Message = message if err := r.Status().Update(ctx, mlsvc); err != nil { log.Error(err, "unable to update MLService status") return err } return nil } ================================================ FILE: deploy/controllers/service_handler.go ================================================ package controllers import ( "context" "errors" "fmt" "strings" monolithv1 "code.byted.org/data/monolith/deploy/api/v1" corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/util/intstr" ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/log" ) // getServiceName returns Service name with pattern {mlsvcName}-{role} func getServiceName(mlsvcName, role string) string { return fmt.Sprintf("%s-%s", mlsvcName, strings.ToLower(role)) } // ServiceHandler handles with k8s Service resource, // make sure k8s service owned by MLService in cluster match the desired state the MLService spec defines. func (r *MLServiceReconciler) ServiceHandler(ctx context.Context, mlsvc *monolithv1.MLService) error { if mlsvc == nil { return nil } log := log.FromContext(ctx).WithName("ServiceHandler") // delete sesrvice if MLService is deleted mlsvcDeleting := !mlsvc.GetDeletionTimestamp().IsZero() if mlsvcDeleting { return r.cleanOwnedServices(ctx, mlsvc) } for roleIdx, role := range mlsvc.Spec.Roles { if role.ServiceSpec == nil { continue } if role.ServiceSpec.ServiceType != corev1.ServiceTypeClusterIP { mlsvc.Status.Phase = monolithv1.ServiceAbnormal mlsvc.Status.Message = "Currently only ClusterIP type is supported" log.Info("invalid service type, set status to abnormal", "ServiceType", role.ServiceSpec.ServiceType) if err := r.Status().Update(ctx, mlsvc); err != nil { log.Error(err, "unable to update MLService status") return err } return errors.New(mlsvc.Status.Message) } svc := &corev1.Service{ ObjectMeta: metav1.ObjectMeta{ Name: getServiceName(mlsvc.Name, mlsvc.Spec.Roles[roleIdx].Name), Namespace: mlsvc.Namespace, }, } if _, err := ctrl.CreateOrUpdate(ctx, r.Client, svc, func() error { svc.Spec = corev1.ServiceSpec{ Ports: GetServicePorts(role.ServiceSpec.Ports), Selector: map[string]string{}, Type: corev1.ServiceTypeClusterIP, } // set service labels svc.ObjectMeta.Labels = make(map[string]string, 0) SetAdditionalKeyValuePairs(svc.ObjectMeta.Labels, mlsvc.Name, role.Name, nil, nil) for k, v := range mlsvc.Spec.Selector.MatchLabels { svc.ObjectMeta.Labels[k] = v } // set selector for pods SetAdditionalKeyValuePairs(svc.Spec.Selector, mlsvc.Name, role.Name, nil, nil) // set the owner so that garbage collection can kicks in if err := ctrl.SetControllerReference(mlsvc, svc, r.Scheme); err != nil { log.Error(err, "unable to set ownerReference from MLService to Service") return err } // end of ctrl.CreateOrUpdate return nil }); err != nil { // error handling of ctrl.CreateOrUpdate log.Error(err, "unable to ensure service is correct") return err } } return nil } // cleanOwnedServices will delete any existing Service resources that // were created for the given MLService func (r *MLServiceReconciler) cleanOwnedServices(ctx context.Context, mlsvc *monolithv1.MLService) error { log := log.FromContext(ctx).WithValues("MLService", mlsvc.Name) log.Info("finding existing Service for MLService resource") // List all service resources owned by this MLService services, err := r.getOwnedServices(ctx, mlsvc) if err != nil { return err } for _, svc := range services.Items { if !svc.GetDeletionTimestamp().IsZero() { // Service already deleted, ignore. continue } // Delete service if err := r.Delete(ctx, &svc); err != nil { log.Error(err, "failed to delete Service resource: "+svc.Name) return err } log.Info("delete service resource: " + svc.Name) } return nil } // getOwnedServices return all services owned by the MLService func (r *MLServiceReconciler) getOwnedServices(ctx context.Context, mlsvc *monolithv1.MLService) (*corev1.ServiceList, error) { var services corev1.ServiceList if err := r.List(ctx, &services, client.InNamespace(mlsvc.Namespace), client.MatchingLabels(mlsvc.Spec.Selector.MatchLabels)); err != nil { return nil, err } return &services, nil } // GetServicePorts return HTTP port and gRPC port func GetServicePorts(ports []monolithv1.ServicePort) []corev1.ServicePort { var httpPort int32 = DefaultHttpPort var rpcPort int32 = DefaultRpcPort for _, port := range ports { if port.Type == monolithv1.ServicePortTypeHttp { httpPort = port.Port } else if port.Type == monolithv1.ServicePortTypeRpc { rpcPort = port.Port } else { // ignore non-http and non-rpc port continue } } return []corev1.ServicePort{ { Name: strings.ToLower(string(monolithv1.ServicePortTypeHttp)), Port: httpPort, TargetPort: intstr.IntOrString{ Type: intstr.Int, IntVal: int32(httpPort), }, }, { Name: strings.ToLower(string(monolithv1.ServicePortTypeRpc)), Port: rpcPort, TargetPort: intstr.IntOrString{ Type: intstr.Int, IntVal: int32(rpcPort), }, }, } } ================================================ FILE: deploy/controllers/status.go ================================================ package controllers import ( "context" "fmt" monolithv1 "code.byted.org/data/monolith/deploy/api/v1" appsv1 "k8s.io/api/apps/v1" corev1 "k8s.io/api/core/v1" utildeployment "k8s.io/kubectl/pkg/util/deployment" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/log" ) // getMLServiceStatus return the status of MLService along with a reason and message. // Queuing: at least one deployment owned by this MLService is in status Queuing // Deploying: at least one deployment owned by this MLService is in status Deploying // Running: all deployments owned by this MLService is in status Running // Abnormal: at least one deployment owned by this MLService is in status Abnormal // Deleting: MLService DeletionTimestamp is not zero // Stopping: at least one deployment owned by this MLService is in status Stopping // Stopped: all deployments owned by this MLService is in status Stopped func (r *MLServiceReconciler) getMLServiceStatus(ctx context.Context, mlsvc *monolithv1.MLService) (phase monolithv1.ServicePhase, reason, message string, err error) { log := log.FromContext(ctx).WithName("MLService").WithValues("mlservice", mlsvc.Name) // the status of Deleting if !mlsvc.GetDeletionTimestamp().IsZero() { phase = monolithv1.ServiceDeleting return } // List all deployment resources owned by this MLService deployments, err := r.getOwnedDeployments(ctx, mlsvc) if err != nil { log.Error(err, "get owned deployments failed") return } deploymentStatusCount := make(map[monolithv1.ServicePhase]int, 0) for _, deployment := range deployments.Items { deploymentPhase, deploymentReason, deploymentMessage, dErr := r.getDeploymentStatus(ctx, &deployment) if dErr != nil { err = dErr log.Error(err, "get status of deployment failed", "deployment", deployment.Name) return } deploymentStatusCount[deploymentPhase]++ // at least one is abnormal if deploymentPhase == monolithv1.ServiceAbnormal { phase = deploymentPhase reason = deploymentReason if deploymentMessage != "" { message = fmt.Sprintf("[Deployment %s] %s", deployment.Name, deploymentMessage) } log.Info("at least one deployment is abnormal", "deployment", deployment.Name, "phase", phase, "reason", reason, "message", message) return } // at least one is queuing if deploymentPhase == monolithv1.ServiceQueuing { phase = deploymentPhase log.Info("at least one deployment is queuing", "deployment", deployment.Name, "phase", phase) return } // at least one is stopping if deploymentPhase == monolithv1.ServiceStopping { phase = deploymentPhase log.Info("at least one deployment is stopping", "deployment", deployment.Name, "phase", phase) return } } // all deployments Running if count, ok := deploymentStatusCount[monolithv1.ServiceRunning]; ok && count == len(deployments.Items) { phase = monolithv1.ServiceRunning log.Info("all deployments are running", "phase", phase) return } // all deployments Stopped if count, ok := deploymentStatusCount[monolithv1.ServiceStopped]; ok && count == len(deployments.Items) { phase = monolithv1.ServiceStopped log.Info("all deployments are stopped", "phase", phase) return } phase = monolithv1.ServiceDeploying return } // getDeploymentStatus return the status of Deployment along with a reason and message. // status is generated based on the latest replicaset // Queuing: the latest ReplicaSet not exists, or it's status is Queuing // Deploying: status of the latest ReplicaSet is Deploying // Running: status of the latest ReplicaSet is Running // Abnormal: status of the latest ReplicaSet is Abnormal // Deleting: Deployment DeletionTimestamp is not zero // Stopping: the latest ReplicaSet is Stopping // Stopped: the latest ReplicaSet is Stopped func (r *MLServiceReconciler) getDeploymentStatus(ctx context.Context, deployment *appsv1.Deployment) (phase monolithv1.ServicePhase, reason, message string, err error) { log := log.FromContext(ctx).WithName("Deployment").WithValues("deployment", deployment.Name) // the status of Deleting if !deployment.GetDeletionTimestamp().IsZero() { phase = monolithv1.ServiceDeleting return } // get all replicaset var replicasets appsv1.ReplicaSetList if err = r.List(ctx, &replicasets, client.InNamespace(deployment.Namespace), client.MatchingLabels(deployment.Spec.Selector.MatchLabels)); err != nil { log.Error(err, "list replicasets failed") return } // get latest replicaset _, _, latest, err := utildeployment.GetAllReplicaSets(deployment, &r.AppsV1Client) if latest == nil { log.Info("latest replicaset not found, set phase to queuing") phase = monolithv1.ServiceQueuing return } return r.getReplicaSetStatus(ctx, latest) } // getReplicaSetStatus return the status of ReplicaSet along with a reason and message. // Queuing: 1) All Pods are in status Queuing 2)PodGroup is Pending; // Deploying: at least one Pod is in status Deploying // Running: at least one Pod is in status Running // Abnormal: all Pods are in status Abnormal // Deleting: ReplicaSet DeletionTimestamp is not zero // Stopping: replicas is 0 but pods exits // Stopped: replicas is 0 and no pods exits func (r *MLServiceReconciler) getReplicaSetStatus(ctx context.Context, replicaset *appsv1.ReplicaSet) (phase monolithv1.ServicePhase, reason, message string, err error) { log := log.FromContext(ctx).WithName("ReplicaSet").WithValues("replicaset", replicaset.Name) // list all pods of the replicaset var podList corev1.PodList if err = r.List(ctx, &podList, client.InNamespace(replicaset.Namespace), client.MatchingLabels(replicaset.Spec.Selector.MatchLabels)); client.IgnoreNotFound(err) != nil { log.Error(err, "list pods failed") return } // Stopping if *replicaset.Spec.Replicas == 0 && len(podList.Items) != 0 { phase = monolithv1.ServiceStopping return } // Stopped if *replicaset.Spec.Replicas == 0 && len(podList.Items) == 0 { phase = monolithv1.ServiceStopped return } podStatusCount := make(map[monolithv1.ServicePhase]int, 0) var abnormalReason, abnormalMessage string for _, pod := range podList.Items { podPhase, podReason, podMessage, dErr := r.getPodStatus(ctx, &pod) if dErr != nil { err = dErr log.Error(err, "get pod status failed") return } log.Info("pod status", "pod", pod.Name, "phase", podPhase, "reason", podReason, "message", podMessage) podStatusCount[podPhase]++ // at least one pod is deploying or running if podPhase == monolithv1.ServiceDeploying || podPhase == monolithv1.ServiceRunning { phase = podPhase reason = podReason message = fmt.Sprintf("[Pod %s] %s", pod.Name, podMessage) log.Info(fmt.Sprintf("at least one pod is %s", phase), "pod", pod.Name) return } if podPhase == monolithv1.ServiceAbnormal && (podReason != "" || podMessage != "") { abnormalReason = podReason abnormalMessage = podMessage } } // all pods Queuing if count, ok := podStatusCount[monolithv1.ServiceQueuing]; ok && count == len(podList.Items) { log.Info("all pods are queuing") phase = monolithv1.ServiceQueuing return } // all pods Abnormal if count, ok := podStatusCount[monolithv1.ServiceAbnormal]; ok && count == len(podList.Items) { log.Info("all pods are abnormal") phase = monolithv1.ServiceAbnormal reason = abnormalReason message = abnormalMessage return } return } // getPodStatus return the status of Pod along with a reason and message. // Queuing: condition PodScheduled is False // Deploying: condition PodScheduled is True // Running: condition Ready is True // Abnormal: Pod phase Succeeded、Failed、Unknown, Pending or Running but crash // Deleting: ReplicaSet DeletionTimestamp is not zero // ref: https://kubernetes.io/docs/concepts/workloads/pods/pod-lifecycle/#pod-phase func (r *MLServiceReconciler) getPodStatus(ctx context.Context, pod *corev1.Pod) (phase monolithv1.ServicePhase, reason, message string, err error) { log := log.FromContext(ctx).WithName("Pod").WithValues("pod", pod.Name) // the status of Deleting if !pod.GetDeletionTimestamp().IsZero() { phase = monolithv1.ServiceDeleting return } // Queuing if cond := getPodCondition(pod, corev1.PodScheduled); cond != nil && cond.Status != corev1.ConditionTrue { log.Info("PodScheduled condition is false, set phase to queuing") phase = monolithv1.ServiceQueuing if cond.Message == "" { reason = ReasonInProgress } else { reason = ReasonInsufficientClusterResources } return } // Running if cond := getPodCondition(pod, corev1.PodReady); cond != nil && cond.Status == corev1.ConditionTrue { log.Info("PodReady condition is true, set phase to running") phase = monolithv1.ServiceRunning return } // Abnormal // Abnormal case 1: pod Failure if pod.Status.Phase == corev1.PodFailed || pod.Status.Phase == corev1.PodSucceeded { log.Info(fmt.Sprintf("pod Phase is %s, set phase to abnormal", pod.Status.Phase)) phase = monolithv1.ServiceAbnormal if pod.Status.Reason == ContainerEvicted { reason = ReasonEvicted message = "pod evicted" } else { reason = ReasonServiceExceptionExited message = "pod exited unexpectedly" } return } // Abnormal case 2: pod status unknown if pod.Status.Phase == corev1.PodUnknown { log.Info("pod Phase is Unknown, set phase to abnormal") phase = monolithv1.ServiceAbnormal reason = ReasonStatusNotFound message = "pod in status Unknown" return } // Abnormal case 3: container creating error or exited if pod.Status.Phase == corev1.PodRunning || pod.Status.Phase == corev1.PodPending { for _, status := range pod.Status.InitContainerStatuses { if tmpReason, tmpMessage := getContainerAbnormalMessage(status, true); tmpReason != "" { phase = monolithv1.ServiceAbnormal reason = tmpReason message = tmpMessage log.Info(fmt.Sprintf("pod Phase %s, but InitContainer is abnormal", pod.Status.Phase), "reason", reason, "message", message) return } } for _, status := range pod.Status.ContainerStatuses { if tmpReason, tmpMessage := getContainerAbnormalMessage(status, false); tmpReason != "" { phase = monolithv1.ServiceAbnormal reason = tmpReason message = tmpMessage log.Info(fmt.Sprintf("pod Phase %s, but container is abnormal", pod.Status.Phase), "reason", reason, "message", message) return } } } log.Info("assume phase is deploying in other cases") phase = monolithv1.ServiceDeploying return } func getContainerAbnormalMessage(status corev1.ContainerStatus, isInitContainer bool) (reason, message string) { waiting, terminated := status.State.Waiting, status.State.Terminated if waiting != nil && waiting.Reason != PodInitializing && waiting.Reason != ContainerCreating { return waiting.Reason, waiting.Message } if terminated != nil { if isInitContainer && terminated.ExitCode != 0 { reason = terminated.Reason message = terminated.Message } if !isInitContainer { reason = terminated.Reason if terminated.Message == "" { message = "Container terminated." } else { message = terminated.Message } } return } return } func getPodCondition(pod *corev1.Pod, conditionType corev1.PodConditionType) *corev1.PodCondition { for _, cond := range pod.Status.Conditions { if cond.Type == conditionType { return &cond } } return nil } ================================================ FILE: deploy/go.mod ================================================ module code.byted.org/data/monolith/deploy go 1.15 require ( github.com/onsi/ginkgo v1.16.5 // indirect github.com/onsi/gomega v1.18.1 // indirect k8s.io/api v0.23.5 k8s.io/apimachinery v0.23.5 k8s.io/client-go v0.23.5 k8s.io/kubectl v0.20.6 sigs.k8s.io/controller-runtime v0.10.2 ) replace ( sigs.k8s.io/controller-runtime => sigs.k8s.io/controller-runtime v0.8.3 k8s.io/api => k8s.io/api v0.20.6 k8s.io/apimachinery => k8s.io/apimachinery v0.20.6 k8s.io/client-go => k8s.io/client-go v0.20.6 ) ================================================ FILE: deploy/go.sum ================================================ cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= cloud.google.com/go v0.38.0/go.mod h1:990N+gfupTy94rShfmMCWGDn0LpTmnzTp2qbd1dvSRU= cloud.google.com/go v0.44.1/go.mod h1:iSa0KzasP4Uvy3f1mN/7PiObzGgflwredwwASm/v6AU= cloud.google.com/go v0.44.2/go.mod h1:60680Gw3Yr4ikxnPRS/oxxkBccT6SA1yMk63TGekxKY= cloud.google.com/go v0.45.1/go.mod h1:RpBamKRgapWJb87xiFSdk4g1CME7QZg3uwTez+TSTjc= cloud.google.com/go v0.46.3/go.mod h1:a6bKKbmY7er1mI7TEI4lsAkts/mkhTSZK8w33B4RAg0= cloud.google.com/go v0.50.0/go.mod h1:r9sluTvynVuxRIOHXQEHMFffphuXHOMZMycpNR5e6To= cloud.google.com/go v0.52.0/go.mod h1:pXajvRH/6o3+F9jDHZWQ5PbGhn+o8w9qiu/CffaVdO4= cloud.google.com/go v0.53.0/go.mod h1:fp/UouUEsRkN6ryDKNW/Upv/JBKnv6WDthjR6+vze6M= cloud.google.com/go v0.54.0 h1:3ithwDMr7/3vpAMXiH+ZQnYbuIsh+OPhUPMFC9enmn0= cloud.google.com/go v0.54.0/go.mod h1:1rq2OEkV3YMf6n/9ZvGWI3GWw0VoqH/1x2nd8Is/bPc= cloud.google.com/go v0.56.0/go.mod h1:jr7tqZxxKOVYizybht9+26Z/gUq7tiRzu+ACVAMbKVk= cloud.google.com/go v0.57.0/go.mod h1:oXiQ6Rzq3RAkkY7N6t3TcE6jE+CIBBbA36lwQ1JyzZs= cloud.google.com/go v0.62.0/go.mod h1:jmCYTdRCQuc1PHIIJ/maLInMho30T/Y0M4hTdTShOYc= cloud.google.com/go v0.65.0/go.mod h1:O5N8zS7uWy9vkA9vayVHs65eM1ubvY4h553ofrNHObY= cloud.google.com/go v0.72.0/go.mod h1:M+5Vjvlc2wnp6tjzE102Dw08nGShTscUx2nZMufOKPI= cloud.google.com/go v0.74.0/go.mod h1:VV1xSbzvo+9QJOxLDaJfTjx5e+MePCpCWwvftOeQmWk= cloud.google.com/go v0.78.0/go.mod h1:QjdrLG0uq+YwhjoVOLsS1t7TW8fs36kLs4XO5R5ECHg= cloud.google.com/go v0.79.0/go.mod h1:3bzgcEeQlzbuEAYu4mrWhKqWjmpprinYgKJLgKHnbb8= cloud.google.com/go v0.81.0 h1:at8Tk2zUz63cLPR0JPWm5vp77pEZmzxEQBEfRKn1VV8= cloud.google.com/go v0.81.0/go.mod h1:mk/AM35KwGk/Nm2YSeZbxXdrNK3KZOYHmLkOqC2V6E0= cloud.google.com/go v0.83.0/go.mod h1:Z7MJUsANfY0pYPdw0lbnivPx4/vhy/e2FEkSkF7vAVY= cloud.google.com/go v0.84.0/go.mod h1:RazrYuxIK6Kb7YrzzhPoLmCVzl7Sup4NrbKPg8KHSUM= cloud.google.com/go v0.87.0/go.mod h1:TpDYlFy7vuLzZMMZ+B6iRiELaY7z/gJPaqbMx6mlWcY= cloud.google.com/go v0.90.0/go.mod h1:kRX0mNRHe0e2rC6oNakvwQqzyDmg57xJ+SZU1eT2aDQ= cloud.google.com/go v0.93.3/go.mod h1:8utlLll2EF5XMAV15woO4lSbWQlk8rer9aLOfLh7+YI= cloud.google.com/go v0.94.1/go.mod h1:qAlAugsXlC+JWO+Bke5vCtc9ONxjQT3drlTTnAplMW4= cloud.google.com/go v0.97.0/go.mod h1:GF7l59pYBVlXQIBLx3a761cZ41F9bBH3JUlihCt2Udc= cloud.google.com/go/bigquery v1.0.1/go.mod h1:i/xbL2UlR5RvWAURpBYZTtm/cXjCha9lbfbpx4poX+o= cloud.google.com/go/bigquery v1.3.0/go.mod h1:PjpwJnslEMmckchkHFfq+HTD2DmtT67aNFKH1/VBDHE= cloud.google.com/go/bigquery v1.4.0/go.mod h1:S8dzgnTigyfTmLBfrtrhyYhwRxG72rYxvftPBK2Dvzc= cloud.google.com/go/bigquery v1.5.0/go.mod h1:snEHRnqQbz117VIFhE8bmtwIDY80NLUZUMb4Nv6dBIg= cloud.google.com/go/bigquery v1.7.0/go.mod h1://okPTzCYNXSlb24MZs83e2Do+h+VXtc4gLoIoXIAPc= cloud.google.com/go/bigquery v1.8.0/go.mod h1:J5hqkt3O0uAFnINi6JXValWIb1v0goeZM77hZzJN/fQ= cloud.google.com/go/datastore v1.0.0/go.mod h1:LXYbyblFSglQ5pkeyhO+Qmw7ukd3C+pD7TKLgZqpHYE= cloud.google.com/go/datastore v1.1.0/go.mod h1:umbIZjpQpHh4hmRpGhH4tLFup+FVzqBi1b3c64qFpCk= cloud.google.com/go/firestore v1.1.0/go.mod h1:ulACoGHTpvq5r8rxGJ4ddJZBZqakUQqClKRT5SZwBmk= cloud.google.com/go/pubsub v1.0.1/go.mod h1:R0Gpsv3s54REJCy4fxDixWD93lHJMoZTyQ2kNxGRt3I= cloud.google.com/go/pubsub v1.1.0/go.mod h1:EwwdRX2sKPjnvnqCa270oGRyludottCI76h+R3AArQw= cloud.google.com/go/pubsub v1.2.0/go.mod h1:jhfEVHT8odbXTkndysNHCcx0awwzvfOlguIAii9o8iA= cloud.google.com/go/pubsub v1.3.1/go.mod h1:i+ucay31+CNRpDW4Lu78I4xXG+O1r/MAHgjpRVR+TSU= cloud.google.com/go/storage v1.0.0/go.mod h1:IhtSnM/ZTZV8YYJWCY8RULGVqBDmpoyjwiyrjsg+URw= cloud.google.com/go/storage v1.5.0/go.mod h1:tpKbwo567HUNpVclU5sGELwQWBDZ8gh0ZeosJ0Rtdos= cloud.google.com/go/storage v1.6.0/go.mod h1:N7U0C8pVQ/+NIKOBQyamJIeKQKkZ+mxpohlUTyfDhBk= cloud.google.com/go/storage v1.8.0/go.mod h1:Wv1Oy7z6Yz3DshWRJFhqM/UCfaWIRTdp0RXyy7KQOVs= cloud.google.com/go/storage v1.10.0/go.mod h1:FLPqc6j+Ki4BU591ie1oL6qBQGu2Bl/tZ9ullr3+Kg0= dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= github.com/Azure/go-ansiterm v0.0.0-20170929234023-d6e3b3328b78/go.mod h1:LmzpDX56iTiv29bbRTIsUNlaFfuhWRQBWjQdVyAevI8= github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E= github.com/Azure/go-autorest v14.2.0+incompatible h1:V5VMDjClD3GiElqLWO7mz2MxNAK/vTfRHdAubSIPRgs= github.com/Azure/go-autorest v14.2.0+incompatible/go.mod h1:r+4oMnoxhatjLLJ6zxSWATqVooLgysK6ZNox3g/xq24= github.com/Azure/go-autorest/autorest v0.11.1 h1:eVvIXUKiTgv++6YnWb42DUA1YL7qDugnKP0HljexdnQ= github.com/Azure/go-autorest/autorest v0.11.1/go.mod h1:JFgpikqFJ/MleTTxwepExTKnFUKKszPS8UavbQYUMuw= github.com/Azure/go-autorest/autorest v0.11.18 h1:90Y4srNYrwOtAgVo3ndrQkTYn6kf1Eg/AjTFJ8Is2aM= github.com/Azure/go-autorest/autorest v0.11.18/go.mod h1:dSiJPy22c3u0OtOKDNttNgqpNFY/GeWa7GH/Pz56QRA= github.com/Azure/go-autorest/autorest/adal v0.9.0/go.mod h1:/c022QCutn2P7uY+/oQWWNcK9YU+MH96NgK+jErpbcg= github.com/Azure/go-autorest/autorest/adal v0.9.5 h1:Y3bBUV4rTuxenJJs41HU3qmqsb+auo+a3Lz+PlJPpL0= github.com/Azure/go-autorest/autorest/adal v0.9.5/go.mod h1:B7KF7jKIeC9Mct5spmyCB/A8CG/sEz1vwIRGv/bbw7A= github.com/Azure/go-autorest/autorest/adal v0.9.13 h1:Mp5hbtOePIzM8pJVRa3YLrWWmZtoxRXqUEzCfJt3+/Q= github.com/Azure/go-autorest/autorest/adal v0.9.13/go.mod h1:W/MM4U6nLxnIskrw4UwWzlHfGjwUS50aOsc/I3yuU8M= github.com/Azure/go-autorest/autorest/date v0.3.0 h1:7gUk1U5M/CQbp9WoqinNzJar+8KY+LPI6wiWrP/myHw= github.com/Azure/go-autorest/autorest/date v0.3.0/go.mod h1:BI0uouVdmngYNUzGWeSYnokU+TrmwEsOqdt8Y6sso74= github.com/Azure/go-autorest/autorest/mocks v0.4.0/go.mod h1:LTp+uSrOhSkaKrUy935gNZuuIPPVsHlr9DSOxSayd+k= github.com/Azure/go-autorest/autorest/mocks v0.4.1 h1:K0laFcLE6VLTOwNgSxaGbUcLPuGXlNkbVvq4cW4nIHk= github.com/Azure/go-autorest/autorest/mocks v0.4.1/go.mod h1:LTp+uSrOhSkaKrUy935gNZuuIPPVsHlr9DSOxSayd+k= github.com/Azure/go-autorest/logger v0.2.0 h1:e4RVHVZKC5p6UANLJHkM4OfR1UKZPj8Wt8Pcx+3oqrE= github.com/Azure/go-autorest/logger v0.2.0/go.mod h1:T9E3cAhj2VqvPOtCYAvby9aBXkZmbF5NWuPV8+WeEW8= github.com/Azure/go-autorest/logger v0.2.1 h1:IG7i4p/mDa2Ce4TRyAO8IHnVhAVF3RFU+ZtXWSmf4Tg= github.com/Azure/go-autorest/logger v0.2.1/go.mod h1:T9E3cAhj2VqvPOtCYAvby9aBXkZmbF5NWuPV8+WeEW8= github.com/Azure/go-autorest/tracing v0.6.0 h1:TYi4+3m5t6K48TGI9AUdb+IzbnSxvnvUMfuitfgcfuo= github.com/Azure/go-autorest/tracing v0.6.0/go.mod h1:+vhtPC754Xsa23ID7GlGsrdKBpUA79WCAKPPZVC2DeU= github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/MakeNowJust/heredoc v0.0.0-20170808103936-bb23615498cd/go.mod h1:64YHyfSL2R96J44Nlwm39UHepQbyR5q10x7iYa1ks2E= github.com/MakeNowJust/heredoc v1.0.0/go.mod h1:mG5amYoWBHf8vpLOuehzbGGw0EHxpZZ6lCpQ4fNJ8LE= github.com/NYTimes/gziphandler v0.0.0-20170623195520-56545f4a5d46/go.mod h1:3wb06e3pkSAbeQ52E9H9iFoQsEEwGN64994WTCIhntQ= github.com/NYTimes/gziphandler v1.1.1/go.mod h1:n/CVRwUEOgIxrgPvAQhUUr9oeUtvrhMomdKFjzJNB0c= github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= github.com/PuerkitoBio/purell v1.1.1/go.mod h1:c11w/QuzBsJSee3cPx9rAFu61PvFxuPbtSwDGJws/X0= github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578/go.mod h1:uGdkoq3SwY9Y+13GIhn11/XLaGBb4BfwItxLd5jeuXE= github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= github.com/alecthomas/units v0.0.0-20190717042225-c3de453c63f4/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d/go.mod h1:rBZYJk541a8SKzHPHnH3zbiI+7dagKZ0cgpgrD7Fyho= github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY= github.com/antlr/antlr4/runtime/Go/antlr v0.0.0-20220418222510-f25a4f6275ed/go.mod h1:F7bn7fEU90QkQ3tnmaTx3LTKLEDqnwWODIYppRQ5hnY= github.com/antlr/antlr4/runtime/Go/antlr v1.4.10/go.mod h1:F7bn7fEU90QkQ3tnmaTx3LTKLEDqnwWODIYppRQ5hnY= github.com/armon/circbuf v0.0.0-20150827004946-bbbad097214e/go.mod h1:3U/XgcO3hCbHZ8TKRvWD2dDTCfh9M9ya+I9JpbB7O8o= github.com/armon/go-metrics v0.0.0-20180917152333-f0300d1749da/go.mod h1:Q73ZrmVTwzkszR9V5SSuryQ31EELlFMUz1kKyl939pY= github.com/armon/go-radix v0.0.0-20180808171621-7fddfc383310/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8= github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5/go.mod h1:wHh0iHkYZB8zMSxRWpUBQtwG5a7fFgvEO+odwuTv2gs= github.com/asaskevich/govalidator v0.0.0-20190424111038-f61b66f89f4a/go.mod h1:lB+ZfQJz7igIIfQNfa7Ml4HSf2uFQQRzpGGRXenZAgY= github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/bgentry/speakeasy v0.1.0/go.mod h1:+zsyZBPWlz7T6j88CTgSN5bM796AkVf0kBD4zp0CCIs= github.com/bketelsen/crypt v0.0.3-0.20200106085610-5cbc8cc4026c/go.mod h1:MKsuJmJgSg28kpZDP6UIiPt0e0Oz0kqKNGyRaWEPv84= github.com/blang/semver v3.5.1+incompatible/go.mod h1:kRBLl5iJ+tD4TcOOxsy/0fnwebNt5EWlYSAyrTnjyyk= github.com/blang/semver/v4 v4.0.0/go.mod h1:IbckMUScFkM3pff0VJDNKRiT6TG/YpiHIM2yvyW5YoQ= github.com/cenkalti/backoff/v4 v4.1.1/go.mod h1:scbssz8iZGpm3xbr14ovlUdkxfGXNInqkPWOWmG2CLw= github.com/cenkalti/backoff/v4 v4.1.3/go.mod h1:scbssz8iZGpm3xbr14ovlUdkxfGXNInqkPWOWmG2CLw= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/certifi/gocertifi v0.0.0-20191021191039-0944d244cd40/go.mod h1:sGbDF6GwGcLpkNXPUTkMRoywsNa/ol15pxFe6ERfguA= github.com/certifi/gocertifi v0.0.0-20200922220541-2c3bb06c6054/go.mod h1:sGbDF6GwGcLpkNXPUTkMRoywsNa/ol15pxFe6ERfguA= github.com/cespare/xxhash v1.1.0 h1:a6HrQnmkObjyL+Gs60czilIUGqrzKutQD6XZog3p+ko= github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghfAqPWnc= github.com/cespare/xxhash/v2 v2.1.1 h1:6MnRN8NT7+YBpUIWxHtefFZOKTAPgGjpQSxqLNn0+qY= github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE= github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/chai2010/gettext-go v0.0.0-20160711120539-c6fed771bfd5/go.mod h1:/iP1qXHoty45bqomnu2LM+VVyAEdWN+vtSHGlQgyxbw= github.com/chai2010/gettext-go v1.0.2/go.mod h1:y+wnP2cHYaVj19NZhYKAwEMH2CI1gNHeQQ+5AjwawxA= github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= github.com/cncf/udpa/go v0.0.0-20200629203442-efcf912fb354/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= github.com/cncf/udpa/go v0.0.0-20210930031921-04548b0d99d4/go.mod h1:6pvJx4me5XPnfI9Z40ddWsdw2W/uZgQLFXToKeRcDiI= github.com/cncf/xds/go v0.0.0-20210312221358-fbca930ec8ed/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/cncf/xds/go v0.0.0-20210805033703-aa0b78936158/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/cncf/xds/go v0.0.0-20210922020428-25de7278fc84/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/cncf/xds/go v0.0.0-20211001041855-01bcc9b48dfe/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/cncf/xds/go v0.0.0-20211011173535-cb28da3451f1/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/cockroachdb/datadriven v0.0.0-20190809214429-80d97fb3cbaa/go.mod h1:zn76sxSg3SzpJ0PPJaLDCu+Bu0Lg3sKTORVIj19EIF8= github.com/cockroachdb/datadriven v0.0.0-20200714090401-bf6692d28da5/go.mod h1:h6jFvWxBdQXxjopDMZyH2UVceIRfR84bdzbkoKrsWNo= github.com/cockroachdb/errors v1.2.4/go.mod h1:rQD95gz6FARkaKkQXUksEje/d9a6wBJoCr5oaCLELYA= github.com/cockroachdb/logtags v0.0.0-20190617123548-eb05cc24525f/go.mod h1:i/u985jwjWRlyHXQbwatDASoW0RMlZ/3i9yJHE2xLkI= github.com/coreos/bbolt v1.3.2/go.mod h1:iRUV2dpdMOn7Bo10OQBFzIJO9kkE559Wcmn+qkEiiKk= github.com/coreos/etcd v3.3.13+incompatible/go.mod h1:uF7uidLiAD3TWHmW31ZFd/JWoc32PjwdhPthX9715RE= github.com/coreos/go-oidc v2.1.0+incompatible/go.mod h1:CgnwVTmzoESiwO9qyAFEMiHoZ1nMCKZlZ9V6mm3/LKc= github.com/coreos/go-semver v0.2.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= github.com/coreos/go-semver v0.3.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= github.com/coreos/go-systemd v0.0.0-20180511133405-39ca1b05acc7/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= github.com/coreos/go-systemd/v22 v22.3.2/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/coreos/pkg v0.0.0-20160727233714-3ac0863d7acf/go.mod h1:E3G3o1h8I7cfcXa63jLwjI0eiQQMgzzUDFVpN/nH/eA= github.com/coreos/pkg v0.0.0-20180928190104-399ea9e2e55f/go.mod h1:E3G3o1h8I7cfcXa63jLwjI0eiQQMgzzUDFVpN/nH/eA= github.com/cpuguy83/go-md2man/v2 v2.0.0/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU= github.com/cpuguy83/go-md2man/v2 v2.0.1/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/creack/pty v1.1.7/go.mod h1:lj5s0c3V2DBrqTV7llrYr5NG6My20zk30Fl46Y7DoTY= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/creack/pty v1.1.11/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/daviddengcn/go-colortext v0.0.0-20160507010035-511bcaf42ccd/go.mod h1:dv4zxwHi5C/8AeI+4gX4dCWOIvNi7I6JCSX0HvlKPgE= github.com/daviddengcn/go-colortext v1.0.0/go.mod h1:zDqEI5NVUop5QPpVJUxE9UO10hRnmkD5G4Pmri9+m4c= github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= github.com/dgryski/go-sip13 v0.0.0-20181026042036-e10d5fee7954/go.mod h1:vAd38F8PWV+bWy6jNmig1y/TA+kYO4g3RSRF0IAv0no= github.com/docker/distribution v2.7.1+incompatible/go.mod h1:J2gT2udsDAN96Uj4KfcMRqY0/ypR+oyYUYmja8H+y+w= github.com/docker/distribution v2.8.1+incompatible/go.mod h1:J2gT2udsDAN96Uj4KfcMRqY0/ypR+oyYUYmja8H+y+w= github.com/docker/spdystream v0.0.0-20160310174837-449fdfce4d96/go.mod h1:Qh8CwZgvJUkLughtfhJv5dyTYa91l1fOUCrgjqmcifM= github.com/docopt/docopt-go v0.0.0-20180111231733-ee0de3bc6815/go.mod h1:WwZ+bS3ebgob9U8Nd0kOddGdZWjyMGR8Wziv+TBNwSE= github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/elazarl/goproxy v0.0.0-20180725130230-947c36da3153/go.mod h1:/Zj4wYkgs4iZTTu3o/KG3Itv/qCCa8VVMlb3i9OVuzc= github.com/emicklei/go-restful v0.0.0-20170410110728-ff4f55a20633 h1:H2pdYOb3KQ1/YsqVWoWNLQO+fusocsw354rqGTZtAgw= github.com/emicklei/go-restful v0.0.0-20170410110728-ff4f55a20633/go.mod h1:otzb+WCGbkyDHkqmQmT5YD2WR4BBwUdeQoFo8l/7tVs= github.com/emicklei/go-restful v2.9.5+incompatible/go.mod h1:otzb+WCGbkyDHkqmQmT5YD2WR4BBwUdeQoFo8l/7tVs= github.com/emicklei/go-restful/v3 v3.8.0/go.mod h1:6n3XBCmQQb25CM2LCACGz8ukIrRry+4bhvbpWn3mrbc= github.com/emicklei/go-restful/v3 v3.9.0 h1:XwGDlfxEnQZzuopoqxwSEllNcCOM9DhhFyhFIIGKwxE= github.com/emicklei/go-restful/v3 v3.9.0/go.mod h1:6n3XBCmQQb25CM2LCACGz8ukIrRry+4bhvbpWn3mrbc= github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= github.com/envoyproxy/go-control-plane v0.9.7/go.mod h1:cwu0lG7PUMfa9snN8LXBig5ynNVH9qI8YYLbd1fK2po= github.com/envoyproxy/go-control-plane v0.9.9-0.20201210154907-fd9021fe5dad/go.mod h1:cXg6YxExXjJnVBQHBLXeUAgxn2UodCpnH306RInaBQk= github.com/envoyproxy/go-control-plane v0.9.9-0.20210217033140-668b12f5399d/go.mod h1:cXg6YxExXjJnVBQHBLXeUAgxn2UodCpnH306RInaBQk= github.com/envoyproxy/go-control-plane v0.9.9-0.20210512163311-63b5d3c536b0/go.mod h1:hliV/p42l8fGbc6Y9bQ70uLwIvmJyVE5k4iMKlh8wCQ= github.com/envoyproxy/go-control-plane v0.9.10-0.20210907150352-cf90f659a021/go.mod h1:AFq3mo9L8Lqqiid3OhADV3RfLJnjiw63cSpi+fDTRC0= github.com/envoyproxy/go-control-plane v0.10.2-0.20220325020618-49ff273808a1/go.mod h1:KJwIaB5Mv44NWtYuAOFCVOjcI94vtpEz2JU/D2v6IjE= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= github.com/evanphx/json-patch v0.5.2/go.mod h1:ZWS5hhDbVDyob71nXKNL0+PWn6ToqBHMikGIFbs31qQ= github.com/evanphx/json-patch v4.5.0+incompatible/go.mod h1:50XU6AFN0ol/bzJsmQLiYLvXMP4fmwYFNcr97nuDLSk= github.com/evanphx/json-patch v4.9.0+incompatible h1:kLcOMZeuLAJvL2BPWLMIj5oaZQobrkAqrL+WFZwQses= github.com/evanphx/json-patch v4.9.0+incompatible/go.mod h1:50XU6AFN0ol/bzJsmQLiYLvXMP4fmwYFNcr97nuDLSk= github.com/evanphx/json-patch v4.11.0+incompatible/go.mod h1:50XU6AFN0ol/bzJsmQLiYLvXMP4fmwYFNcr97nuDLSk= github.com/evanphx/json-patch v4.12.0+incompatible h1:4onqiflcdA9EOZ4RxV643DvftH5pOlLGNtQ5lPWQu84= github.com/evanphx/json-patch v4.12.0+incompatible/go.mod h1:50XU6AFN0ol/bzJsmQLiYLvXMP4fmwYFNcr97nuDLSk= github.com/evanphx/json-patch/v5 v5.6.0 h1:b91NhWfaz02IuVxO9faSllyAtNXHMPkC5J8sJCLunww= github.com/evanphx/json-patch/v5 v5.6.0/go.mod h1:G79N1coSVB93tBe7j6PhzjmR3/2VvlbKOFpnXhI9Bw4= github.com/exponent-io/jsonpath v0.0.0-20151013193312-d6023ce2651d/go.mod h1:ZZMPRZwes7CROmyNKgQzC3XPs6L/G2EJLHddWejkmf4= github.com/fatih/camelcase v1.0.0/go.mod h1:yN2Sb0lFhZJUdVvtELVWefmrXpuZESvPmqwoZc+/fpc= github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4= github.com/felixge/httpsnoop v1.0.3/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= github.com/form3tech-oss/jwt-go v3.2.2+incompatible h1:TcekIExNqud5crz4xD2pavyTgWiPvpYe4Xau31I0PRk= github.com/form3tech-oss/jwt-go v3.2.2+incompatible/go.mod h1:pbq4aXjuKjdthFRnoDwaVPLA+WlJuPGy+QneDUgJi2k= github.com/form3tech-oss/jwt-go v3.2.3+incompatible h1:7ZaBxOI7TMoYBfyA3cQHErNNyAWIKUMIwqxEtgHOs5c= github.com/form3tech-oss/jwt-go v3.2.3+incompatible/go.mod h1:pbq4aXjuKjdthFRnoDwaVPLA+WlJuPGy+QneDUgJi2k= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= github.com/fsnotify/fsnotify v1.6.0 h1:n+5WquG0fcWoWp6xPWfHdbskMCQaFnG6PfBrh1Ky4HY= github.com/fsnotify/fsnotify v1.6.0/go.mod h1:sl3t1tCWJFWoRz9R8WJCbQihKKwmorjAbSClcnxKAGw= github.com/fvbommel/sortorder v1.0.1/go.mod h1:uk88iVf1ovNn1iLfgUVU2F9o5eO30ui720w+kxuqRs0= github.com/getkin/kin-openapi v0.76.0/go.mod h1:660oXbgy5JFMKreazJaQTw7o+X00qeSyhcnluiMv+Xg= github.com/getsentry/raven-go v0.2.0/go.mod h1:KungGk8q33+aIAZUIVWZDr2OfAEBsO49PX4NzFV5kcQ= github.com/ghodss/yaml v0.0.0-20150909031657-73d445a93680/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= github.com/go-errors/errors v1.0.1 h1:LUHzmkK3GUKUrL/1gfBUxAHzcev3apQlezX/+O7ma6w= github.com/go-errors/errors v1.0.1/go.mod h1:f4zRHt4oKfwPJE5k8C9vpYG+aDHdBFUsgrm6/TyX73Q= github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY= github.com/go-kit/log v0.2.0/go.mod h1:NwTd00d/i8cPZ3xOwwiv2PO5MOcx78fFErGNcVmBjv0= github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9GBnD5lWE= github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V4qmtdjCk= github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A= github.com/go-logfmt/logfmt v0.5.1/go.mod h1:WYhtIu8zTZfxdn5+rREduYbwxfcBr/Vr6KEVveWlfTs= github.com/go-logr/logr v0.1.0/go.mod h1:ixOQHD9gLJUVQQ2ZOR7zLEifBX6tGkNJF4QyIY7sIas= github.com/go-logr/logr v0.2.0/go.mod h1:z6/tIYblkpsD+a4lm/fGIIU9mZ+XfAiaFtq7xTgseGU= github.com/go-logr/logr v0.3.0 h1:q4c+kbcR0d5rSurhBR8dIgieOaYpXtsdTYfx22Cu6rs= github.com/go-logr/logr v0.3.0/go.mod h1:z6/tIYblkpsD+a4lm/fGIIU9mZ+XfAiaFtq7xTgseGU= github.com/go-logr/logr v1.2.0 h1:QK40JKJyMdUDz+h+xvCsru/bJhvG0UxvePV0ufL/AcE= github.com/go-logr/logr v1.2.0/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= github.com/go-logr/logr v1.2.3 h1:2DntVwHkVopvECVRSlL5PSo9eG+cAkDCuckLubN+rq0= github.com/go-logr/logr v1.2.3/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= github.com/go-logr/zapr v0.2.0 h1:v6Ji8yBW77pva6NkJKQdHLAJKrIJKRHz0RXwPqCHSR4= github.com/go-logr/zapr v0.2.0/go.mod h1:qhKdvif7YF5GI9NWEpyxTSSBdGmzkNguibrdCNVPunU= github.com/go-logr/zapr v1.2.3 h1:a9vnzlIBPQBBkeaR9IuMUfmVOrQlkoC4YfPoFkX3T7A= github.com/go-logr/zapr v1.2.3/go.mod h1:eIauM6P8qSvTw5o2ez6UEAfGjQKrxQTl5EoK+Qa2oG4= github.com/go-openapi/jsonpointer v0.19.2/go.mod h1:3akKfEdA7DF1sugOqz1dVQHBcuDBPKZGEoHC/NkiQRg= github.com/go-openapi/jsonpointer v0.19.3/go.mod h1:Pl9vOtqEWErmShwVjC8pYs9cog34VGT37dQOVbmoatg= github.com/go-openapi/jsonpointer v0.19.5 h1:gZr+CIYByUqjcgeLXnQu2gHYQC9o73G2XUeOFYEICuY= github.com/go-openapi/jsonpointer v0.19.5/go.mod h1:Pl9vOtqEWErmShwVjC8pYs9cog34VGT37dQOVbmoatg= github.com/go-openapi/jsonreference v0.19.2/go.mod h1:jMjeRr2HHw6nAVajTXJ4eiUwohSTlpa0o73RUL1owJc= github.com/go-openapi/jsonreference v0.19.3/go.mod h1:rjx6GuL8TTa9VaixXglHmQmIL98+wF9xc8zWvFonSJ8= github.com/go-openapi/jsonreference v0.19.5/go.mod h1:RdybgQwPxbL4UEjuAruzK1x3nE69AqPYEJeo/TWfEeg= github.com/go-openapi/jsonreference v0.20.0 h1:MYlu0sBgChmCfJxxUKZ8g1cPWFOB37YSZqewK7OKeyA= github.com/go-openapi/jsonreference v0.20.0/go.mod h1:Ag74Ico3lPc+zR+qjn4XBUmXymS4zJbYVCZmcgkasdo= github.com/go-openapi/spec v0.19.3/go.mod h1:FpwSN1ksY1eteniUU7X0N/BgJ7a4WvBFVA8Lj9mJglo= github.com/go-openapi/swag v0.19.2/go.mod h1:POnQmlKehdgb5mhVOsnJFsivZCEZ/vjK9gh66Z9tfKk= github.com/go-openapi/swag v0.19.5/go.mod h1:POnQmlKehdgb5mhVOsnJFsivZCEZ/vjK9gh66Z9tfKk= github.com/go-openapi/swag v0.19.14 h1:gm3vOOXfiuw5i9p5N9xJvfjvuofpyvLA9Wr6QfK5Fng= github.com/go-openapi/swag v0.19.14/go.mod h1:QYRuS/SOXUCsnplDa677K7+DxSOj6IPNl/eQntq43wQ= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= github.com/gogo/protobuf v1.2.1/go.mod h1:hp+jE20tsWTFYpLwKvXlhS1hjn+gTNwPg2I6zVXpSg4= github.com/gogo/protobuf v1.3.1/go.mod h1:SlYgWuQ5SjCEi6WLHjHCa1yvBfUnHcTbrrZtXPKa29o= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/glog v1.0.0/go.mod h1:EWib/APOK0SL3dFbYqvxE3UYd8E6s1ouQ7iEp/0LWV4= github.com/golang/groupcache v0.0.0-20160516000752-02826c3e7903/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20190129154638-5b532d6fd5ef/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e h1:1r7pUrabqp18hOBcwBwiTsbnFeTZHV9eER/QT5JVZxY= github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE= github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/mock v1.3.1/go.mod h1:sBzyDLLjw3U8JLTeZvSv8jJB+tU5PVekmnlKIyFUx0Y= github.com/golang/mock v1.4.0/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw= github.com/golang/mock v1.4.1/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw= github.com/golang/mock v1.4.3/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw= github.com/golang/mock v1.4.4/go.mod h1:l3mdAwkq5BuhzHwde/uurv3sEJeZMXNpwsxVWU71h+4= github.com/golang/mock v1.5.0/go.mod h1:CWnOUgYIOo4TcNZ0wHX3YZCqsaM1I1Jvs6v3mP3KVu8= github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= github.com/golang/protobuf v1.3.4/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= github.com/golang/protobuf v1.3.5/go.mod h1:6O5/vntMXwX2lRkT1hjjk0nAC1IDOTvTlVgjlRvqsdk= github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8= github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/golang/protobuf v1.5.1/go.mod h1:DopwsBzvsk0Fs44TXzsVbJyPhcCPeIwnvohx4u74HPM= github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw= github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/golangplus/bytes v0.0.0-20160111154220-45c989fe5450/go.mod h1:Bk6SMAONeMXrxql8uvOKuAZSu8aM5RUGv+1C6IJaEho= github.com/golangplus/bytes v1.0.0/go.mod h1:AdRaCFwmc/00ZzELMWb01soso6W1R/++O1XL80yAn+A= github.com/golangplus/fmt v0.0.0-20150411045040-2a5d6d7d2995/go.mod h1:lJgMEyOkYFkPcDKwRXegd+iM6E7matEszMG5HhwytU8= github.com/golangplus/fmt v1.0.0/go.mod h1:zpM0OfbMCjPtd2qkTD/jX2MgiFCqklhSUFyDW44gVQE= github.com/golangplus/testing v0.0.0-20180327235837-af21d9c3145e/go.mod h1:0AA//k/eakGydO4jKRoRL2j92ZKSzTgj9tclaCrvXHk= github.com/golangplus/testing v1.0.0/go.mod h1:ZDreixUV3YzhoVraIDyOzHrr76p6NUh6k/pPg/Q3gYA= github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA= github.com/google/cel-go v0.12.5/go.mod h1:Jk7ljRzLBhkmiAwBoUxB1sZSCVBAzkqPF25olK/iRDw= github.com/google/gnostic v0.5.7-v3refs h1:FhTMOKj2VhjpouxvWJAV1TL304uMlb9zcDqkl6cEI54= github.com/google/gnostic v0.5.7-v3refs/go.mod h1:73MKFl6jIHelAJNaBGFzt3SPtZULs9dYrGFt8OiIsHQ= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.4.1/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.1/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.3/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.1.0 h1:Hsa8mG0dQ46ij8Sl2AYJDUv1oA9/d6Vk+3LG99Oe02g= github.com/google/gofuzz v1.1.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= github.com/google/martian/v3 v3.0.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0= github.com/google/martian/v3 v3.1.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0= github.com/google/martian/v3 v3.2.1/go.mod h1:oBOf6HBosgwRXnUGWUB05QECsc6uvmMiJ3+6W4l/CUk= github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= github.com/google/pprof v0.0.0-20190515194954-54271f7e092f/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= github.com/google/pprof v0.0.0-20191218002539-d4f498aebedc/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM= github.com/google/pprof v0.0.0-20200212024743-f11f1df84d12/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM= github.com/google/pprof v0.0.0-20200229191704-1ebb73c60ed3/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM= github.com/google/pprof v0.0.0-20200430221834-fc25d7d30c6d/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM= github.com/google/pprof v0.0.0-20200708004538-1a94d8640e99/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM= github.com/google/pprof v0.0.0-20201023163331-3e6fc7fc9c4c/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= github.com/google/pprof v0.0.0-20201203190320-1bf35d6f28c2/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= github.com/google/pprof v0.0.0-20210122040257-d980be63207e/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= github.com/google/pprof v0.0.0-20210226084205-cbba55b83ad5/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= github.com/google/pprof v0.0.0-20210601050228-01bbb1931b22/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= github.com/google/pprof v0.0.0-20210609004039-a478d1d731e9/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= github.com/google/pprof v0.0.0-20210720184732-4bb14d4b1be1/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaUGG7oYTSPP8MxqL4YI3kZKwcP4= github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3pAXIBCelhxNneeOaAeabZDe5s4K6zSpQ= github.com/google/uuid v1.0.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.1.2 h1:EVhdT+1Kseyi1/pUmXKaFxYsDNy9RQYkMWRH68J/W7Y= github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg= github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk= github.com/googleapis/gax-go/v2 v2.1.0/go.mod h1:Q3nei7sK6ybPYH7twZdmQpAd1MKb7pfu6SK+H1/DsU0= github.com/googleapis/gnostic v0.4.1/go.mod h1:LRhVm6pbyptWbWbuZ38d1eyptfvIytN3ir6b65WBswg= github.com/googleapis/gnostic v0.5.1 h1:A8Yhf6EtqTv9RMsU6MQTyrtV1TjWlR6xU9BsZIwuTCM= github.com/googleapis/gnostic v0.5.1/go.mod h1:6U4PtQXGIEt/Z3h5MAT7FNofLnw9vXk2cUuW7uA/OeU= github.com/googleapis/gnostic v0.5.5 h1:9fHAtK0uDfpveeqqo1hkEZJcFvYXAiCN3UutL8F9xHw= github.com/googleapis/gnostic v0.5.5/go.mod h1:7+EbHbldMins07ALC74bsA81Ovc97DwqyJO1AENw9kA= github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= github.com/gorilla/websocket v0.0.0-20170926233335-4201258b820c/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ= github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/gregjones/httpcache v0.0.0-20180305231024-9cad4c3443a7/go.mod h1:FecbI9+v66THATjSRHfNgh1IVFe/9kFxbXtjV0ctIMA= github.com/grpc-ecosystem/go-grpc-middleware v1.0.0/go.mod h1:FiyG127CGDf3tlThmgyCl78X/SZQqEOJBCDaAfeWzPs= github.com/grpc-ecosystem/go-grpc-middleware v1.0.1-0.20190118093823-f849b5445de4/go.mod h1:FiyG127CGDf3tlThmgyCl78X/SZQqEOJBCDaAfeWzPs= github.com/grpc-ecosystem/go-grpc-middleware v1.3.0/go.mod h1:z0ButlSOZa5vEBq9m2m2hlwIgKw+rp3sdCBRoJY+30Y= github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0/go.mod h1:8NvIoxWQoOIhqOTXgfV/d3M/q6VIi02HzZEHgUlZvzk= github.com/grpc-ecosystem/grpc-gateway v1.9.0/go.mod h1:vNeuVxBJEsws4ogUvrchl83t/GYV9WGTSLVdBhOQFDY= github.com/grpc-ecosystem/grpc-gateway v1.9.5/go.mod h1:vNeuVxBJEsws4ogUvrchl83t/GYV9WGTSLVdBhOQFDY= github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw= github.com/grpc-ecosystem/grpc-gateway/v2 v2.7.0/go.mod h1:hgWBS7lorOAVIJEQMi4ZsPv9hVvWI6+ch50m39Pf2Ks= github.com/hashicorp/consul/api v1.1.0/go.mod h1:VmuI/Lkw1nC05EYQWNKwWGbkg+FbDBtguAZLlVdkD9Q= github.com/hashicorp/consul/sdk v0.1.1/go.mod h1:VKf9jXwCTEY1QZP2MOLRhb5i/I/ssyNV1vwHyQBF0x8= github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/go-cleanhttp v0.5.1/go.mod h1:JpRdi6/HCYpAwUzNwuwqhbovhLtngrth3wmdIIUrZ80= github.com/hashicorp/go-immutable-radix v1.0.0/go.mod h1:0y9vanUI8NX6FsYoO3zeMjhV/C5i9g4Q3DwcSNZ4P60= github.com/hashicorp/go-msgpack v0.5.3/go.mod h1:ahLV/dePpqEmjfWmKiqvPkv/twdG7iPBM1vqhUKIvfM= github.com/hashicorp/go-multierror v1.0.0/go.mod h1:dHtQlpGsu+cZNNAkkCN/P3hoUDHhCYQXV3UM06sGGrk= github.com/hashicorp/go-rootcerts v1.0.0/go.mod h1:K6zTfqpRlCUIjkwsN4Z+hiSfzSTQa6eBIzfwKfwNnHU= github.com/hashicorp/go-sockaddr v1.0.0/go.mod h1:7Xibr9yA9JjQq1JpNB2Vw7kxv8xerXegt+ozgdvDeDU= github.com/hashicorp/go-syslog v1.0.0/go.mod h1:qPfqrKkXGihmCqbJM2mZgkZGvKG1dFdvsLplgctolz4= github.com/hashicorp/go-uuid v1.0.0/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= github.com/hashicorp/go-uuid v1.0.1/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= github.com/hashicorp/go.net v0.0.1/go.mod h1:hjKkEWcCURg++eb33jQU7oqQcI9XDCnUzHA0oac0k90= github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hashicorp/golang-lru v0.5.4 h1:YDjusn29QI/Das2iO9M0BHnIbxPeyuCHsjMW+lJfyTc= github.com/hashicorp/golang-lru v0.5.4/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4= github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= github.com/hashicorp/logutils v1.0.0/go.mod h1:QIAnNjmIWmVIIkWDTG1z5v++HQmx9WQRO+LraFDTW64= github.com/hashicorp/mdns v1.0.0/go.mod h1:tL+uN++7HEJ6SQLQ2/p+z2pH24WQKWjBPkE0mNTz8vQ= github.com/hashicorp/memberlist v0.1.3/go.mod h1:ajVTdAv/9Im8oMAAj5G31PhhMCZJV2pPBoIllUwCN7I= github.com/hashicorp/serf v0.8.2/go.mod h1:6hOLApaqBFA1NXqRQAsxw9QxuDEvNxSQRwA/JwenrHc= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= github.com/imdario/mergo v0.3.5/go.mod h1:2EnlNZ0deacrJVfApfmtdGgDfMuh/nq6Ok1EcJh5FfA= github.com/imdario/mergo v0.3.6 h1:xTNEAn+kxVO7dTZGu0CegyqKZmoWFI0rF8UxjlB2d28= github.com/imdario/mergo v0.3.6/go.mod h1:2EnlNZ0deacrJVfApfmtdGgDfMuh/nq6Ok1EcJh5FfA= github.com/imdario/mergo v0.3.10 h1:6q5mVkdH/vYmqngx7kZQTjJ5HRsx+ImorDIEQ+beJgc= github.com/imdario/mergo v0.3.10/go.mod h1:jmQim1M+e3UYxmgPu/WyfjB3N3VflVyUjjjwH0dnCYA= github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= github.com/inconshreveable/mousetrap v1.0.1/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI= github.com/jonboulle/clockwork v0.1.0/go.mod h1:Ii8DK3G1RaLaWxj9trq07+26W01tbo22gdxWY5EU2bo= github.com/jonboulle/clockwork v0.2.2/go.mod h1:Pkfl5aHPm1nk2H9h0bjmnJD/BcgbGXUBGnn1kMkgxc8= github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4= github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= github.com/json-iterator/go v1.1.7/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/json-iterator/go v1.1.10 h1:Kz6Cvnvv2wGdaG/V8yMvfkmNiXq9Ya2KUv4rouJJr68= github.com/json-iterator/go v1.1.10/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/json-iterator/go v1.1.11/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU= github.com/jstemmer/go-junit-report v0.9.1/go.mod h1:Brl9GWCQeLvo8nXZwPNNblvFj/XSXhF0NWZEnDohbsk= github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w= github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM= github.com/kisielk/errcheck v1.1.0/go.mod h1:EZBBE59ingxPouuu3KfxchcWSUPOHkagtvWXihfKN4Q= github.com/kisielk/errcheck v1.2.0/go.mod h1:/BMXB+zMLi60iA8Vv6Ksmxu/1UDYcXs4uQLJ+jE2L00= github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/konsorten/go-windows-terminal-sequences v1.0.3/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.2.0/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/pty v1.1.5/go.mod h1:9r2w37qlBe7rQ6e1fg1S/9xpWHSnaqNdHD3WcMdbPDA= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/liggitt/tabwriter v0.0.0-20181228230101-89fcab3d43de/go.mod h1:zAbeS9B/r2mtpb6U+EI2rYA5OAXxsYw6wTamcNW+zcE= github.com/lithammer/dedent v1.1.0/go.mod h1:jrXYCQtgg0nJiN+StA2KgR7w6CiQNv9Fd/Z9BP0jIOc= github.com/magiconair/properties v1.8.1/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czPbwD3XqdrwzmxQ= github.com/mailru/easyjson v0.0.0-20190614124828-94de47d64c63/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= github.com/mailru/easyjson v0.0.0-20190626092158-b2ccc519800e/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= github.com/mailru/easyjson v0.7.0/go.mod h1:KAzv3t3aY1NaHWoQz1+4F1ccyAH66Jk7yos7ldAVICs= github.com/mailru/easyjson v0.7.6 h1:8yTIVnZgCoiM1TgqoeTl+LfU5Jg6/xL3QhGQnimLYnA= github.com/mailru/easyjson v0.7.6/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= github.com/mattn/go-colorable v0.0.9/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaOChaDxuIBZU= github.com/mattn/go-isatty v0.0.3/go.mod h1:M+lRXTBqGeGNdLjl/ufCoiOlB5xdOkqRJdNxMWT7Zi4= github.com/mattn/go-isatty v0.0.4/go.mod h1:M+lRXTBqGeGNdLjl/ufCoiOlB5xdOkqRJdNxMWT7Zi4= github.com/mattn/go-runewidth v0.0.2/go.mod h1:LwmH8dsx7+W8Uxz3IHJYH5QSwggIsqBzpuz5H//U1FU= github.com/mattn/go-runewidth v0.0.7/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= github.com/matttproud/golang_protobuf_extensions v1.0.2-0.20181231171920-c182affec369 h1:I0XW9+e1XWDxdcEniV4rQAIOPUGDq67JSCiRCgGCZLI= github.com/matttproud/golang_protobuf_extensions v1.0.2-0.20181231171920-c182affec369/go.mod h1:BSXmuO+STAnVfrANrmjBb36TMTDstsz7MSK+HVaYKv4= github.com/matttproud/golang_protobuf_extensions v1.0.2 h1:hAHbPm5IJGijwng3PWk09JkG9WeqChjprR5s9bBZ+OM= github.com/matttproud/golang_protobuf_extensions v1.0.2/go.mod h1:BSXmuO+STAnVfrANrmjBb36TMTDstsz7MSK+HVaYKv4= github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg= github.com/mitchellh/cli v1.0.0/go.mod h1:hNIlj7HEI86fIcpObd7a0FcrxTWetlwJDGcceTlRvqc= github.com/mitchellh/go-homedir v1.0.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= github.com/mitchellh/go-testing-interface v1.0.0/go.mod h1:kRemZodwjscx+RGhAo8eIhFbs2+BFgRtFPeD/KE+zxI= github.com/mitchellh/go-wordwrap v1.0.0/go.mod h1:ZXFpozHsX6DPmq2I0TCekCxypsnAUbP2oI0UX1GXzOo= github.com/mitchellh/gox v0.4.0/go.mod h1:Sd9lOJ0+aimLBi73mGofS1ycjY8lL3uZM3JPS42BGNg= github.com/mitchellh/iochan v1.0.0/go.mod h1:JwYml1nuB7xOzsp52dPpHFffvOCDupsG0QubkSMEySY= github.com/mitchellh/mapstructure v0.0.0-20160808181253-ca63d7c062ee/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y= github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y= github.com/mitchellh/mapstructure v1.4.1/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= github.com/moby/spdystream v0.2.0/go.mod h1:f7i0iNDQJ059oMTcWxx8MA/zKFIuD/lY+0GqbN2Wy8c= github.com/moby/term v0.0.0-20200312100748-672ec06f55cd/go.mod h1:DdlQx2hp0Ss5/fLikoLlEeIYiATotOjgB//nb973jeo= github.com/moby/term v0.0.0-20220808134915-39b0c02b01ae/go.mod h1:E2VnQOmVuvZB6UYnnDB0qG5Nq/1tD9acaOpo6xmt0Kw= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= github.com/modern-go/reflect2 v1.0.1 h1:9f412s+6RmYXLWZSEzVVgPGK7C2PphHj5RJrvfx9AWI= github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= github.com/monochromegane/go-gitignore v0.0.0-20200626010858-205db1a8cc00 h1:n6/2gBQ3RWajuToeY6ZtZTIKv2v7ThUy5KKusIT0yc0= github.com/monochromegane/go-gitignore v0.0.0-20200626010858-205db1a8cc00/go.mod h1:Pm3mSP3c5uWn86xMLZ5Sa7JB9GsEZySvHYXCTK4E9q4= github.com/munnerz/goautoneg v0.0.0-20120707110453-a547fc61f48d/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f/go.mod h1:ZdcZmHo+o7JKHSa8/e818NopupXU1YMK5fe1lsApnBw= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A= github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU= github.com/oklog/ulid v1.3.1/go.mod h1:CirwcVhetQ6Lv90oh/F+FBtV6XMibvdAFo93nm5qn4U= github.com/olekukonko/tablewriter v0.0.0-20170122224234-a0225b3f23b5/go.mod h1:vsDQFd/mU46D+Z4whnwzcISnGGzXWMclvtLoiIKAKIo= github.com/olekukonko/tablewriter v0.0.4/go.mod h1:zq6QwlOf5SlnkVbMSr5EoBv3636FWnp+qbPhuoO21uA= github.com/onsi/ginkgo v0.0.0-20170829012221-11459a886d9c/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/ginkgo v1.11.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/ginkgo v1.12.1/go.mod h1:zj2OWP4+oCPe1qIXoGWkgMRwljMUYCdkwsT2108oapk= github.com/onsi/ginkgo v1.14.0/go.mod h1:iSB4RoI2tjJc9BBv4NKIKWKya62Rps+oPG/Lv9klQyY= github.com/onsi/ginkgo v1.14.1/go.mod h1:iSB4RoI2tjJc9BBv4NKIKWKya62Rps+oPG/Lv9klQyY= github.com/onsi/ginkgo v1.16.4 h1:29JGrr5oVBm5ulCWet69zQkzWipVXIol6ygQUe/EzNc= github.com/onsi/ginkgo v1.16.4/go.mod h1:dX+/inL/fNMqNlz0e9LfyB9TswhZpCVdJM/Z6Vvnwo0= github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE= github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU= github.com/onsi/ginkgo/v2 v2.0.0/go.mod h1:vw5CSIxN1JObi/U8gcbwft7ZxR2dgaR70JSE3/PpL4c= github.com/onsi/ginkgo/v2 v2.1.3/go.mod h1:vw5CSIxN1JObi/U8gcbwft7ZxR2dgaR70JSE3/PpL4c= github.com/onsi/ginkgo/v2 v2.1.4/go.mod h1:um6tUpWM/cxCK3/FK8BXqEiUMUwRgSM4JXG47RKZmLU= github.com/onsi/ginkgo/v2 v2.1.6/go.mod h1:MEH45j8TBi6u9BMogfbp0stKC5cdGjumZj5Y7AG4VIk= github.com/onsi/ginkgo/v2 v2.3.0/go.mod h1:Eew0uilEqZmIEZr8JrvYlvOM7Rr6xzTmMV8AyFNU9d0= github.com/onsi/ginkgo/v2 v2.4.0 h1:+Ig9nvqgS5OBSACXNk15PLdp0U9XPYROt9CFzVdFGIs= github.com/onsi/ginkgo/v2 v2.4.0/go.mod h1:iHkDK1fKGcBoEHT5W7YBq4RFWaQulw+caOMkAt4OrFo= github.com/onsi/ginkgo/v2 v2.5.0/go.mod h1:Luc4sArBICYCS8THh8v3i3i5CuSZO+RaQRaJoeNwomw= github.com/onsi/ginkgo/v2 v2.6.0/go.mod h1:63DOGlLAH8+REH8jUGdL3YpCpu7JODesutUjdENfUAc= github.com/onsi/gomega v0.0.0-20170829124025-dcabb60a477c/go.mod h1:C1qb7wdrVGGVU+Z6iS04AVkA3Q65CEZX59MT0QO5uiA= github.com/onsi/gomega v1.7.0/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY= github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo= github.com/onsi/gomega v1.10.2/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo= github.com/onsi/gomega v1.17.0/go.mod h1:HnhC7FXeEQY45zxNK3PPoIUhzk/80Xly9PcubAlGdZY= github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE= github.com/onsi/gomega v1.18.1/go.mod h1:0q+aL8jAiMXy9hbwj2mr5GziHiwhAIQpFmmtT5hitRs= github.com/onsi/gomega v1.19.0/go.mod h1:LY+I3pBVzYsTBU1AnDwOSxaYi9WoWiqgwooUqq9yPro= github.com/onsi/gomega v1.20.1/go.mod h1:DtrZpjmvpn2mPm4YWQa0/ALMDj9v4YxLgojwPeREyVo= github.com/onsi/gomega v1.21.1/go.mod h1:iYAIXgPSaDHak0LCMA+AWBpIKBr8WZicMxnE8luStNc= github.com/onsi/gomega v1.22.1/go.mod h1:x6n7VNe4hw0vkyYUM4mjIXx3JbLiPaBPNgB7PRQ1tuM= github.com/onsi/gomega v1.23.0 h1:/oxKu9c2HVap+F3PfKort2Hw5DEU+HGlW8n+tguWsys= github.com/onsi/gomega v1.23.0/go.mod h1:Z/NWtiqwBrwUt4/2loMmHL63EDLnYHmVbuBpDr2vQAg= github.com/onsi/gomega v1.24.0/go.mod h1:Z/NWtiqwBrwUt4/2loMmHL63EDLnYHmVbuBpDr2vQAg= github.com/onsi/gomega v1.24.1/go.mod h1:3AOiACssS3/MajrniINInwbfOOtfZvplPzuRSmvt1jM= github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= github.com/opentracing/opentracing-go v1.1.0/go.mod h1:UkNAQd3GIcIGf0SeVgPpRdFStlNbqXla1AfSYxPUl2o= github.com/pascaldekloe/goe v0.0.0-20180627143212-57f6aae5913c/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc= github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic= github.com/peterbourgon/diskv v2.0.1+incompatible/go.mod h1:uqqh8zWWbv1HBMNONnaR/tNboyR3/BZd58JJSHlUSCU= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/posener/complete v1.1.1/go.mod h1:em0nMJCgc9GFtwrmVmEMR/ZL6WyhyjMBndrE9hABlRI= github.com/pquerna/cachecontrol v0.0.0-20171018203845-0dec1b30a021/go.mod h1:prYjPmNq4d1NPVmpShWobRqXY3q7Vp+80DqgxxUrUIA= github.com/pquerna/cachecontrol v0.1.0/go.mod h1:NrUG3Z7Rdu85UNR3vm7SOsl1nFIeSiQnrHV5K9mBcUI= github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= github.com/prometheus/client_golang v0.9.3/go.mod h1:/TN21ttK/J9q6uSwhBd54HahCDft0ttaMvbicHlPoso= github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo= github.com/prometheus/client_golang v1.7.1 h1:NTGy1Ja9pByO+xAeH/qiWnLrKtr3hJPNjaVUwnjpdpA= github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M= github.com/prometheus/client_golang v1.11.0/go.mod h1:Z6t4BnS23TR94PD6BsDNk8yVqroYurpAkEiz0P2BEV0= github.com/prometheus/client_golang v1.11.1/go.mod h1:Z6t4BnS23TR94PD6BsDNk8yVqroYurpAkEiz0P2BEV0= github.com/prometheus/client_golang v1.12.1/go.mod h1:3Z9XVyYiZYEO+YQWt3RD2R3jrbd179Rt297l4aS6nDY= github.com/prometheus/client_golang v1.14.0 h1:nJdhIvne2eSX/XRAFV9PcvFFRbrjbcTUj0VP62TMhnw= github.com/prometheus/client_golang v1.14.0/go.mod h1:8vpkKitgIVNcqrRBWh1C4TIUQgYNtG/XQE4E/Zae36Y= github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/prometheus/client_model v0.2.0 h1:uq5h0d+GuxiXLJLNABMgp2qUWDPiLvgCzz2dUR+/W/M= github.com/prometheus/client_model v0.2.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/prometheus/client_model v0.3.0 h1:UBgGFHqYdG/TPFD1B1ogZywDqEkwp3fBMvqdiQ7Xew4= github.com/prometheus/client_model v0.3.0/go.mod h1:LDGWKZIo7rky3hgvBe+caln+Dr3dPggB5dvjtD7w9+w= github.com/prometheus/common v0.0.0-20181113130724-41aa239b4cce/go.mod h1:daVV7qP5qjZbuso7PdcryaAu0sAZbrN9i7WWcTMWvro= github.com/prometheus/common v0.4.0/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4= github.com/prometheus/common v0.4.1/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4= github.com/prometheus/common v0.10.0 h1:RyRA7RzGXQZiW+tGMr7sxa85G1z0yOpM1qq5c8lNawc= github.com/prometheus/common v0.10.0/go.mod h1:Tlit/dnDKsSWFlCLTWaA1cyBgKHSMdTB80sz/V91rCo= github.com/prometheus/common v0.26.0/go.mod h1:M7rCNAaPfAosfx8veZJCuw84e35h3Cfd9VFqTh1DIvc= github.com/prometheus/common v0.32.1/go.mod h1:vu+V0TpY+O6vW9J44gczi3Ap/oXXR10b+M/gUGO4Hls= github.com/prometheus/common v0.37.0 h1:ccBbHCgIiT9uSoFY0vX8H3zsNR5eLt17/RQLUvn8pXE= github.com/prometheus/common v0.37.0/go.mod h1:phzohg0JFMnBEFGxTDbfu3QyL5GI8gTQJFhYO5B3mfA= github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= github.com/prometheus/procfs v0.0.0-20190507164030-5867b95ac084/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA= github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA= github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU= github.com/prometheus/procfs v0.2.0 h1:wH4vA7pcjKuZzjF7lM8awk4fnuJO6idemZXoKnULUx4= github.com/prometheus/procfs v0.2.0/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU= github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA= github.com/prometheus/procfs v0.7.3/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA= github.com/prometheus/procfs v0.8.0 h1:ODq8ZFEaYeCaZOJlZZdJA2AbQR98dSHSM1KW/You5mo= github.com/prometheus/procfs v0.8.0/go.mod h1:z7EfXMXOkbkqb9IINtpCn86r/to3BnA0uaxHdg830/4= github.com/prometheus/tsdb v0.7.1/go.mod h1:qhTCs0VvXwvX/y3TZrWD7rabWM+ijKTux40TwIPHuXU= github.com/rogpeppe/fastuuid v0.0.0-20150106093220-6724a57986af/go.mod h1:XWv6SoW27p1b0cqNHllgS5HIMJraePCO15w5zCzIWYg= github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g= github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/ryanuber/columnize v0.0.0-20160712163229-9b3edd62028f/go.mod h1:sm1tb6uqfes/u+d4ooFouqFdy9/2g9QGwK3SQygK0Ts= github.com/sean-/seed v0.0.0-20170313163322-e2103e2c3529/go.mod h1:DxrIzT+xaE7yg65j358z/aeFdxmN0P9QXhEzd20vsDc= github.com/sergi/go-diff v1.1.0/go.mod h1:STckp+ISIX8hZLjrqAeVduY0gWCT9IjLuqbuNXdaHfM= github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc= github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88= github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= github.com/sirupsen/logrus v1.8.1/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc= github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA= github.com/soheilhy/cmux v0.1.4/go.mod h1:IM3LyeVVIOuxMH7sFAkER9+bJ4dT7Ms6E4xg4kGIyLM= github.com/soheilhy/cmux v0.1.5/go.mod h1:T7TcVDs9LWfQgPlPsdngu6I6QIoyIFZDDC6sNE1GqG0= github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA= github.com/spf13/afero v1.1.2/go.mod h1:j4pytiNVoe2o6bmDsKpLACNPDBIoEAkihy7loJ1B0CQ= github.com/spf13/afero v1.2.2/go.mod h1:9ZxEEn6pIJ8Rxe320qSDBk6AsU0r9pR7Q4OcevTdifk= github.com/spf13/cast v1.3.0/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE= github.com/spf13/cobra v0.0.3/go.mod h1:1l0Ry5zgKvJasoi3XT1TypsSe7PqH0Sj9dhYf7v3XqQ= github.com/spf13/cobra v1.1.1/go.mod h1:WnodtKOvamDL/PwE2M4iKs8aMDBZ5Q5klgD3qfVJQMI= github.com/spf13/cobra v1.1.3/go.mod h1:pGADOWyqRD/YMrPZigI/zbliZ2wVD/23d+is3pSWzOo= github.com/spf13/cobra v1.4.0/go.mod h1:Wo4iy3BUC+X2Fybo0PDqwJIv3dNRiZLHQymsfxlB84g= github.com/spf13/cobra v1.6.0/go.mod h1:IOw/AERYS7UzyrGinqmz6HLUo219MORXGxhbaJUqzrY= github.com/spf13/jwalterweatherman v1.0.0/go.mod h1:cQK4TGJAtQXfYWX+Ddv3mKDzgVb68N+wFjFa4jdeBTo= github.com/spf13/pflag v0.0.0-20170130214245-9ff6c6923cff/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4= github.com/spf13/pflag v1.0.1/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4= github.com/spf13/pflag v1.0.3/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/spf13/viper v1.7.0/go.mod h1:8WkrPz2fc9jxqZNCJI/76HCieCp4Q8HaLFoCha5qpdg= github.com/stoewer/go-strcase v1.2.0/go.mod h1:IBiWB2sKIp3wVVQ3Y035++gc+knqhUQag1KpM8ahLw8= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/subosito/gotenv v1.2.0/go.mod h1:N0PQaV/YGNqwC0u51sEeR/aUtSLEXKX9iv69rRypqCw= github.com/tmc/grpc-websocket-proxy v0.0.0-20170815181823-89b8d40f7ca8/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U= github.com/tmc/grpc-websocket-proxy v0.0.0-20190109142713-0ad062ec5ee5/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U= github.com/tmc/grpc-websocket-proxy v0.0.0-20201229170055-e5319fda7802/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U= github.com/urfave/cli v1.20.0/go.mod h1:70zkFmudgCuE/ngEzBv17Jvp/497gISqfk5gWijbERA= github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2/go.mod h1:UETIi67q53MR2AWcXfiuqkDkRtnGDLqkBTpCHuJHxtU= github.com/xlab/treeprint v1.1.0 h1:G/1DjNkPpfZCFt9CSh6b5/nY4VimlbHF3Rh4obvtzDk= github.com/xlab/treeprint v1.1.0/go.mod h1:gj5Gd3gPdKtR1ikdDK6fnFLdmIS0X30kTTuNd/WEJu0= github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= github.com/yuin/goldmark v1.4.1/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= go.etcd.io/bbolt v1.3.2/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU= go.etcd.io/bbolt v1.3.3/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU= go.etcd.io/bbolt v1.3.5/go.mod h1:G5EMThwa9y8QZGBClrRx5EY+Yw9kAhnjy3bSjsnlVTQ= go.etcd.io/bbolt v1.3.6/go.mod h1:qXsaaIqmgQH0T+OPdb99Bf+PKfBBQVAdyD6TY9G8XM4= go.etcd.io/etcd v0.5.0-alpha.5.0.20200910180754-dd1b699fc489/go.mod h1:yVHk9ub3CSBatqGNg7GRmsnfLWtoW60w4eDYfh7vHDg= go.etcd.io/etcd/api/v3 v3.5.5/go.mod h1:KFtNaxGDw4Yx/BA4iPPwevUTAuqcsPxzyX8PHydchN8= go.etcd.io/etcd/client/pkg/v3 v3.5.5/go.mod h1:ggrwbk069qxpKPq8/FKkQ3Xq9y39kbFR4LnKszpRXeQ= go.etcd.io/etcd/client/v2 v2.305.5/go.mod h1:zQjKllfqfBVyVStbt4FaosoX2iYd8fV/GRy/PbowgP4= go.etcd.io/etcd/client/v3 v3.5.5/go.mod h1:aApjR4WGlSumpnJ2kloS75h6aHUmAyaPLjHMxpc7E7c= go.etcd.io/etcd/pkg/v3 v3.5.5/go.mod h1:6ksYFxttiUGzC2uxyqiyOEvhAiD0tuIqSZkX3TyPdaE= go.etcd.io/etcd/raft/v3 v3.5.5/go.mod h1:76TA48q03g1y1VpTue92jZLr9lIHKUNcYdZOOGyx8rI= go.etcd.io/etcd/server/v3 v3.5.5/go.mod h1:rZ95vDw/jrvsbj9XpTqPrTAB9/kzchVdhRirySPkUBc= go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU= go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8= go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= go.opencensus.io v0.22.3/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= go.opencensus.io v0.22.4/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= go.opencensus.io v0.22.5/go.mod h1:5pWMHQbX5EPX2/62yrJeAkowc+lfs/XD7Uxpq3pI6kk= go.opencensus.io v0.23.0/go.mod h1:XItmlyltB5F7CS4xOC1DcqMoFqwtC6OG2xF7mCv7P7E= go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.25.0/go.mod h1:E5NNboN0UqSAki0Atn9kVwaN7I+l25gGxDqBueo/74E= go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.35.0/go.mod h1:h8TWwRAhQpOd0aM5nYsRD8+flnkj+526GEIVlarH7eY= go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.35.0/go.mod h1:9NiG9I2aHTKkcxqCILhjtyNA1QEiCjdBACv4IvrFQ+c= go.opentelemetry.io/otel v1.0.1/go.mod h1:OPEOD4jIT2SlZPMmwT6FqZz2C0ZNdQqiWcoK6M0SNFU= go.opentelemetry.io/otel v1.8.0/go.mod h1:2pkj+iMj0o03Y+cW6/m8Y4WkRdYN3AvCXCnzRMp9yvM= go.opentelemetry.io/otel v1.10.0/go.mod h1:NbvWjCthWHKBEUMpf0/v8ZRZlni86PpGFEMA9pnQSnQ= go.opentelemetry.io/otel/exporters/otlp/internal/retry v1.10.0/go.mod h1:78XhIg8Ht9vR4tbLNUhXsiOnE2HOuSeKAiAcoVQEpOY= go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.0.1/go.mod h1:Kv8liBeVNFkkkbilbgWRpV+wWuu+H5xdOT6HAgd30iw= go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.10.0/go.mod h1:Krqnjl22jUJ0HgMzw5eveuCvFDXY4nSYb4F8t5gdrag= go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.0.1/go.mod h1:xOvWoTOrQjxjW61xtOmD/WKGRYb/P4NzRo3bs65U6Rk= go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.10.0/go.mod h1:OfUCyyIiDvNXHWpcWgbF+MWvqPZiNa3YDEnivcnYsV0= go.opentelemetry.io/otel/metric v0.31.0/go.mod h1:ohmwj9KTSIeBnDBm/ZwH2PSZxZzoOaG2xZeekTRzL5A= go.opentelemetry.io/otel/sdk v1.0.1/go.mod h1:HrdXne+BiwsOHYYkBE5ysIcv2bvdZstxzmCQhxTcZkI= go.opentelemetry.io/otel/sdk v1.10.0/go.mod h1:vO06iKzD5baltJz1zarxMCNHFpUlUiOy4s65ECtn6kE= go.opentelemetry.io/otel/trace v1.0.1/go.mod h1:5g4i4fKLaX2BQpSBsxw8YYcgKpMMSW3x7ZTuYBr3sUk= go.opentelemetry.io/otel/trace v1.8.0/go.mod h1:0Bt3PXY8w+3pheS3hQUt+wow8b1ojPaTBoTCh2zIFI4= go.opentelemetry.io/otel/trace v1.10.0/go.mod h1:Sij3YYczqAdz+EhmGhE6TpTxUO5/F/AzrK+kxfGqySM= go.opentelemetry.io/proto/otlp v0.7.0/go.mod h1:PqfVotwruBrMGOCsRd/89rSnXhoiJIqeYNgFYFoEGnI= go.opentelemetry.io/proto/otlp v0.9.0/go.mod h1:1vKfU9rv61e9EVGthD1zNvUbiwPcimSsOPU9brfSHJg= go.opentelemetry.io/proto/otlp v0.19.0/go.mod h1:H7XAot3MsfNsj7EXtrA2q5xSNQ10UqI405h3+duxN4U= go.starlark.net v0.0.0-20200306205701-8dd3e2ee1dd5 h1:+FNtrFTmVw0YZGpBGX56XDee331t6JAXeK2bcyhLOOc= go.starlark.net v0.0.0-20200306205701-8dd3e2ee1dd5/go.mod h1:nmDLcffg48OtT/PSW0Hg7FvpRQsQh5OSqIylirxKC7o= go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= go.uber.org/atomic v1.6.0 h1:Ezj3JGmsOnG1MoRWQkPBsKLe9DwWD9QeXzTRzzldNVk= go.uber.org/atomic v1.6.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= go.uber.org/atomic v1.7.0 h1:ADUqmZGgLDDfbSL9ZmPxKTybcoEYHgpYfELNoN+7hsw= go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/goleak v1.1.10 h1:z+mqJhf6ss6BSfSM671tgKyZBFPTTJM+HLxnhPC3wu0= go.uber.org/goleak v1.1.10/go.mod h1:8a7PlsEVH3e/a/GLqe5IIrQx6GzcnRmZEufDUTk4A7A= go.uber.org/goleak v1.1.11/go.mod h1:cwTWslyiVhfpKIDGSZEM2HlOvcqm+tG4zioyIeLoqMQ= go.uber.org/goleak v1.1.12/go.mod h1:cwTWslyiVhfpKIDGSZEM2HlOvcqm+tG4zioyIeLoqMQ= go.uber.org/goleak v1.2.0/go.mod h1:XJYK+MuIchqpmGmUSAzotztawfKvYLUIgg7guXrwVUo= go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0= go.uber.org/multierr v1.5.0 h1:KCa4XfM8CWFCpxXRGok+Q0SS/0XBhMDbHHGABQLvD2A= go.uber.org/multierr v1.5.0/go.mod h1:FeouvMocqHpRaaGuG9EjoKcStLC43Zu/fmqdUMPcKYU= go.uber.org/multierr v1.6.0 h1:y6IPFStTAIT5Ytl7/XYmHvzXQ7S3g/IeZW9hyZ5thw4= go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU= go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee h1:0mgffUl7nfd+FpvXMVz4IDEaUSmT1ysygQC7qYo7sG4= go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee/go.mod h1:vJERXedbb3MVM5f9Ejo0C68/HhF8uaILCdgjnY+goOA= go.uber.org/zap v1.8.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= go.uber.org/zap v1.15.0 h1:ZZCA22JRF2gQE5FoNmhmrf7jeJJ2uhqDUNRYKm8dvmM= go.uber.org/zap v1.15.0/go.mod h1:Mb2vm2krFEG5DV0W9qcHBYFtp/Wku1cvYaqPsS/WYfc= go.uber.org/zap v1.17.0/go.mod h1:MXVU+bhUf/A7Xi2HNOnopQOrmycQ5Ih87HtOu4q5SSo= go.uber.org/zap v1.19.0/go.mod h1:xg/QME4nWcxGxrpdeYfq7UvYrLh66cuVKdrbD1XF/NI= go.uber.org/zap v1.24.0 h1:FiJd5l1UOLj0wCgbSE0rwwXHzEdAZS6hiiSnxJN/D60= go.uber.org/zap v1.24.0/go.mod h1:2kMP+WWQ8aoFoedH3T2sq6iJ2yDWpHbP0f6MQbS9Gkg= golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20181029021203-45a5f77698d3/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190611184440-5c40567a22f8/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20201002170205-7f63de1d35b0 h1:hb9wdF1z5waM+dSIICn1l0DkLVDT3hqhhQsDNUmHPRE= golang.org/x/crypto v0.0.0-20201002170205-7f63de1d35b0/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20210817164053-32db794688a5 h1:HWj/xjIHfjYU5nVXpTM0s39J9CbLn7Cc5a7IC5rwsMQ= golang.org/x/crypto v0.0.0-20210817164053-32db794688a5/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20220411220226-7b82a4e95df4/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.1.0 h1:MDRAIl0xIo9Io2xV565hzXHw3zVseKrJKodhohM5CjU= golang.org/x/crypto v0.1.0/go.mod h1:RecgLatLF4+eUMCP1PoPZQb+cVrJcOPbHkTkbkB9sbw= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= golang.org/x/exp v0.0.0-20190829153037-c13cbed26979/go.mod h1:86+5VVa7VpoJ4kLfm080zCjGlMRFzhUhsZKEZO7MGek= golang.org/x/exp v0.0.0-20191030013958-a1ab85dbe136/go.mod h1:JXzH8nQsPlswgeRAPE3MuO9GYsAcnJvJ4vnMwN/5qkY= golang.org/x/exp v0.0.0-20191129062945-2f5052295587/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4= golang.org/x/exp v0.0.0-20191227195350-da58074b4299/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4= golang.org/x/exp v0.0.0-20200119233911-0405dc783f0a/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4= golang.org/x/exp v0.0.0-20200207192155-f17229e696bd/go.mod h1:J/WKrq2StrnmMY6+EHIKF9dgMWnmCNThgcyBT1FY9mM= golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU= golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= golang.org/x/lint v0.0.0-20190301231843-5614ed5bae6f/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/lint v0.0.0-20190409202823-959b441ac422/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/lint v0.0.0-20190909230951-414d861bb4ac/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/lint v0.0.0-20191125180803-fdd1cda4f05f/go.mod h1:5qLYkcX4OjUUV8bRuDixDT3tpyyb+LUpUlRWLxfhWrs= golang.org/x/lint v0.0.0-20200130185559-910be7a94367/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= golang.org/x/lint v0.0.0-20201208152925-83fdc39ff7b5 h1:2M3HP5CCK1Si9FQhwnzYhXdG6DXeebvUHFpre8QvbyI= golang.org/x/lint v0.0.0-20201208152925-83fdc39ff7b5/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= golang.org/x/lint v0.0.0-20210508222113-6edffad5e616/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= golang.org/x/mobile v0.0.0-20190312151609-d3739f865fa6/go.mod h1:z+o9i4GpDbdi3rU15maQ/Ox0txvL9dWGYEHz965HBQE= golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028/go.mod h1:E/iHnbuqvinMTCcRqshq8CkpyQDoeVncDDYHnLhea+o= golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc= golang.org/x/mod v0.1.0/go.mod h1:0QHyrYULN0/3qlju5TqG8bIK38QM8yzMo5ekMj3DlcY= golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= golang.org/x/mod v0.1.1-0.20191107180719-034126e5016b/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.1/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3/go.mod h1:3p9vT2HGsQu2K1YbXdKPJLVgG5VJdoTa1poYQBtP1AY= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.6.0 h1:b9gGHsz9/HhJ3HF5DHQytPpuwocVTChQJK3AvoLRD5I= golang.org/x/mod v0.6.0/go.mod h1:4mET923SAdbXp2ki8ey+zGs1SLqsuM2Y0uvdZR/fUNI= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20181023162649-9b4f9f5ad519/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20181201002055-351d144fa1fc/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20181220203305-927f97764cc3/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190501004415-9ce7a6920f09/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190503192946-f4e77d36d62c/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= golang.org/x/net v0.0.0-20190613194153-d28f0bde5980/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20190628185345-da137c7871d7/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20190724013045-ca1201d0de80/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20190813141303-74dc4d7220e7/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20190827160401-ba9fcec4b297/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20191209160850-c0dbc17a3553/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200222125558-5a598a2470a0/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200301022130-244492dfa37a/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200324143707-d3edc9973b7e/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= golang.org/x/net v0.0.0-20200501053045-e0ff5e5a1de5/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= golang.org/x/net v0.0.0-20200506145744-7e3656a0809f/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= golang.org/x/net v0.0.0-20200513185701-a91f0712d120/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= golang.org/x/net v0.0.0-20200520004742-59133d7f0dd7/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= golang.org/x/net v0.0.0-20200520182314-0ba52f642ac2/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20200707034311-ab3426394381/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20201031054903-ff519b6c9102/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20201202161906-c7110b5ffcbb/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20201209123823-ac852fbbde11/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210119194325-5f4716e94777/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210316092652-d523dce5a7f4/go.mod h1:RBQZq4jEuRlivfhVLdyRGr576XBO4/greRjx4P4O3yc= golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= golang.org/x/net v0.0.0-20210428140749-89ef3d95e781 h1:DzZ89McO9/gWPsQXS/FVKAlG02ZjaQ6AlZRBimEYOd0= golang.org/x/net v0.0.0-20210428140749-89ef3d95e781/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk= golang.org/x/net v0.0.0-20210503060351-7fd8e65b6420/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20210525063256-abc453219eb5/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20211209124913-491a49abca63 h1:iocB37TsdFuN6IBRZ+ry36wrkoV51/tl5vOWqkcPGvY= golang.org/x/net v0.0.0-20211209124913-491a49abca63/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/net v0.0.0-20220225172249-27dd8689420f/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/net v0.0.0-20220425223048-2871e0cb64e4/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.1.0 h1:hZ/3BUoy5aId7sCpA/Tc5lt8DkFgdVS2onTpJsZ/fl0= golang.org/x/net v0.1.0/go.mod h1:Cx3nUiGt4eDBEyega/BKRp+/AlGL8hYe7U9odMt2Cco= golang.org/x/net v0.2.0/go.mod h1:KqCZLdyyvdV855qA2rE3GC2aiw5xGR5TEjj8smXukLY= golang.org/x/net v0.3.1-0.20221206200815-1e63c2f08a10 h1:Frnccbp+ok2GkUS2tC84yAq/U9Vg+0sIO7aRL3T4Xnc= golang.org/x/net v0.3.1-0.20221206200815-1e63c2f08a10/go.mod h1:MBQ8lrhLObU/6UmLb4fmbmk5OcyYmqtbGd/9yIeKjEE= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20191202225959-858c2ad4c8b6/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d h1:TzXSXBo42m9gQenoE3b9BGiEpg5IG2JkU5FkPIawgtw= golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20200902213428-5d25da1a8d43/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= golang.org/x/oauth2 v0.0.0-20201109201403-9fd604954f58/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= golang.org/x/oauth2 v0.0.0-20201208152858-08078c50e5b5/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= golang.org/x/oauth2 v0.0.0-20210218202405-ba52d332ba99/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= golang.org/x/oauth2 v0.0.0-20210220000619-9bb904979d93/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= golang.org/x/oauth2 v0.0.0-20210313182246-cd4f82c27b84/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= golang.org/x/oauth2 v0.0.0-20210514164344-f6687ab2804c/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= golang.org/x/oauth2 v0.0.0-20210628180205-a41e5a781914/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= golang.org/x/oauth2 v0.0.0-20210805134026-6f1e6394065a/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= golang.org/x/oauth2 v0.0.0-20210819190943-2bc19b11175f h1:Qmd2pbz05z7z6lm0DrgQVVPuBm92jqujBKMHMOlOQEw= golang.org/x/oauth2 v0.0.0-20210819190943-2bc19b11175f/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= golang.org/x/oauth2 v0.0.0-20211104180415-d3ed0bb246c8/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= golang.org/x/oauth2 v0.0.0-20220223155221-ee480838109b h1:clP8eMhB30EHdc0bd2Twtq6kgU7yl5ub2cQLSdrv1Dg= golang.org/x/oauth2 v0.0.0-20220223155221-ee480838109b/go.mod h1:DAh4E804XQdzx2j+YRIaUnCqCV2RuMz24cGBJ5QYIrc= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20200317015054-43a5402ce75a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220601150217-0de741cfad7f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181026203630-95b1ffbd15a5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181107165924-66b7b1311ac8/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190502145724-3ef323f4f1fd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190507160741-ecd444e8653b/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190606165138-5da285871e9c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190616124812-15dcb6c0061f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190624142023-c5567b49c5d0/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190726091711-fc99dfbffb4e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190826190057-c7b8b68b1456/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190904154756-749cb33beabd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191001151750-bb3f8db39f24/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191002063906-3421d5a6bb1c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191204072324-ce4227a45e2e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191228213918-04cbcbbfeed8/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200106162015-b016eb3dc98e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200113162924-86b910548bc1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200122134326-e047566fdf82/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200202164722-d101bd2416d5/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200212091648-12a6c2dcc1e4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200302150141-5c8b2ff67527/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200331124033-c3d80250170d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200501052902-10377860bb8e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200511232937-7e40ca221e25/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200515095857-1151b9dac4a9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200519105757-fe76b779f299/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200523222454-059865788121/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200625212154-ddb9806d33ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200803210538-64077c9b5642/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200905004654-be1d3432aa8f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200923182605-d9f96fdee20d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201112073958-5cba982894dd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201201145000-ef89a241ccb3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210104204734-6f8348627aad/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210112080510-489259a85091/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210119212857-b64e53b001e4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210220050731-9a76102bfb43/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210305230114-8fe3ee5dd75b/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210315160823-c6e025ad8005/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210320140829-1e4c9ba3b0c4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210403161142-5e06dd20ab57/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210423185535-09eb48e85fd7/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210514084401-e8d321eab015/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210603125802-9665404d3644/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210823070655-63515b42dcdf/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210831042530-f4d43177bf5e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210908233432-aa78b53d3365/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211019181941-9d821ace8654/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e h1:fLOSk5Q00efkSvAm+4xcoXD+RRmLmmulPn5I3Y9F2EM= golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220319134239-a9b59b0215f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220422013727-9388b58f7150/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.1.0 h1:kunALQeHf1/185U1i0GOB/fy1IPRDDpuoOOqRReG57U= golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.3.0 h1:w8ZOecv6NaNa/zC8944JTU3vz4u6Lagfk4RPQxv92NQ= golang.org/x/sys v0.3.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210615171337-6886f2dfbf5b h1:9zKuko04nR4gjZ4+DNjHqRlAJqbJETHwiNKDqTfOjfE= golang.org/x/term v0.0.0-20210615171337-6886f2dfbf5b/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.1.0 h1:g6Z6vPFA9dYBAF7DWcH6sCcOntplXsDKcliusYijMlw= golang.org/x/term v0.1.0/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.2.0/go.mod h1:TVmDHMZPmdnySmBfhjOoOdhjzdE1h4u1VwSiw2l1Nuc= golang.org/x/term v0.3.0 h1:qoo4akIqOcDME5bhc/NgxUdovd6BSS2uMsVjB56q1xI= golang.org/x/term v0.3.0/go.mod h1:q750SLmJuPmVoN1blW3UFBPREJfb1KmY3vwxfr+nFDA= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6 h1:aRYxNxv6iGQlyVaZmk6ZgYEDa+Jg18DxebPSrd6bg1M= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.4.0 h1:BrVqGRd7+k1DiOgtnFvAkoQEWQvBc25ouMJM6429SFg= golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.5.0 h1:OLmvp0KP+FVG99Ct/qFiL/Fhk4zp4QQnZ7b2U+5piUM= golang.org/x/text v0.5.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20200630173020-3af7569d3a1e h1:EHBhcS0mlXEAVwNyO2dLfjToGsyY4j24pTs2ScHnX7s= golang.org/x/time v0.0.0-20200630173020-3af7569d3a1e/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20210220033141-f8bda1e9f3ba/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20210723032227-1f47c861a9ac h1:7zkz7BUtwNFFqcowJ+RIgu2MaV/MapERkDIy+mwPyjs= golang.org/x/time v0.0.0-20210723032227-1f47c861a9ac/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20220210224613-90d013bbcef8/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4= golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180221164845-07fd8470d635/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20181030221726-6c7e314b6563/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190312151545-0bb0c0a6e846/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190312170243-e65039ee4138/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190328211700-ab21143f2384/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190425150028-36563e24a262/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= golang.org/x/tools v0.0.0-20190506145303-2d16b83fe98c/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= golang.org/x/tools v0.0.0-20190606124116-d0a3d012864b/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= golang.org/x/tools v0.0.0-20190614205625-5aca471b1d59/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= golang.org/x/tools v0.0.0-20190621195816-6e04913cbbac/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= golang.org/x/tools v0.0.0-20190624222133-a101b041ded4/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= golang.org/x/tools v0.0.0-20190628153133-6cdbf07be9d0/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= golang.org/x/tools v0.0.0-20190816200558-6889da9d5479/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20190911174233-4f2ddba30aff/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191012152004-8de300cfc20a/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191029190741-b9c20aec41a5/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191108193012-7d206e10da11/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191112195655-aa38f8e97acc/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191113191852-77e3bb0ad9e7/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191115202509-3a792d9c32b2/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191125144606-a911d9008d1f/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191130070609-6e064ea0cf2d/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191216173652-a0e659d51361/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= golang.org/x/tools v0.0.0-20191227053925-7b8e75db28f4/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= golang.org/x/tools v0.0.0-20200117161641-43d50277825c/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= golang.org/x/tools v0.0.0-20200122220014-bf1340f18c4a/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= golang.org/x/tools v0.0.0-20200204074204-1cc6d1ef6c74/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= golang.org/x/tools v0.0.0-20200207183749-b753a1ba74fa/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= golang.org/x/tools v0.0.0-20200212150539-ea181f53ac56/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= golang.org/x/tools v0.0.0-20200224181240-023911ca70b2/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= golang.org/x/tools v0.0.0-20200227222343-706bc42d1f0d/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= golang.org/x/tools v0.0.0-20200304193943-95d2e580d8eb/go.mod h1:o4KQGtdN14AW+yjsvvwRTJJuXz8XRtIHtEnmAXLyFUw= golang.org/x/tools v0.0.0-20200312045724-11d5b4c81c7d/go.mod h1:o4KQGtdN14AW+yjsvvwRTJJuXz8XRtIHtEnmAXLyFUw= golang.org/x/tools v0.0.0-20200331025713-a30bf2db82d4/go.mod h1:Sl4aGygMT6LrqrWclx+PTx3U+LnKx/seiNR+3G19Ar8= golang.org/x/tools v0.0.0-20200501065659-ab2804fb9c9d/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20200505023115-26f46d2f7ef8/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20200512131952-2bc93b1c0c88/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20200515010526-7d3b6ebf133d/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20200616133436-c1934b75d054/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20200618134242-20370b0cb4b2/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20200729194436-6467de6f59a7/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA= golang.org/x/tools v0.0.0-20200804011535-6c149bb5ef0d/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA= golang.org/x/tools v0.0.0-20200825202427-b303f430e36d/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA= golang.org/x/tools v0.0.0-20200904185747-39188db58858/go.mod h1:Cj7w3i3Rnn0Xh82ur9kSqwfTHTeVxaDqrfMjpcNT6bE= golang.org/x/tools v0.0.0-20201110124207-079ba7bd75cd/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.0.0-20201201161351-ac6f37ff4c2a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.0.0-20201208233053-a543418bbed2/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.0.0-20201224043029-2b0845dc783e/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.0.0-20210105154028-b0ab187a4818/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.1.0/go.mod h1:xkSsbof2nBLbhDlRMhhhyNLN/zl3eTqcnHD5viDpcZ0= golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.2/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.3/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.4/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.10/go.mod h1:Uh6Zz+xoGYZom868N8YTex3t7RhtHDBrE8Gzo9bV56E= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.2.0 h1:G6AHpWxTMGY1KyEYoAQ5WTtIekUUvDNjan3ugu60JvE= golang.org/x/tools v0.2.0/go.mod h1:y4OqIKeOV/fWJetJ8bXPU1sEVniLMIyDAZWeHdV+NTA= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gomodules.xyz/jsonpatch/v2 v2.1.0 h1:Phva6wqu+xR//Njw6iorylFFgn/z547tw5Ne3HZPQ+k= gomodules.xyz/jsonpatch/v2 v2.1.0/go.mod h1:IhYNNY4jnS53ZnfE4PAmpKtDpTCj1JFXc+3mwe7XcUU= gomodules.xyz/jsonpatch/v2 v2.2.0 h1:4pT439QV83L+G9FkcCriY6EkpcK6r6bK+A5FBUMI7qY= gomodules.xyz/jsonpatch/v2 v2.2.0/go.mod h1:WXp+iVDkoLQqPudfQ9GBlwB2eZ5DKOnjQZCYdOS8GPY= google.golang.org/api v0.4.0/go.mod h1:8k5glujaEP+g9n7WNsDg8QP6cUVNI86fCNMcbazEtwE= google.golang.org/api v0.7.0/go.mod h1:WtwebWUNSVBH/HAw79HIFXZNqEvBhG+Ra+ax0hx3E3M= google.golang.org/api v0.8.0/go.mod h1:o4eAsZoiT+ibD93RtjEohWalFOjRDx6CVaqeizhEnKg= google.golang.org/api v0.9.0/go.mod h1:o4eAsZoiT+ibD93RtjEohWalFOjRDx6CVaqeizhEnKg= google.golang.org/api v0.13.0/go.mod h1:iLdEw5Ide6rF15KTC1Kkl0iskquN2gFfn9o9XIsbkAI= google.golang.org/api v0.14.0/go.mod h1:iLdEw5Ide6rF15KTC1Kkl0iskquN2gFfn9o9XIsbkAI= google.golang.org/api v0.15.0/go.mod h1:iLdEw5Ide6rF15KTC1Kkl0iskquN2gFfn9o9XIsbkAI= google.golang.org/api v0.17.0/go.mod h1:BwFmGc8tA3vsd7r/7kR8DY7iEEGSU04BFxCo5jP/sfE= google.golang.org/api v0.18.0/go.mod h1:BwFmGc8tA3vsd7r/7kR8DY7iEEGSU04BFxCo5jP/sfE= google.golang.org/api v0.19.0/go.mod h1:BwFmGc8tA3vsd7r/7kR8DY7iEEGSU04BFxCo5jP/sfE= google.golang.org/api v0.20.0/go.mod h1:BwFmGc8tA3vsd7r/7kR8DY7iEEGSU04BFxCo5jP/sfE= google.golang.org/api v0.22.0/go.mod h1:BwFmGc8tA3vsd7r/7kR8DY7iEEGSU04BFxCo5jP/sfE= google.golang.org/api v0.24.0/go.mod h1:lIXQywCXRcnZPGlsd8NbLnOjtAoL6em04bJ9+z0MncE= google.golang.org/api v0.28.0/go.mod h1:lIXQywCXRcnZPGlsd8NbLnOjtAoL6em04bJ9+z0MncE= google.golang.org/api v0.29.0/go.mod h1:Lcubydp8VUV7KeIHD9z2Bys/sm/vGKnG1UHuDBSrHWM= google.golang.org/api v0.30.0/go.mod h1:QGmEvQ87FHZNiUVJkT14jQNYJ4ZJjdRF23ZXz5138Fc= google.golang.org/api v0.35.0/go.mod h1:/XrVsuzM0rZmrsbjJutiuftIzeuTQcEeaYcSk/mQ1dg= google.golang.org/api v0.36.0/go.mod h1:+z5ficQTmoYpPn8LCUNVpK5I7hwkpjbcgqA7I34qYtE= google.golang.org/api v0.40.0/go.mod h1:fYKFpnQN0DsDSKRVRcQSDQNtqWPfM9i+zNPxepjRCQ8= google.golang.org/api v0.41.0/go.mod h1:RkxM5lITDfTzmyKFPt+wGrCJbVfniCr2ool8kTBzRTU= google.golang.org/api v0.43.0/go.mod h1:nQsDGjRXMo4lvh5hP0TKqF244gqhGcr/YSIykhUk/94= google.golang.org/api v0.47.0/go.mod h1:Wbvgpq1HddcWVtzsVLyfLp8lDg6AA241LmgIL59tHXo= google.golang.org/api v0.48.0/go.mod h1:71Pr1vy+TAZRPkPs/xlCf5SsU8WjuAWv1Pfjbtukyy4= google.golang.org/api v0.50.0/go.mod h1:4bNT5pAuq5ji4SRZm+5QIkjny9JAyVD/3gaSihNefaw= google.golang.org/api v0.51.0/go.mod h1:t4HdrdoNgyN5cbEfm7Lum0lcLDLiise1F8qDKX00sOU= google.golang.org/api v0.54.0/go.mod h1:7C4bFFOvVDGXjfDTAsgGwDgAxRDeQ4X8NvUedIt6z3k= google.golang.org/api v0.55.0/go.mod h1:38yMfeP1kfjsl8isn0tliTjIb1rJXcQi4UXlbqivdVE= google.golang.org/api v0.57.0/go.mod h1:dVPlbZyBo2/OjBpmvNdpn2GRm6rPy75jyU7bmhdrMgI= google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/appengine v1.5.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/appengine v1.6.1/go.mod h1:i06prIuMbXzDqacNJfV5OdTW448YApPu5ww/cMBSeb0= google.golang.org/appengine v1.6.5/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= google.golang.org/appengine v1.6.6 h1:lMO5rYAqUxkmaj76jAkRUvt5JZgFymx/+Q5Mzfivuhc= google.golang.org/appengine v1.6.6/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= google.golang.org/appengine v1.6.7 h1:FZR1q0exgwxzPzp/aF+VccGrSfxfPpkBqjIIEq3ru6c= google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= google.golang.org/genproto v0.0.0-20190307195333-5fe7a883aa19/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= google.golang.org/genproto v0.0.0-20190418145605-e7d98fc518a7/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= google.golang.org/genproto v0.0.0-20190425155659-357c62f0e4bb/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= google.golang.org/genproto v0.0.0-20190502173448-54afdca5d873/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= google.golang.org/genproto v0.0.0-20190801165951-fa694d86fc64/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= google.golang.org/genproto v0.0.0-20190911173649-1774047e7e51/go.mod h1:IbNlFCBrqXvoKpeg0TB2l7cyZUmoaFKYIwrEpbDKLA8= google.golang.org/genproto v0.0.0-20191108220845-16a3f7862a1a/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc= google.golang.org/genproto v0.0.0-20191115194625-c23dd37a84c9/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc= google.golang.org/genproto v0.0.0-20191216164720-4f79533eabd1/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc= google.golang.org/genproto v0.0.0-20191230161307-f3c370f40bfb/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc= google.golang.org/genproto v0.0.0-20200115191322-ca5a22157cba/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc= google.golang.org/genproto v0.0.0-20200122232147-0452cf42e150/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc= google.golang.org/genproto v0.0.0-20200204135345-fa8e72b47b90/go.mod h1:GmwEX6Z4W5gMy59cAlVYjN9JhxgbQH6Gn+gFDQe2lzA= google.golang.org/genproto v0.0.0-20200212174721-66ed5ce911ce/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= google.golang.org/genproto v0.0.0-20200224152610-e50cd9704f63/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= google.golang.org/genproto v0.0.0-20200228133532-8c2c7df3a383/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= google.golang.org/genproto v0.0.0-20200305110556-506484158171/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= google.golang.org/genproto v0.0.0-20200312145019-da6875a35672/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= google.golang.org/genproto v0.0.0-20200331122359-1ee6d9798940/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= google.golang.org/genproto v0.0.0-20200423170343-7949de9c1215/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= google.golang.org/genproto v0.0.0-20200430143042-b979b6f78d84/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= google.golang.org/genproto v0.0.0-20200511104702-f5ebc3bea380/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= google.golang.org/genproto v0.0.0-20200513103714-09dca8ec2884/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= google.golang.org/genproto v0.0.0-20200515170657-fc4c6c6a6587/go.mod h1:YsZOwe1myG/8QRHRsmBRE1LrgQY60beZKjly0O1fX9U= google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= google.golang.org/genproto v0.0.0-20200618031413-b414f8b61790/go.mod h1:jDfRM7FcilCzHH/e9qn6dsT145K34l5v+OpcnNgKAAA= google.golang.org/genproto v0.0.0-20200729003335-053ba62fc06f/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= google.golang.org/genproto v0.0.0-20200804131852-c06518451d9c/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= google.golang.org/genproto v0.0.0-20200825200019-8632dd797987/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= google.golang.org/genproto v0.0.0-20200904004341-0bd0a958aa1d/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= google.golang.org/genproto v0.0.0-20201019141844-1ed22bb0c154/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= google.golang.org/genproto v0.0.0-20201109203340-2640f1f9cdfb/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= google.golang.org/genproto v0.0.0-20201110150050-8816d57aaa9a/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= google.golang.org/genproto v0.0.0-20201201144952-b05cb90ed32e/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= google.golang.org/genproto v0.0.0-20201210142538-e3217bee35cc/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= google.golang.org/genproto v0.0.0-20201214200347-8c77b98c765d/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= google.golang.org/genproto v0.0.0-20210222152913-aa3ee6e6a81c/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= google.golang.org/genproto v0.0.0-20210303154014-9728d6b83eeb/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= google.golang.org/genproto v0.0.0-20210310155132-4ce2db91004e/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= google.golang.org/genproto v0.0.0-20210319143718-93e7006c17a6/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= google.golang.org/genproto v0.0.0-20210402141018-6c239bbf2bb1/go.mod h1:9lPAdzaEmUacj36I+k7YKbEc5CXzPIeORRgDAUOu28A= google.golang.org/genproto v0.0.0-20210513213006-bf773b8c8384/go.mod h1:P3QM42oQyzQSnHPnZ/vqoCdDmzH28fzWByN9asMeM8A= google.golang.org/genproto v0.0.0-20210602131652-f16073e35f0c/go.mod h1:UODoCrxHCcBojKKwX1terBiRUaqAsFqJiF615XL43r0= google.golang.org/genproto v0.0.0-20210604141403-392c879c8b08/go.mod h1:UODoCrxHCcBojKKwX1terBiRUaqAsFqJiF615XL43r0= google.golang.org/genproto v0.0.0-20210608205507-b6d2f5bf0d7d/go.mod h1:UODoCrxHCcBojKKwX1terBiRUaqAsFqJiF615XL43r0= google.golang.org/genproto v0.0.0-20210624195500-8bfb893ecb84/go.mod h1:SzzZ/N+nwJDaO1kznhnlzqS8ocJICar6hYhVyhi++24= google.golang.org/genproto v0.0.0-20210713002101-d411969a0d9a/go.mod h1:AxrInvYm1dci+enl5hChSFPOmmUF1+uAa/UsgNRWd7k= google.golang.org/genproto v0.0.0-20210716133855-ce7ef5c701ea/go.mod h1:AxrInvYm1dci+enl5hChSFPOmmUF1+uAa/UsgNRWd7k= google.golang.org/genproto v0.0.0-20210728212813-7823e685a01f/go.mod h1:ob2IJxKrgPT52GcgX759i1sleT07tiKowYBGbczaW48= google.golang.org/genproto v0.0.0-20210805201207-89edb61ffb67/go.mod h1:ob2IJxKrgPT52GcgX759i1sleT07tiKowYBGbczaW48= google.golang.org/genproto v0.0.0-20210813162853-db860fec028c/go.mod h1:cFeNkxwySK631ADgubI+/XFU/xp8FD5KIVV4rj8UC5w= google.golang.org/genproto v0.0.0-20210821163610-241b8fcbd6c8/go.mod h1:eFjDcFEctNawg4eG61bRv87N7iHBWyVhJu7u1kqDUXY= google.golang.org/genproto v0.0.0-20210828152312-66f60bf46e71/go.mod h1:eFjDcFEctNawg4eG61bRv87N7iHBWyVhJu7u1kqDUXY= google.golang.org/genproto v0.0.0-20210831024726-fe130286e0e2/go.mod h1:eFjDcFEctNawg4eG61bRv87N7iHBWyVhJu7u1kqDUXY= google.golang.org/genproto v0.0.0-20210903162649-d08c68adba83/go.mod h1:eFjDcFEctNawg4eG61bRv87N7iHBWyVhJu7u1kqDUXY= google.golang.org/genproto v0.0.0-20210924002016-3dee208752a0/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc= google.golang.org/genproto v0.0.0-20211118181313-81c1377c94b1/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc= google.golang.org/genproto v0.0.0-20220502173005-c8bf987b8c21/go.mod h1:RAyBrSAP7Fh3Nc84ghnVLDPuV51xc9agzmm4Ph6i0Q4= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38= google.golang.org/grpc v1.21.1/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM= google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY= google.golang.org/grpc v1.26.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= google.golang.org/grpc v1.27.1/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= google.golang.org/grpc v1.28.0/go.mod h1:rpkK4SK4GF4Ach/+MFLZUBavHOvF2JJB5uozKKal+60= google.golang.org/grpc v1.29.1/go.mod h1:itym6AZVZYACWQqET3MqgPpjcuV5QH3BxFS3IjizoKk= google.golang.org/grpc v1.30.0/go.mod h1:N36X2cJ7JwdamYAgDz+s+rVMFjt3numwzf/HckM8pak= google.golang.org/grpc v1.31.0/go.mod h1:N36X2cJ7JwdamYAgDz+s+rVMFjt3numwzf/HckM8pak= google.golang.org/grpc v1.31.1/go.mod h1:N36X2cJ7JwdamYAgDz+s+rVMFjt3numwzf/HckM8pak= google.golang.org/grpc v1.33.1/go.mod h1:fr5YgcSWrqhRRxogOsw7RzIpsmvOZ6IcH4kBYTpR3n0= google.golang.org/grpc v1.33.2/go.mod h1:JMHMWHQWaTccqQQlmk3MJZS+GWXOdAesneDmEnv2fbc= google.golang.org/grpc v1.34.0/go.mod h1:WotjhfgOW/POjDeRt8vscBtXq+2VjORFy659qA51WJ8= google.golang.org/grpc v1.35.0/go.mod h1:qjiiYl8FncCW8feJPdyg3v6XW24KsRHe+dy9BAGRRjU= google.golang.org/grpc v1.36.0/go.mod h1:qjiiYl8FncCW8feJPdyg3v6XW24KsRHe+dy9BAGRRjU= google.golang.org/grpc v1.36.1/go.mod h1:qjiiYl8FncCW8feJPdyg3v6XW24KsRHe+dy9BAGRRjU= google.golang.org/grpc v1.37.0/go.mod h1:NREThFqKR1f3iQ6oBuvc5LadQuXVGo9rkm5ZGrQdJfM= google.golang.org/grpc v1.37.1/go.mod h1:NREThFqKR1f3iQ6oBuvc5LadQuXVGo9rkm5ZGrQdJfM= google.golang.org/grpc v1.38.0/go.mod h1:NREThFqKR1f3iQ6oBuvc5LadQuXVGo9rkm5ZGrQdJfM= google.golang.org/grpc v1.39.0/go.mod h1:PImNr+rS9TWYb2O4/emRugxiyHZ5JyHW5F+RPnDzfrE= google.golang.org/grpc v1.39.1/go.mod h1:PImNr+rS9TWYb2O4/emRugxiyHZ5JyHW5F+RPnDzfrE= google.golang.org/grpc v1.40.0/go.mod h1:ogyxbiOoUXAkP+4+xa6PZSE9DZgIHtSpzjDTB9KAK34= google.golang.org/grpc v1.41.0/go.mod h1:U3l9uK9J0sini8mHphKoXyaqDA/8VyGnDee1zzIUK6k= google.golang.org/grpc v1.42.0/go.mod h1:k+4IHHFw41K8+bbowsex27ge2rCb65oeWqe4jJ590SU= google.golang.org/grpc v1.46.0/go.mod h1:vN9eftEi1UMyUsIF80+uQXhHjbXYbm0uXoFCACuMGWk= google.golang.org/grpc v1.46.2/go.mod h1:vN9eftEi1UMyUsIF80+uQXhHjbXYbm0uXoFCACuMGWk= google.golang.org/grpc v1.49.0/go.mod h1:ZgQEeidpAuNRZ8iRrlBKXZQP1ghovWIVhdJRyCDK+GI= google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.1.0/go.mod h1:6Kw0yEErY5E/yWrBtf03jp27GLLJujG4z/JK95pnjjw= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= google.golang.org/protobuf v1.22.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= google.golang.org/protobuf v1.24.0/go.mod h1:r/3tXBNzIEhYS9I1OUVjXDlt8tc493IdKGjtUeSXeh4= google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.26.0 h1:bxAC2xTBsZGibn2RTntX0oH50xLsqy1OxA9tTL3p/lk= google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= google.golang.org/protobuf v1.27.1 h1:SnqbnDw1V7RiZcXPx5MEeqPv2s79L9i7BJUlG/+RurQ= google.golang.org/protobuf v1.27.1/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= google.golang.org/protobuf v1.28.0 h1:w43yiav+6bVFTBQFZX0r7ipe9JQ1QsbMgHwbBziscLw= google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= google.golang.org/protobuf v1.28.1 h1:d0NfwRgPtno5B1Wa6L2DAG+KivqkdutMf1UhdNx175w= google.golang.org/protobuf v1.28.1/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f h1:BLraFXnmrev5lT+xlilqcH8XK9/i0At2xKjWk4p6zsU= gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/cheggaaa/pb.v1 v1.0.25/go.mod h1:V/YB90LKu/1FcN3WVnfiiE5oMCibMjukxqG/qStrOgw= gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= gopkg.in/inf.v0 v0.9.1 h1:73M5CoZyi3ZLMOyDlQh031Cx6N9NDJ2Vvfl76EDAgDc= gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw= gopkg.in/ini.v1 v1.51.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= gopkg.in/natefinch/lumberjack.v2 v2.0.0/go.mod h1:l0ndWWf7gzL7RNwBG7wST/UCcT4T24xpD6X8LsfU/+k= gopkg.in/resty.v1 v1.12.0/go.mod h1:mDo4pnntr5jdWRML875a/NmxYqAlA73dVijT2AXvQQo= gopkg.in/square/go-jose.v2 v2.2.2/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI= gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= gopkg.in/yaml.v2 v2.0.0-20170812160011-eb3733d160e7/go.mod h1:JAlM8MvJe8wmxCU4Bli9HhUf9+ttbYbLASfIpnQbh74= gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.3/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.5/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0-20200615113413-eeeca48fe776 h1:tQIYjPdBoyREyB9XMu+nnTclpTYkz2zFM+lzLJFO4gQ= gopkg.in/yaml.v3 v3.0.0-20200615113413-eeeca48fe776/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b h1:h8qDotaEPuJATrMmW04NCwg7v22aHH28wwpauUhK9Oo= gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gotest.tools v2.2.0+incompatible/go.mod h1:DsYFclhRJ6vuDpmuTbkuFWG+y2sxOXAzmJt81HFBacw= gotest.tools/v3 v3.0.2/go.mod h1:3SzNCllyD9/Y+b5r9JIKQ474KzkZyqLqEfYqMsX94Bk= gotest.tools/v3 v3.0.3/go.mod h1:Z7Lb0S5l+klDB31fvDQX8ss/FlKDxtlFlw3Oa8Ymbl8= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= honnef.co/go/tools v0.0.1-2020.1.3/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k= honnef.co/go/tools v0.0.1-2020.1.4 h1:UoveltGrhghAA7ePc+e+QYDHXrBps2PqFZiHkGR/xK8= honnef.co/go/tools v0.0.1-2020.1.4/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k= k8s.io/api v0.20.1/go.mod h1:KqwcCVogGxQY3nBlRpwt+wpAMF/KjaCc7RpywacvqUo= k8s.io/api v0.20.2/go.mod h1:d7n6Ehyzx+S+cE3VhTGfVNNqtGc/oL9DCdYYahlurV8= k8s.io/api v0.20.6 h1:bgdZrW++LqgrLikWYNruIKAtltXbSCX2l5mJu11hrVE= k8s.io/api v0.20.6/go.mod h1:X9e8Qag6JV/bL5G6bU8sdVRltWKmdHsFUGS3eVndqE8= k8s.io/api v0.23.5 h1:zno3LUiMubxD/V1Zw3ijyKO3wxrhbUF1Ck+VjBvfaoA= k8s.io/api v0.23.5/go.mod h1:Na4XuKng8PXJ2JsploYYrivXrINeTaycCGcYgF91Xm8= k8s.io/api v0.26.0 h1:IpPlZnxBpV1xl7TGk/X6lFtpgjgntCg8PJ+qrPHAC7I= k8s.io/api v0.26.0/go.mod h1:k6HDTaIFC8yn1i6pSClSqIwLABIcLV9l5Q4EcngKnQg= k8s.io/apiextensions-apiserver v0.20.1 h1:ZrXQeslal+6zKM/HjDXLzThlz/vPSxrfK3OqL8txgVQ= k8s.io/apiextensions-apiserver v0.20.1/go.mod h1:ntnrZV+6a3dB504qwC5PN/Yg9PBiDNt1EVqbW2kORVk= k8s.io/apiextensions-apiserver v0.26.0 h1:Gy93Xo1eg2ZIkNX/8vy5xviVSxwQulsnUdQ00nEdpDo= k8s.io/apiextensions-apiserver v0.26.0/go.mod h1:7ez0LTiyW5nq3vADtK6C3kMESxadD51Bh6uz3JOlqWQ= k8s.io/apimachinery v0.20.1/go.mod h1:WlLqWAHZGg07AeltaI0MV5uk1Omp8xaN0JGLY6gkRpU= k8s.io/apimachinery v0.20.2/go.mod h1:WlLqWAHZGg07AeltaI0MV5uk1Omp8xaN0JGLY6gkRpU= k8s.io/apimachinery v0.20.6 h1:R5p3SlhaABYShQSO6LpPsYHjV05Q+79eBUR0Ut/f4tk= k8s.io/apimachinery v0.20.6/go.mod h1:ejZXtW1Ra6V1O5H8xPBGz+T3+4gfkTCeExAHKU57MAc= k8s.io/apimachinery v0.23.5 h1:Va7dwhp8wgkUPWsEXk6XglXWU4IKYLKNlv8VkX7SDM0= k8s.io/apimachinery v0.23.5/go.mod h1:BEuFMMBaIbcOqVIJqNZJXGFTP4W6AycEpb5+m/97hrM= k8s.io/apimachinery v0.26.0 h1:1feANjElT7MvPqp0JT6F3Ss6TWDwmcjLypwoPpEf7zg= k8s.io/apimachinery v0.26.0/go.mod h1:tnPmbONNJ7ByJNz9+n9kMjNP8ON+1qoAIIC70lztu74= k8s.io/apiserver v0.20.1/go.mod h1:ro5QHeQkgMS7ZGpvf4tSMx6bBOgPfE+f52KwvXfScaU= k8s.io/apiserver v0.26.0/go.mod h1:aWhlLD+mU+xRo+zhkvP/gFNbShI4wBDHS33o0+JGI84= k8s.io/cli-runtime v0.20.6/go.mod h1:JVERW478qcxWrUjJuWQSqyJeiz9QC4T6jmBznHFBC8w= k8s.io/cli-runtime v0.26.0 h1:aQHa1SyUhpqxAw1fY21x2z2OS5RLtMJOCj7tN4oq8mw= k8s.io/cli-runtime v0.26.0/go.mod h1:o+4KmwHzO/UK0wepE1qpRk6l3o60/txUZ1fEXWGIKTY= k8s.io/client-go v0.20.1/go.mod h1:/zcHdt1TeWSd5HoUe6elJmHSQ6uLLgp4bIJHVEuy+/Y= k8s.io/client-go v0.20.2/go.mod h1:kH5brqWqp7HDxUFKoEgiI4v8G1xzbe9giaCenUWJzgE= k8s.io/client-go v0.20.6 h1:nJZOfolnsVtDtbGJNCxzOtKUAu7zvXjB8+pMo9UNxZo= k8s.io/client-go v0.20.6/go.mod h1:nNQMnOvEUEsOzRRFIIkdmYOjAZrC8bgq0ExboWSU1I0= k8s.io/client-go v0.23.5 h1:zUXHmEuqx0RY4+CsnkOn5l0GU+skkRXKGJrhmE2SLd8= k8s.io/client-go v0.23.5/go.mod h1:flkeinTO1CirYgzMPRWxUCnV0G4Fbu2vLhYCObnt/r4= k8s.io/client-go v0.26.0 h1:lT1D3OfO+wIi9UFolCrifbjUUgu7CpLca0AD8ghRLI8= k8s.io/client-go v0.26.0/go.mod h1:I2Sh57A79EQsDmn7F7ASpmru1cceh3ocVT9KlX2jEZg= k8s.io/code-generator v0.20.1/go.mod h1:UsqdF+VX4PU2g46NC2JRs4gc+IfrctnwHb76RNbWHJg= k8s.io/code-generator v0.20.6/go.mod h1:i6FmG+QxaLxvJsezvZp0q/gAEzzOz3U53KFibghWToU= k8s.io/code-generator v0.26.0/go.mod h1:OMoJ5Dqx1wgaQzKgc+ZWaZPfGjdRq/Y3WubFrZmeI3I= k8s.io/component-base v0.20.1/go.mod h1:guxkoJnNoh8LNrbtiQOlyp2Y2XFCZQmrcg2n/DeYNLk= k8s.io/component-base v0.20.2/go.mod h1:pzFtCiwe/ASD0iV7ySMu8SYVJjCapNM9bjvk7ptpKh0= k8s.io/component-base v0.20.6 h1:G0inASS5vAqCpzs7M4Sp9dv9d0aElpz39zDHbSB4f4g= k8s.io/component-base v0.20.6/go.mod h1:6f1MPBAeI+mvuts3sIdtpjljHWBQ2cIy38oBIWMYnrM= k8s.io/component-base v0.26.0 h1:0IkChOCohtDHttmKuz+EP3j3+qKmV55rM9gIFTXA7Vs= k8s.io/component-base v0.26.0/go.mod h1:lqHwlfV1/haa14F/Z5Zizk5QmzaVf23nQzCwVOQpfC8= k8s.io/component-helpers v0.20.6/go.mod h1:d4rFhZS/wxrZCxRiJJiWf1mVGVeMB5/ey3Yv8/rOp78= k8s.io/component-helpers v0.26.0/go.mod h1:jHN01qS/Jdj95WCbTe9S2VZ9yxpxXNY488WjF+yW4fo= k8s.io/gengo v0.0.0-20200413195148-3a45101e95ac/go.mod h1:ezvh/TsK7cY6rbqRK0oQQ8IAqLxYwwyPxAX1Pzy0ii0= k8s.io/gengo v0.0.0-20201113003025-83324d819ded/go.mod h1:FiNAH4ZV3gBg2Kwh89tzAEV2be7d5xI0vBa/VySYy3E= k8s.io/gengo v0.0.0-20210813121822-485abfe95c7c/go.mod h1:FiNAH4ZV3gBg2Kwh89tzAEV2be7d5xI0vBa/VySYy3E= k8s.io/gengo v0.0.0-20220902162205-c0856e24416d/go.mod h1:FiNAH4ZV3gBg2Kwh89tzAEV2be7d5xI0vBa/VySYy3E= k8s.io/klog/v2 v2.0.0/go.mod h1:PBfzABfn139FHAV07az/IF9Wp1bkk3vpT2XSJ76fSDE= k8s.io/klog/v2 v2.2.0/go.mod h1:Od+F08eJP+W3HUb4pSrPpgp9DGU4GzlpG/TmITuYh/Y= k8s.io/klog/v2 v2.4.0 h1:7+X0fUguPyrKEC4WjH8iGDg3laWgMo5tMnRTIGTTxGQ= k8s.io/klog/v2 v2.4.0/go.mod h1:Od+F08eJP+W3HUb4pSrPpgp9DGU4GzlpG/TmITuYh/Y= k8s.io/klog/v2 v2.30.0 h1:bUO6drIvCIsvZ/XFgfxoGFQU/a4Qkh0iAlvUR7vlHJw= k8s.io/klog/v2 v2.30.0/go.mod h1:y1WjHnz7Dj687irZUWR/WLkLc5N1YHtjLdmgWjndZn0= k8s.io/klog/v2 v2.80.1 h1:atnLQ121W371wYYFawwYx1aEY2eUfs4l3J72wtgAwV4= k8s.io/klog/v2 v2.80.1/go.mod h1:y1WjHnz7Dj687irZUWR/WLkLc5N1YHtjLdmgWjndZn0= k8s.io/kms v0.26.0/go.mod h1:ReC1IEGuxgfN+PDCIpR6w8+XMmDE7uJhxcCwMZFdIYc= k8s.io/kube-openapi v0.0.0-20201113171705-d219536bb9fd h1:sOHNzJIkytDF6qadMNKhhDRpc6ODik8lVC6nOur7B2c= k8s.io/kube-openapi v0.0.0-20201113171705-d219536bb9fd/go.mod h1:WOJ3KddDSol4tAGcJo0Tvi+dK12EcqSLqcWsryKMpfM= k8s.io/kube-openapi v0.0.0-20211115234752-e816edb12b65 h1:E3J9oCLlaobFUqsjG9DfKbP2BmgwBL2p7pn0A3dG9W4= k8s.io/kube-openapi v0.0.0-20211115234752-e816edb12b65/go.mod h1:sX9MT8g7NVZM5lVL/j8QyCCJe8YSMW30QvGZWaCIDIk= k8s.io/kube-openapi v0.0.0-20220401212409-b28bf2818661/go.mod h1:daOouuuwd9JXpv1L7Y34iV3yf6nxzipkKMWWlqlvK9M= k8s.io/kube-openapi v0.0.0-20221012153701-172d655c2280 h1:+70TFaan3hfJzs+7VK2o+OGxg8HsuBr/5f6tVAjDu6E= k8s.io/kube-openapi v0.0.0-20221012153701-172d655c2280/go.mod h1:+Axhij7bCpeqhklhUTe3xmOn6bWxolyZEeyaFpjGtl4= k8s.io/kubectl v0.20.6 h1:G0a3fJXvypzN1fDcO+clH131rpDxNtDZIgSuogSCtng= k8s.io/kubectl v0.20.6/go.mod h1:yTCGVrlkBuQhFbKA1R65+lQ9hH7XeyOqUd0FUPFicPg= k8s.io/kubectl v0.26.0 h1:xmrzoKR9CyNdzxBmXV7jW9Ln8WMrwRK6hGbbf69o4T0= k8s.io/kubectl v0.26.0/go.mod h1:eInP0b+U9XUJWSYeU9XZnTA+cVYuWyl3iYPGtru0qhQ= k8s.io/metrics v0.20.6/go.mod h1:d+OAIaXutom9kGWcBit/M8OkDpIzBKTsm47+KcUt7VI= k8s.io/metrics v0.26.0/go.mod h1:cf5MlG4ZgWaEFZrR9+sOImhZ2ICMpIdNurA+D8snIs8= k8s.io/utils v0.0.0-20201110183641-67b214c5f920/go.mod h1:jPW/WVKK9YHAvNhRxK0md/EJ228hCsBRufyofKtW8HA= k8s.io/utils v0.0.0-20210111153108-fddb29f9d009 h1:0T5IaWHO3sJTEmCP6mUlBvMukxPKUQWqiI/YuiBNMiQ= k8s.io/utils v0.0.0-20210111153108-fddb29f9d009/go.mod h1:jPW/WVKK9YHAvNhRxK0md/EJ228hCsBRufyofKtW8HA= k8s.io/utils v0.0.0-20210802155522-efc7438f0176/go.mod h1:jPW/WVKK9YHAvNhRxK0md/EJ228hCsBRufyofKtW8HA= k8s.io/utils v0.0.0-20211116205334-6203023598ed h1:ck1fRPWPJWsMd8ZRFsWc6mh/zHp5fZ/shhbrgPUxDAE= k8s.io/utils v0.0.0-20211116205334-6203023598ed/go.mod h1:jPW/WVKK9YHAvNhRxK0md/EJ228hCsBRufyofKtW8HA= k8s.io/utils v0.0.0-20220210201930-3a6ce19ff2f9/go.mod h1:jPW/WVKK9YHAvNhRxK0md/EJ228hCsBRufyofKtW8HA= k8s.io/utils v0.0.0-20221107191617-1a15be271d1d/go.mod h1:OLgZIPagt7ERELqWJFomSt595RzquPNLL48iOWgYOg0= k8s.io/utils v0.0.0-20221128185143-99ec85e7a448 h1:KTgPnR10d5zhztWptI952TNtt/4u5h3IzDXkdIMuo2Y= k8s.io/utils v0.0.0-20221128185143-99ec85e7a448/go.mod h1:OLgZIPagt7ERELqWJFomSt595RzquPNLL48iOWgYOg0= rsc.io/binaryregexp v0.2.0/go.mod h1:qTv7/COck+e2FymRvadv62gMdZztPaShugOCi3I+8D8= rsc.io/quote/v3 v3.1.0/go.mod h1:yEA65RcK8LyAZtP9Kv3t0HmxON59tX3rD+tICJqUlj0= rsc.io/sampler v1.3.0/go.mod h1:T1hPZKmBbMNahiBKFy5HrXp6adAjACjK9JXDnKaTXpA= sigs.k8s.io/apiserver-network-proxy/konnectivity-client v0.0.14/go.mod h1:LEScyzhFmoF5pso/YSeBstl57mOzx9xlU9n85RGrDQg= sigs.k8s.io/apiserver-network-proxy/konnectivity-client v0.0.33/go.mod h1:soWkSNf2tZC7aMibXEqVhCd73GOY5fJikn8qbdzemB0= sigs.k8s.io/controller-runtime v0.8.3 h1:GMHvzjTmaWHQB8HadW+dIvBoJuLvZObYJ5YoZruPRao= sigs.k8s.io/controller-runtime v0.8.3/go.mod h1:U/l+DUopBc1ecfRZ5aviA9JDmGFQKvLf5YkZNx2e0sU= sigs.k8s.io/controller-runtime v0.14.1 h1:vThDes9pzg0Y+UbCPY3Wj34CGIYPgdmspPm2GIpxpzM= sigs.k8s.io/controller-runtime v0.14.1/go.mod h1:GaRkrY8a7UZF0kqFFbUKG7n9ICiTY5T55P1RiE3UZlU= sigs.k8s.io/json v0.0.0-20211020170558-c049b76a60c6 h1:fD1pz4yfdADVNfFmcP2aBEtudwUQ1AlLnRBALr33v3s= sigs.k8s.io/json v0.0.0-20211020170558-c049b76a60c6/go.mod h1:p4QtZmO4uMYipTQNzagwnNoseA6OxSUutVw05NhYDRs= sigs.k8s.io/json v0.0.0-20220713155537-f223a00ba0e2 h1:iXTIw73aPyC+oRdyqqvVJuloN1p0AC/kzH07hu3NE+k= sigs.k8s.io/json v0.0.0-20220713155537-f223a00ba0e2/go.mod h1:B8JuhiUyNFVKdsE8h686QcCxMaH6HrOAZj4vswFpcB0= sigs.k8s.io/kustomize v2.0.3+incompatible/go.mod h1:MkjgH3RdOWrievjo6c9T245dYlB5QeXV4WCbnt/PEpU= sigs.k8s.io/kustomize/api v0.12.1 h1:7YM7gW3kYBwtKvoY216ZzY+8hM+lV53LUayghNRJ0vM= sigs.k8s.io/kustomize/api v0.12.1/go.mod h1:y3JUhimkZkR6sbLNwfJHxvo1TCLwuwm14sCYnkH6S1s= sigs.k8s.io/kustomize/cmd/config v0.10.9/go.mod h1:T0s850zPV3wKfBALA0dyeP/K74jlJcoP8Pr9ZWwE3MQ= sigs.k8s.io/kustomize/kustomize/v4 v4.5.7/go.mod h1:VSNKEH9D9d9bLiWEGbS6Xbg/Ih0tgQalmPvntzRxZ/Q= sigs.k8s.io/kustomize/kyaml v0.13.9 h1:Qz53EAaFFANyNgyOEJbT/yoIHygK40/ZcvU3rgry2Tk= sigs.k8s.io/kustomize/kyaml v0.13.9/go.mod h1:QsRbD0/KcU+wdk0/L0fIp2KLnohkVzs6fQ85/nOXac4= sigs.k8s.io/structured-merge-diff/v4 v4.0.2/go.mod h1:bJZC9H9iH24zzfZ/41RGcq60oK1F7G282QMXDPYydCw= sigs.k8s.io/structured-merge-diff/v4 v4.0.3 h1:4oyYo8NREp49LBBhKxEqCulFjg26rawYKrnCmg+Sr6c= sigs.k8s.io/structured-merge-diff/v4 v4.0.3/go.mod h1:bJZC9H9iH24zzfZ/41RGcq60oK1F7G282QMXDPYydCw= sigs.k8s.io/structured-merge-diff/v4 v4.2.1 h1:bKCqE9GvQ5tiVHn5rfn1r+yao3aLQEaLzkkmAkf+A6Y= sigs.k8s.io/structured-merge-diff/v4 v4.2.1/go.mod h1:j/nl6xW8vLS49O8YvXW1ocPhZawJtm+Yrr7PPRQ0Vg4= sigs.k8s.io/structured-merge-diff/v4 v4.2.3 h1:PRbqxJClWWYMNV1dhaG4NsibJbArud9kFxnAMREiWFE= sigs.k8s.io/structured-merge-diff/v4 v4.2.3/go.mod h1:qjx8mGObPmV2aSZepjQjbmb2ihdVs8cGKBraizNC69E= sigs.k8s.io/yaml v1.1.0/go.mod h1:UJmg0vDUVViEyp3mgSv9WPwZCDxu4rQW1olrI1uml+o= sigs.k8s.io/yaml v1.2.0 h1:kr/MCeFWJWTwyaHoR9c8EjH9OumOmoF9YGiZd7lFm/Q= sigs.k8s.io/yaml v1.2.0/go.mod h1:yfXDCHCao9+ENCvLSE62v9VSji2MKu5jeNfTrofGhJc= sigs.k8s.io/yaml v1.3.0 h1:a2VclLzOGrwOHDiV8EfBGhvjHvP46CtW5j6POvhYGGo= sigs.k8s.io/yaml v1.3.0/go.mod h1:GeOyir5tyXNByN85N/dRIT9es5UQNerPYEKK56eTBm8= ================================================ FILE: deploy/hack/boilerplate.go.txt ================================================ /* Copyright 2023. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ ================================================ FILE: deploy/main.go ================================================ /* Copyright 2023. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package main import ( "flag" "os" // Import all Kubernetes client auth plugins (e.g. Azure, GCP, OIDC, etc.) // to ensure that exec-entrypoint and run can make use of them. _ "k8s.io/client-go/plugin/pkg/client/auth" "k8s.io/apimachinery/pkg/runtime" utilruntime "k8s.io/apimachinery/pkg/util/runtime" clientgoscheme "k8s.io/client-go/kubernetes/scheme" ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/healthz" "sigs.k8s.io/controller-runtime/pkg/log/zap" mlplatformv1 "code.byted.org/data/monolith/deploy/api/v1" "code.byted.org/data/monolith/deploy/controllers" appsclient "k8s.io/client-go/kubernetes/typed/apps/v1" //+kubebuilder:scaffold:imports ) var ( scheme = runtime.NewScheme() setupLog = ctrl.Log.WithName("setup") ) func init() { utilruntime.Must(clientgoscheme.AddToScheme(scheme)) utilruntime.Must(mlplatformv1.AddToScheme(scheme)) //+kubebuilder:scaffold:scheme } func main() { var metricsAddr string var enableLeaderElection bool var probeAddr string flag.StringVar(&metricsAddr, "metrics-bind-address", ":8080", "The address the metric endpoint binds to.") flag.StringVar(&probeAddr, "health-probe-bind-address", ":8081", "The address the probe endpoint binds to.") flag.BoolVar(&enableLeaderElection, "leader-elect", false, "Enable leader election for controller manager. "+ "Enabling this will ensure there is only one active controller manager.") opts := zap.Options{ Development: true, } opts.BindFlags(flag.CommandLine) flag.Parse() ctrl.SetLogger(zap.New(zap.UseFlagOptions(&opts))) mgr, err := ctrl.NewManager(ctrl.GetConfigOrDie(), ctrl.Options{ Scheme: scheme, MetricsBindAddress: metricsAddr, Port: 9443, HealthProbeBindAddress: probeAddr, LeaderElection: enableLeaderElection, LeaderElectionID: "183d5a48.volcengine.com", }) if err != nil { setupLog.Error(err, "unable to start manager") os.Exit(1) } appsClient, err := appsclient.NewForConfig(mgr.GetConfig()) if err != nil { setupLog.Error(err, "unable to create AppsV1Client") os.Exit(1) } if err = (&controllers.MLServiceReconciler{ AppsV1Client: *appsClient, Client: mgr.GetClient(), Scheme: mgr.GetScheme(), }).SetupWithManager(mgr); err != nil { setupLog.Error(err, "unable to create controller", "controller", "MLService") os.Exit(1) } //+kubebuilder:scaffold:builder if err := mgr.AddHealthzCheck("healthz", healthz.Ping); err != nil { setupLog.Error(err, "unable to set up health check") os.Exit(1) } if err := mgr.AddReadyzCheck("readyz", healthz.Ping); err != nil { setupLog.Error(err, "unable to set up ready check") os.Exit(1) } setupLog.Info("starting manager") if err := mgr.Start(ctrl.SetupSignalHandler()); err != nil { setupLog.Error(err, "problem running manager") os.Exit(1) } } ================================================ FILE: deploy/serving/agent.conf ================================================ bzid {{bzid}} base_name {{base_name}} base_path {{base_path}} num_ps {{num_ps}} server_type {{server_type}} zk_servers {{zk_servers}} dense_alone {{dense_alone}} update_model_status_interval 10 enable_batching {{enable_batching}} tensorflow_session_parallelism 0 tensorflow_intra_op_parallelism 0 tensorflow_inter_op_parallelism 0 per_process_gpu_memory_fraction 0 num_load_threads 0 num_unload_threads 0 max_num_load_retries 5 load_retry_interval_micros 60 * 1000 * 1000 file_system_poll_wait_seconds 1 file_system_poll_wait_seconds_ps 0 flush_filesystem_caches true saved_model_tags none grpc_channel_arguments none grpc_max_threads 0 enable_model_warmup true enable_signature_method_name_check false xla_cpu_compilation_enabled false enable_profiler true aio_thread_num 20 dc_aware true ================================================ FILE: deploy/serving/docker/Dockerfile ================================================ FROM debian:buster-20221219 LABEL maintainer="Monolith" ARG PYPI_SOURCE=https://pypi.tuna.tsinghua.edu.cn/simple # pre install for tsinghua apt source RUN set -eux; \ apt-get update; \ apt-get install -y --no-install-recommends \ apt-transport-https ca-certificates wget dirmngr gnupg \ software-properties-common \ ; # install java RUN set -eux; \ wget -qO - https://adoptopenjdk.jfrog.io/adoptopenjdk/api/gpg/key/public | apt-key add - ; \ add-apt-repository --yes https://adoptopenjdk.jfrog.io/adoptopenjdk/deb/; \ apt-get update; \ apt-get install -y --no-install-recommends adoptopenjdk-8-hotspot ENV JAVA_HOME /usr/lib/jvm/adoptopenjdk-8-hotspot-amd64/ ENV JRE_HOME ${JAVA_HOME}/jre ENV CLASSPATH .:${JAVA_HOME}/lib:${JRE_HOME}/lib ENV PATH ${JAVA_HOME}/bin:$PATH # copy assets ADD deploy/serving/docker/assets /tmp/assets # Copy the service of dumping environment variables COPY deploy/serving/docker/assets/configurator_dumpenv.sh /root/.system/ COPY deploy/serving/docker/assets/configurator_dumpenv.service /etc/systemd/system/ # Configure bashrc COPY deploy/serving/docker/assets/bashrc /root/.bashrc # Change mirrors in /etc/apt/sources.list to tsinghua mirrors RUN set -eux; \ { \ echo "deb https://mirrors.tuna.tsinghua.edu.cn/debian/ buster main contrib non-free"; \ echo "deb https://mirrors.tuna.tsinghua.edu.cn/debian/ buster-updates main contrib non-free"; \ echo "deb https://mirrors.tuna.tsinghua.edu.cn/debian/ buster-backports main contrib non-free"; \ echo "deb https://mirrors.tuna.tsinghua.edu.cn/debian-security buster/updates main contrib non-free"; \ } > /etc/apt/sources.list; \ \ # Install required packages apt-get update; \ apt-get upgrade -y; \ DEBIAN_FRONTEND=noninteractive \ apt-get install -y --no-install-recommends \ curl lsof procps locales tzdata less vim python \ build-essential autoconf automake bzip2 file imagemagick libbz2-dev \ libc6-dev openssl libcurl4-openssl-dev libdb-dev libevent-dev \ libffi-dev libgdbm-dev libgeoip-dev libglib2.0-dev libjpeg-dev \ libkrb5-dev liblzma-dev libmagickcore-dev libmagickwand-dev \ libncurses-dev libpng-dev libpq-dev libreadline-dev libsqlite3-dev \ default-libmysqlclient-dev libssl-dev libtool libwebp-dev libxml2 \ libxml2-dev libxslt-dev libyaml-dev make patch xz-utils zlib1g-dev \ tcl tk git rsync ssh net-tools iputils-ping pbzip2 python-dev netcat \ libxslt1-dev libcap2-bin libjemalloc2 libjemalloc-dev libsnappy1v5 \ libtcmalloc-minimal4 libzookeeper-mt2 lldpd libnss3 pv gnupg2 libaio1 \ systemd systemd-sysv libsystemd0 netcat-openbsd rsyslog unscd \ apt-utils cgroup-bin cmake gdb libncurses5-dev libnss3-dev \ libprotobuf-dev linux-base linux-libc-dev linux-perf \ openssh-client openssh-server protobuf-compiler \ python-pip python-pycurl python-setuptools python-wheel python3-dev \ python3-pip python3-pycurl python3-setuptools python3-wheel sysstat \ telnet tree libunwind-dev numactl unzip \ ; \ # Remove apt cache rm -rf /var/lib/apt/lists/* # RDMA Essentials # download from https://linux.mellanox.com/public/repo/mlnx_rdma_minimal/5.0-1.0.0.0/debian10.0/ RUN set -eux; \ apt-get update && apt-get install -y --no-install-recommends \ libnl-3-200=3.4.0-1 libnl-3-dev libnl-route-3-200=3.4.0-1 libnl-route-3-dev; \ dpkg -i /tmp/assets/rdma/librdmacm1_50mlnx1-1.50100.0_amd64.deb \ /tmp/assets/rdma/rdmacm-utils_50mlnx1-1.50100.0_amd64.deb \ /tmp/assets/rdma/libibverbs1_50mlnx1-1.50100.0_amd64.deb \ /tmp/assets/rdma/ibverbs-utils_50mlnx1-1.50100.0_amd64.deb \ /tmp/assets/rdma/libibverbs-dev_50mlnx1-1.50100.0_amd64.deb \ /tmp/assets/rdma/ibverbs-providers_50mlnx1-1.50100.0_amd64.deb \ /tmp/assets/rdma/libibumad3_50mlnx1-1.50100.0_amd64.deb RUN set -eux; \ # Create tiger account if ! id -u tiger >/dev/null 2>&1 ; then \ groupadd -f tiger; \ useradd -u 1000 -g tiger -d /home/tiger -m -s /bin/bash tiger; \ fi; \ mkdir -p /home/tiger/.service/ /opt/tiger /opt/log/tiger /var/log/tiger; \ chown tiger:tiger /home/tiger/.service/ /opt/tiger /opt/log/tiger /var/log/tiger; \ # Change timezone echo "Asia/Shanghai" > /etc/timezone; \ ln -sfn /usr/share/zoneinfo/Asia/Shanghai /etc/localtime; \ \ # Generate locales sed -i 's/# en_US.UTF-8 UTF-8/en_US.UTF-8 UTF-8/' /etc/locale.gen; \ sed -i 's/# zh_CN.UTF-8 UTF-8/zh_CN.UTF-8 UTF-8/' /etc/locale.gen; \ locale-gen; \ # Configure permissions chmod 700 /root/.system/configurator_dumpenv.sh; \ # Install pip tar -xf /tmp/assets/Python-3.8.6.tar.xz && cd Python-3.8.6 && \ ./configure --enable-optimizations && make -j8 build_all && make altinstall && \ update-alternatives --install /usr/bin/python3 python3 /usr/local/bin/python3.8 3 && \ ln -s /usr/share/pyshared/lsb_release.py /usr/local/lib/python3.8/site-packages/lsb_release.py && \ cd .. && rm -rf Python-3.8.* && \ cd /usr/local/bin && ln -s pip3.8 pip3 && \ cp /tmp/assets/pip.conf /etc/ && mkdir ~/.pip && cp /tmp/assets/pip.conf ~/.pip/ && \ printf "\n. /etc/profile\n" >> /root/.bashrc && printf "\n. /etc/profile\n" >> /home/tiger/.bashrc && \ sh /tmp/assets/build.sh && apt-get clean && rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/* ~/.cache/pip && \ # Enable the service of dumping environment variables systemctl enable configurator_dumpenv.service; \ # Configure systemd sed -i '/#DefaultLimitNOFILE=/c\DefaultLimitNOFILE=1048576:1048576' /etc/systemd/system.conf; \ rm -fr /lib/systemd/system/multi-user.target.wants/* \ /etc/systemd/system/*.wants/* \ /lib/systemd/system/local-fs.target.wants/* \ /lib/systemd/system/sockets.target.wants/*udev* \ /lib/systemd/system/sockets.target.wants/*initctl* \ /lib/systemd/system/sysinit.target.wants/systemd-tmpfiles-setup* \ /lib/systemd/system/systemd-update-utmp* \ /etc/systemd/system/-.mount \ /lib/systemd/system/user-.slice.d/; \ # Remove cron.daily rm -f /etc/cron.daily/* # Set locales to en_US.UTF-8 ENV LC_ALL=en_US.UTF-8 LANG=en_US.UTF-8 LC_CTYPE=en_US.UTF-8 ENV IS_DOCKER_ENV=true # Set the default mount path, the path /sys/fs/cgroup is necessary to systemd; VOLUME ["/sys/fs/cgroup","/run","/run/lock","/tmp"] # Currently the service is managed by systemd. So systemd needs to be installed and set as the default command. CMD ["/lib/systemd/systemd"] # CUDA Essentials RUN apt-get update && \ apt-get install -yq --no-install-recommends \ gawk \ binutils-dev \ && \ curl -fsSL http://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/3bf863cc.pub | apt-key add - && \ echo "deb http://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64 /" > /etc/apt/sources.list.d/cuda.list && \ curl -fsSL http://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64/7fa2af80.pub | apt-key add - && \ echo "deb http://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64 /" > /etc/apt/sources.list.d/nvidia-ml.list && \ apt-get update && \ apt-get install -yq --no-install-recommends --fix-missing \ cuda-cudart-11-0=11.0.221-1 \ cuda-compat-11-0 \ cuda-libraries-11-0=11.0.3-1 \ libnpp-11-0=11.1.0.245-1 \ cuda-nvtx-11-0=11.0.167-1 \ libcublas-11-0=11.2.0.252-1 \ cuda-nvml-dev-11-0=11.0.167-1 \ cuda-command-line-tools-11-0=11.0.3-1 \ cuda-nvprof-11-0=11.0.221-1 \ libnpp-dev-11-0=11.1.0.245-1 \ cuda-libraries-dev-11-0=11.0.3-1 \ cuda-minimal-build-11-0=11.0.3-1 \ libcublas-dev-11-0=11.2.0.252-1 \ libcusparse-11-0=11.1.1.245-1 \ libcusparse-dev-11-0=11.1.1.245-1 \ libcudnn8=8.0.5.39-1+cuda11.0 \ libcudnn8-dev=8.0.5.39-1+cuda11.0 \ libnccl2=2.8.3-1+cuda11.0 \ libnccl-dev=2.8.3-1+cuda11.0 \ libnvinfer7=7.2.1-1+cuda11.0 \ libnvinfer-dev=7.2.1-1+cuda11.0 \ libnvinfer-plugin7=7.2.1-1+cuda11.0 \ libnvinfer-plugin-dev=7.2.1-1+cuda11.0 \ libxml-sax-expat-perl libexpat1 libexpat1-dev \ && \ ln -s /usr/local/cuda-11.0 /usr/local/cuda && \ find /usr/local/cuda-11.0/lib64/ -type f -name '*.a' -not -name 'libcudart_static.a' -not -name 'libcudadevrt.a' -delete && \ rm /etc/alternatives/libcudnn_stlib && \ rm /usr/lib/x86_64-linux-gnu/libcudnn_static.a && \ rm /usr/lib/x86_64-linux-gnu/libcudnn_static_v8.a && \ rm /usr/lib/x86_64-linux-gnu/libnccl_static.a && \ rm /usr/lib/x86_64-linux-gnu/libnvinfer_static.a && \ rm /usr/lib/x86_64-linux-gnu/libnvinfer_plugin_static.a && \ rm /usr/lib/x86_64-linux-gnu/libmyelin_compiler_static.a && \ rm /usr/lib/x86_64-linux-gnu/libmyelin_executor_static.a && \ rm /usr/lib/x86_64-linux-gnu/libmyelin_pattern_library_static.a && \ rm /usr/lib/x86_64-linux-gnu/libmyelin_pattern_runtime_static.a && \ rm /etc/apt/sources.list.d/cuda.list && \ rm /etc/apt/sources.list.d/nvidia-ml.list RUN python3.8 -m pip install ifstat==1.0.3 absl-py==0.12.0 kazoo==2.8.0 Flask-API==2.0 \ dataclasses-json==0.5.2 numpy==1.23.4 psutil==5.8.0 msgpack==1.0.2 \ pyinotify==0.9.6 Jinja2==2.11.3 requests==2.25.1 PyYAML==3.13 redis==3.5.1 \ protobuf==3.12.4 grpcio==1.26.0 sqlalchemy==1.3.24 tensorflow-gpu==2.4.0 RUN HOROVOD_NCCL_LINK=SHARED HOROVOD_WITHOUT_GLOO=1 HOROVOD_WITH_TENSORFLOW=1 HOROVOD_GPU_OPERATIONS=NCCL python3.8 -m pip install --no-cache-dir horovod==0.21.3 # CUDA environment variables ENV CUDA_HOME "/usr/local/cuda" ENV PATH "${CUDA_HOME}/bin:${PATH}" # ENV LD_LIBRARY_PATH "${CUDA_HOME}/compat:${CUDA_HOME}/lib64:${LD_LIBRARY_PATH}" ENV LD_LIBRARY_PATH "${CUDA_HOME}/lib64:${LD_LIBRARY_PATH}" # nvidia-container-runtime ENV NVIDIA_VISIBLE_DEVICES all ENV NVIDIA_DRIVER_CAPABILITIES compute,utility ENV NVIDIA_REQUIRE_CUDA "cuda>=11.0 brand=tesla,driver>=418,driver<419 brand=tesla,driver>=440,driver<441 brand=tesla,driver>=450,driver<451" # hadoop RUN wget -q https://mirrors.aliyun.com/apache/hadoop/common/hadoop-3.3.2/hadoop-3.3.2.tar.gz && \ tar -xzf hadoop-3.3.2.tar.gz && mv hadoop-3.3.2 /opt/tiger/hadoop && rm hadoop-3.3.2.tar.gz ENV HADOOP_HOME /opt/tiger/hadoop/ ENV HDFS_JDK ${JAVA_HOME} ENV LD_LIBRARY_PATH ${LD_LIBRARY_PATH}:${HADOOP_HOME}/lib/native:${JAVA_HOME}/jre/lib/amd64/server ENV CLASSPATH ${CLASSPATH}:`${HADOOP_HOME}/bin/hadoop classpath --glob` ENV HADOOP_HDFS_HOME $HADOOP_HOME ENV HADOOP_OPTS "-Djava.library.path=$HADOOP_HOME/lib/native" ENV PATH $PATH:$HADOOP_HOME/sbin:$HADOOP_HOME/bin RUN printf "\nexport PATH=${JAVA_HOME}/bin:${CUDA_HOME}/bin:$HADOOP_HOME/sbin:$HADOOP_HOME/bin:$PATH\n" >> /root/.bashrc && \ printf "\nexport PATH=${JAVA_HOME}/bin:${CUDA_HOME}/bin:$HADOOP_HOME/sbin:$HADOOP_HOME/bin:$PATH\n" >> /home/tiger/.bashrc # ADD monolith_serving /opt/tiger/monolith_serving ================================================ FILE: deploy/serving/docker/assets/Python-3.8.6.tar.xz ================================================ [File too large to display: 17.4 MB] ================================================ FILE: deploy/serving/docker/assets/bashrc ================================================ # ~/.bashrc: executed by bash(1) for non-login shells. # see /usr/share/doc/bash/examples/startup-files (in the package bash-doc) # for examples # If not running interactively, don't do anything case $- in *i*) ;; *) return;; esac # don't put duplicate lines or lines starting with space in the history. # See bash(1) for more options HISTCONTROL=ignoreboth # append to the history file, don't overwrite it shopt -s histappend # for setting history length see HISTSIZE and HISTFILESIZE in bash(1) HISTSIZE=1000 HISTFILESIZE=2000 # check the window size after each command and, if necessary, # update the values of LINES and COLUMNS. shopt -s checkwinsize # If set, the pattern "**" used in a pathname expansion context will # match all files and zero or more directories and subdirectories. #shopt -s globstar # make less more friendly for non-text input files, see lesspipe(1) #[ -x /usr/bin/lesspipe ] && eval "$(SHELL=/bin/sh lesspipe)" # set variable identifying the chroot you work in (used in the prompt below) if [ -z "${debian_chroot:-}" ] && [ -r /etc/debian_chroot ]; then debian_chroot=$(cat /etc/debian_chroot) fi # set a fancy prompt (non-color, unless we know we "want" color) case "$TERM" in xterm-color) color_prompt=yes;; esac # uncomment for a colored prompt, if the terminal has the capability; turned # off by default to not distract the user: the focus in a terminal window # should be on the output of commands, not on the prompt #force_color_prompt=yes if [ -n "$force_color_prompt" ]; then if [ -x /usr/bin/tput ] && tput setaf 1 >&/dev/null; then # We have color support; assume it's compliant with Ecma-48 # (ISO/IEC-6429). (Lack of such support is extremely rare, and such # a case would tend to support setf rather than setaf.) color_prompt=yes else color_prompt= fi fi unset color_prompt force_color_prompt # If this is an xterm set the title to user@host:dir case "$TERM" in xterm*|rxvt*) PS1="\[\e]0;${debian_chroot:+($debian_chroot)}\u@\h: \W\a\]$PS1" ;; *) ;; esac # enable color support of ls and also add handy aliases if [ -x /usr/bin/dircolors ]; then test -r ~/.dircolors && eval "$(dircolors -b ~/.dircolors)" || eval "$(dircolors -b)" alias ls='ls --color=auto' #alias dir='dir --color=auto' #alias vdir='vdir --color=auto' #alias grep='grep --color=auto' #alias fgrep='fgrep --color=auto' #alias egrep='egrep --color=auto' fi # some more ls aliases #alias ll='ls -l' #alias la='ls -A' #alias l='ls -CF' # Alias definitions. # You may want to put all your additions into a separate file like # ~/.bash_aliases, instead of adding them here directly. # See /usr/share/doc/bash-doc/examples in the bash-doc package. if [ -f ~/.bash_aliases ]; then . ~/.bash_aliases fi # enable programmable completion features (you don't need to enable # this, if it's already enabled in /etc/bash.bashrc and /etc/profile # sources /etc/bash.bashrc). if ! shopt -oq posix; then if [ -f /usr/share/bash-completion/bash_completion ]; then . /usr/share/bash-completion/bash_completion elif [ -f /etc/bash_completion ]; then . /etc/bash_completion fi fi alias l=ls alias ll='ls -l --color=auto' alias cp='cp -i' alias mv='mv -i' alias m='more' alias ll='ls -l' alias lsl='ls -lrt' alias lm='ls -al|more' alias l='ls -lrt' alias c='cat' alias v='vi' alias cl='clear' alias pg='ps -ef| grep ' export TERM=xterm ================================================ FILE: deploy/serving/docker/assets/build.sh ================================================ #!/usr/bin/env bash set -ex # still make python2 as default, but python 3.8 is installed export PYTHON_PIP_VERSION=20.1 curl -skSLf -o get-pip.py 'https://bootstrap.pypa.io/pip/2.7/get-pip.py' python get-pip.py \ --disable-pip-version-check \ --no-cache-dir \ -i http://mirrors.aliyun.com/pypi/simple \ --trusted-host mirrors.aliyun.com \ "pip==$PYTHON_PIP_VERSION" find /usr/local -depth \ \( \ \( -type d -a -name test -o -name tests \) \ -o \ \( -type f -a -name '*.pyc' -o -name '*.pyo' \) \ \) -exec rm -rf '{}' +; rm -f get-pip.py rm -rf /usr/src/python cd /root/ pip uninstall -y paramiko pycrypto pip install --no-cache-dir setuptools==44.1.0 virtualenv==15.1.0 cffi==1.12.3 paramiko==1.18.3 pip install --no-cache-dir -r /tmp/assets/requirements.txt # systemd [ -d /etc/systemd/system/user@1000.service.d ] || mkdir /etc/systemd/system/user@1000.service.d echo "[Service] Restart=always" > /etc/systemd/system/user@1000.service.d/always.conf echo "[Service] LimitNOFILE=1000000 LimitMEMLOCK=infinity" > /etc/systemd/system/user@1000.service.d/limits.conf ## for run systemd cd /lib/systemd/system/sysinit.target.wants/ && \ ls | grep -v systemd-tmpfiles-setup | xargs rm -f $1 && \ rm -f /lib/systemd/system/sockets.target.wants/*udev* systemctl mask -- \ apt-daily-upgrade.timer \ apt-daily.timer \ cgmanager.service \ cgproxy.service \ dev-mqueue.mount \ getty-static.service \ getty.target \ swap.target \ systemd-logind.service \ systemd-remount-fs.service \ systemd-timesyncd.service \ systemd-tmpfiles-setup-dev.service \ systemd-tmpfiles-setup.service \ systemd-update-utmp-runlevel.service; \ tmp.mount \ etc-hostname.mount \ etc-hosts.mount \ etc-resolv.conf.mount \ -.mount \ ================================================ FILE: deploy/serving/docker/assets/configurator_dumpenv.service ================================================ [Unit] Description=Dump the Docker environment variables [Service] Type=oneshot ExecStart=/root/.system/configurator_dumpenv.sh TimeoutSec=0 ================================================ FILE: deploy/serving/docker/assets/configurator_dumpenv.sh ================================================ #!/bin/bash - xargs -0 bash -c 'printf "%q\n" "$@" ; systemctl set-environment "$@"' -- \ < /proc/1/environ \ > /var/docker_environment chmod 700 /var/docker_environment ================================================ FILE: deploy/serving/docker/assets/pip.conf ================================================ [global] index-url=http://mirrors.aliyun.com/pypi/simple trusted-host=mirrors.aliyun.com timeout = 600 disable_pip_version_check = 1 [install] trusted-host=mirrors.aliyun.com ================================================ FILE: deploy/serving/docker/assets/requirements.txt ================================================ absl-py==0.7.1 ansible==2.2.1.0 ansicolors==1.0.2 APScheduler==3.5.1 asn1crypto==0.24.0 astor==0.8.0 backports.ssl-match-hostname==3.5.0.1 backports.weakref==1.0.post1 bcrypt==3.1.6 bpython==0.12 certifi==2017.7.27.1 cffi==1.12.3 chardet==2.3.0 cityhash==0.2.3.post9 click==6.7 colorama==0.3.2 coloredlogs==10.0 cryptography==2.6.1 DBUtils==1.3 decorator==3.4.0 defusedxml==0.4.1 Django==1.11.3 docutils==0.12 ecdsa==0.11 enum34==1.1.6 executor==21.3 fasteners==0.14.1 filelock==2.0.11 Flask==0.12.2 funcsigs==1.0.2 futures==3.2.0 gast==0.2.2 Geohash==1.0 gevent==1.1.1 google-pasta==0.1.7 greenlet==0.4.10 grpcio==1.22.0 gunicorn==19.9.0 h5py==2.9.0 hash-ring==1.3.1 html5lib==0.999 httplib2==0.9 humanfriendly==4.18 idna==2.0 ipaddress==1.0.22 ipython==2.3.0 itsdangerous==0.24 jieba==0.42.1 kafka-python==1.3.5 lockfile==0.8 lxml==3.4.0 MarkupSafe==1.0 monotonic==1.5 msgpack-python==0.4.8 msgpack==0.6.1 mysqlclient==1.4.6 ndg-httpsclient==0.4.2 numpy==1.12.1 pandas==0.24.2 Pillow==2.6.1 proc==0.17 property-manager==2.3.1 protobuf==3.11.3 pssh==2.3.1 psutil==5.6.1 pyasn1==0.1.9 pycparser==2.19 Pygments==2.0.1 PyNaCl==1.3.0 pyOpenSSL==19.0.0 PySocks==1.7.0 python-consul==0.4.0 python-daemon==1.5.5 python-dateutil==2.8.0 python-decouple==3.0 python-memcached==1.58 python-utils==2.3.0 pytz==2019.1 PyYAML==3.12 redis==2.10.5 requests==2.11.1 roman==2.0.0 ruamel.ordereddict==0.4.13 ruamel.yaml==0.15.85 scipy==0.14.0 setuptools==44.0.0 simplejson==3.11.1 six==1.12.0 SOAPpy==0.12.22 sqlalchemy-migrate==0.12.0 sqlalchemy==1.2.12 sqlparse==0.3.0 tabulate==0.7.7 Tempita==0.5.2 termcolor==1.1.0 thrift==0.13.0 tornado==4.1 transitions==0.7.1 ujson==1.35 urllib3==1.16 verboselogs==1.7 Werkzeug==0.12.2 wstools==0.4.3 xcmd==0.0.3 xmltodict==0.9.0 zk-shell==1.1.3 zstandard==0.13.0 ================================================ FILE: deploy/serving/docker/run ================================================ #! /bin/bash docker build --force-rm -t bytedance.monolith_pro.release:v1.0.0 ./ ================================================ FILE: deploy/serving/open_source_serving.sh ================================================ #!/bin/bash set -eux export SHARD_ID=`expr $MLP_SHARD_ID - 1` echo "The SHARD_ID {$SHARD_ID}" export MY_POD_NAME=$MLP_POD_NAME echo "The MY_POD_NAME {$MY_POD_NAME}" export byterec_host_shard_n=$MLP_SHARD_NUM echo "The byterec_host_shard_n {$byterec_host_shard_n}" if [ $MLP_IDC ]; then export TCE_INTERNAL_IDC=$MLP_IDC else export TCE_INTERNAL_IDC="cn-beijing-b" fi echo "The TCE_INTERNAL_IDC {$TCE_INTERNAL_IDC}" #export TCE_CLUSTER=$ export TCE_CLUSTER=default echo "The TCE_CLUSTER {$TCE_CLUSTER}" # TCE_PSM for metrics PSM_PREFIX="data.tob.monolith_serving_" SHELL_FOLDER=/opt/tiger/monolith_serving export PATH=$SHELL_FOLDER:$PATH if [ $MLP_ROLE_NAME = 'PS' ]; then export SERVER_TYPE='ps' export ENABLE_BATCHING=false export TCE_PSM=$PSM_PREFIX"ps-"$MLP_SERVICE_NAME elif [ $MLP_ROLE_NAME == 'Entry' ]; then export SERVER_TYPE='entry' export ENABLE_BATCHING=false export TCE_PSM=$PSM_PREFIX"en-"$MLP_SERVICE_NAME elif [ $MLP_ROLE_NAME == 'DenseNN' ]; then export SERVER_TYPE='dense' export CUDA_MPS_PIPE_DIRECTORY=/dev/shm export ENABLE_BATCHING=true export TCE_PSM=$PSM_PREFIX"de-"$MLP_SERVICE_NAME nvidia-cuda-mps-control -d fi echo "THE SERVER_TYPE {$SERVER_TYPE}" echo "THE ENABLE_BATCHING {$ENABLE_BATCHING}" echo "The TCE_PSM {$TCE_PSM}" cd $SHELL_FOLDER echo "The shell folder is {$SHELL_FOLDER}" PYV=$(python -c "import sys; print('{}.{}'.format(sys.version_info.major, sys.version_info.minor))") echo "Using sparse_dense_serving: {$DENSE_ALONE}" cat agent.conf | sed -e "s/{{bzid}}/${BZID}/g" -e "s/{{base_name}}/${BASE_NAME}/g" -e "s?{{base_path}}?${BASE_PATH}?g" -e "s/{{num_ps}}/${NUM_PS}/g" -e "s/{{server_type}}/${SERVER_TYPE}/g" \ -e "s/{{zk_servers}}/${ZK_SERVERS}/g" -e "s/{{dense_alone}}/${DENSE_ALONE}/g" -e "s/{{enable_batching}}/${ENABLE_BATCHING}/g" > render_agent.conf # add other conf parameter here #echo -e "\ndense_service_num 3" >> render_agent.conf cd $SHELL_FOLDER/bin if [ $PYV = '3.8' ]; then python run --bin_name="agent" --conf /opt/tiger/monolith_serving/render_agent.conf else python3 run --bin_name="agent" --conf /opt/tiger/monolith_serving/render_agent.conf fi ================================================ FILE: deploy/serving/scripts/build_serving.sh ================================================ #!/bin/bash set -eux script_dir=`dirname $0` abs_script_dir=`realpath $script_dir` use_gpu="${1:-false}" rm -rf output mkdir -p output bazel --version if [ "$use_gpu" = "true" ]; then bazel build \ --output_filter=DONT_MATCH_ANYTHING \ --define=framework_shared_object=false \ --config=cuda \ @org_tensorflow_serving//tensorflow_serving/model_servers:tensorflow_model_server else bazel build \ --output_filter=DONT_MATCH_ANYTHING \ --define=framework_shared_object=false \ @org_tensorflow_serving//tensorflow_serving/model_servers:tensorflow_model_server fi # We can't compile archon in TensorFlow Py bazel build \ --output_filter=DONT_MATCH_ANYTHING \ --define=framework_shared_object=false \ //monolith/agent_service:agent bazel build \ --output_filter=DONT_MATCH_ANYTHING \ --define=framework_shared_object=false \ //monolith/agent_service:tfs_client # 1) prepare output mkdir -p output/bin/ mkdir -p output/lib/ function clear_external() { runfiles_dir="$1" echo "runfiles_dir: $runfiles_dir" pushd $runfiles_dir/__main__/external for external_name in `ls .`; do if [ -d "../../$external_name" ]; then rm -rf $external_name && ln -s ../../$external_name . fi done popd } cp -frL bazel-bin/monolith/agent_service/agent.runfiles/ output/ clear_external output/agent.runfiles cp -frL bazel-bin/monolith/agent_service/tfs_client.runfiles/ output/ clear_external output/tfs_client.runfiles cp -frL bazel-bin/external/org_tensorflow_serving/tensorflow_serving/model_servers/tensorflow_model_server.runfiles/ output/ rm output/tensorflow_model_server.runfiles/org_tensorflow_serving -rf cd output/bin ln -s ../agent.runfiles/__main__/monolith/agent_service/agent . ln -s ../tfs_client.runfiles/__main__/monolith/agent_service/tfs_client . ln -s ../tensorflow_model_server.runfiles/__main__/external/org_tensorflow_serving/tensorflow_serving/model_servers/tensorflow_model_server . cd - cp -rL $abs_script_dir/run_server output/bin/ cp -rL $abs_script_dir/conf output ================================================ FILE: deploy/serving/scripts/run_server ================================================ #!/bin/bash # need env: TCE_PSM, TCE_INTERNAL_IDC, TCE_CLUSTER, SHARD_ID # need arg: --conf agent.conf python3 agent $@ ================================================ FILE: idl/BUILD ================================================ load("@com_google_protobuf//:protobuf.bzl", "cc_proto_library") load("@com_google_protobuf//:protobuf.bzl", "py_proto_library") load("@rules_cc//cc:defs.bzl", "cc_library", "cc_test") load("@rules_proto//proto:defs.bzl", "proto_library") proto_library( name = "proto_parser_proto", srcs = [ "matrix/proto/feature.proto", "matrix/proto/line_id.proto", "matrix/proto/proto_parser.proto", ], visibility = ["//visibility:public"], ) cc_proto_library( name = "line_id_cc_proto", srcs = [ "matrix/proto/line_id.proto", ], visibility = ["//visibility:public"], ) py_proto_library( name = "line_id_py_proto", srcs = [ "matrix/proto/line_id.proto", ], visibility = ["//visibility:public"], ) cc_proto_library( name = "proto_parser_cc_proto", srcs = [ "matrix/proto/feature.proto", "matrix/proto/proto_parser.proto", ], visibility = ["//visibility:public"], deps = [ ":line_id_cc_proto", ], ) py_proto_library( name = "proto_parser_py_proto", srcs = [ "matrix/proto/feature.proto", "matrix/proto/proto_parser.proto", ], visibility = ["//visibility:public"], deps = [ ":line_id_py_proto", ], ) cc_library( name = "compression_cc_float16", hdrs = ["matrix/compression/float16.h"], visibility = ["//visibility:public"], deps = ["//third_party/half_sourceforge_net:half"], ) cc_library( name = "compression", srcs = ["matrix/compression/compression.cc"], hdrs = [ "matrix/compression/compression.h", "matrix/compression/compression_qtz8mm.h", ], visibility = ["//visibility:public"], deps = [ ":compression_cc_float16", "@com_google_glog//:glog", ], ) cc_library( name = "compression_qtz8mm", srcs = ["matrix/compression/compression_qtz8mm.cc"], hdrs = ["matrix/compression/compression_qtz8mm.h"], visibility = ["//visibility:public"], deps = [ ":compression", ":compression_cc_float16", "@com_google_glog//:glog", ], ) cc_proto_library( name = "example_cc_proto", srcs = [ "matrix/proto/example.proto", ], visibility = ["//visibility:public"], deps = [ ":line_id_cc_proto", ], ) py_proto_library( name = "example_py_proto", srcs = [ "matrix/proto/example.proto", ], visibility = ["//visibility:public"], deps = [ ":line_id_py_proto", ], ) ================================================ FILE: idl/matrix/compression/compression.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "glog/logging.h" #include "idl/matrix/compression/compression.h" #include "idl/matrix/compression/float16.h" namespace matrix { namespace compression { using matrix::compression::Float16; bool compress_float_list_f16(const char* raw_data, const size_t raw_size, char* out_buffer, size_t* out_size) { if ((raw_size % sizeof(float)) != 0) { LOG(ERROR) << "compress_float_list_f16 got invalid input data"; return false; } size_t num = raw_size / sizeof(float); if (sizeof(Float16) * num > *out_size) { LOG(ERROR) << "compress_float_list_f16 out_buffer size not enough"; return false; } const float* raw_floats = reinterpret_cast(raw_data); Float16* f16_buffer = reinterpret_cast(out_buffer); *out_size = 0; for (size_t i = 0; i < num; ++i) { f16_buffer[i].set(raw_floats[i]); (*out_size) += sizeof(Float16); } return true; } bool compress_float_list_f16(const char* raw_data, const size_t raw_size, std::string* out) { if ((raw_size % sizeof(float)) != 0) { LOG(ERROR) << "compress_float_list_f16 got invalid input data"; return false; } size_t num = raw_size / sizeof(float); size_t out_size = num * sizeof(Float16); out->resize(out_size); const float* raw_floats = reinterpret_cast(raw_data); Float16* f16_buffer = reinterpret_cast(const_cast(out->data())); for (size_t i = 0; i < num; ++i) { f16_buffer[i].set(raw_floats[i]); } return true; } bool decompress_float_list_f16(const char* compressed_data, size_t compressed_size, char* out_buffer, size_t* out_size) { if ((compressed_size % sizeof(Float16)) != 0) { LOG(ERROR) << "decompress_float_list_f16 got invalid data"; return false; } size_t num = compressed_size / sizeof(Float16); if (sizeof(float) * num > *out_size) { LOG(ERROR) << "decompress_float_list_f16 got no enough out_buffer"; return false; } const Float16* f16_buffer = reinterpret_cast(compressed_data); float* out_floats = reinterpret_cast(out_buffer); *out_size = 0; for (size_t i = 0; i < num; ++i) { out_floats[i] = f16_buffer[i].get_m(); (*out_size) += sizeof(float); } return true; } bool decompress_float_list_f16(const char* compressed_data, size_t compressed_size, std::string* out) { if ((compressed_size % sizeof(Float16)) != 0) { LOG(ERROR) << "decompress_float_list_f16 got invalid data"; return false; } size_t num = compressed_size / sizeof(Float16); size_t out_size = num * sizeof(float); out->resize(out_size); const Float16* f16_buffer = reinterpret_cast(compressed_data); float* out_floats = reinterpret_cast(const_cast(out->data())); for (size_t i = 0; i < num; ++i) { out_floats[i] = f16_buffer[i].get_m(); } return true; } using bfloat16 = uint16_t; bool compress_float_list_f16b(const char* raw_data, const size_t raw_size, std::string* out) { if ((raw_size % sizeof(float)) != 0) { LOG(ERROR) << "compress_float_list_f16 got invalid input data"; return false; } size_t num = raw_size / sizeof(float); size_t out_size = num * sizeof(bfloat16); out->resize(out_size); const uint16_t* p = reinterpret_cast(raw_data); uint16_t* q = reinterpret_cast(const_cast(out->data())); for (; num != 0; p += 2, q++, num--) { *q = p[1]; } return true; } bool decompress_float_list_f16b(const char* compressed_data, size_t compressed_size, std::string* out) { if ((compressed_size % sizeof(bfloat16)) != 0) { LOG(ERROR) << "decompress_float_list_f16 got invalid data"; return false; } size_t num = compressed_size / sizeof(bfloat16); size_t out_size = num * sizeof(float); out->resize(out_size); const uint16_t* p = reinterpret_cast(compressed_data); uint16_t* q = reinterpret_cast(const_cast(out->data())); for (; num != 0; p++, q += 2, num--) { q[0] = 0; q[1] = *p; } return true; } } // end namespace compression } // end namespace matrix ================================================ FILE: idl/matrix/compression/compression.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef IDL_MATRIX_COMPRESSION_COMPRESSION_H_ #define IDL_MATRIX_COMPRESSION_COMPRESSION_H_ #include namespace matrix { namespace compression { bool compress_float_list_f16(const char* raw_data, const size_t raw_size, char* out_buffer, size_t* out_size); bool compress_float_list_f16(const char* raw_data, const size_t raw_size, std::string* out); bool decompress_float_list_f16(const char* compressed_data, size_t compressed_size, char* out_buffer, size_t* out_size); bool decompress_float_list_f16(const char* compressed_data, size_t compressed_size, std::string* out); bool compress_float_list_f16b(const char* raw_data, const size_t raw_size, std::string* out); bool decompress_float_list_f16b(const char* compressed_data, size_t compressed_size, std::string* out); // qtz8mm, with min/max qtz8, by liuyizhou // 注意raw_size不是float的数量,而是buffer长度,所以是float数量*4 // 前两个函数主要是给 wudi.yx 使用,对于一个vec的数据做操作 bool compress_float_list_qtz8mm(const char* raw_data, const size_t raw_size, std::string* out); bool decompress_float_list_qtz8mm(const char* compressed_data, size_t compressed_size, std::string* out); } // end namespace compression } // end namespace matrix #include "idl/matrix/compression/compression_qtz8mm.h" #endif // IDL_MATRIX_COMPRESSION_COMPRESSION_H_ ================================================ FILE: idl/matrix/compression/compression_qtz8mm.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "glog/logging.h" #include "idl/matrix/compression/compression.h" #include "idl/matrix/compression/compression_qtz8mm.h" using matrix::compression::Float16; namespace matrix { namespace compression { bool compress_float_list_qtz8mm(const char* raw_data, const size_t raw_size, std::string* out) { if ((raw_size % sizeof(float)) != 0) { LOG(ERROR) << "compress_float_list_f16 got invalid input data"; return false; } if (raw_size <= 4 * sizeof(float)) { // 如果长度 <= // 4时,由于min/max需要各占用2B,会导致qtz8mm相对于f16不节省内存,此时直接用f16 return compress_float_list_f16(raw_data, raw_size, out); } size_t num = raw_size / sizeof(float); size_t out_size = 2 * sizeof(Float16) + num * sizeof(uint8_t); out->resize(out_size); char* qtz_buf_ptr = const_cast(out->data()); const float* raw_floats = reinterpret_cast(raw_data); Float16* w_ptr = reinterpret_cast(qtz_buf_ptr); uint8_t* v_ptr = reinterpret_cast(qtz_buf_ptr + 2 * sizeof(Float16)); set_to_qtz8mm(raw_floats, num, num, w_ptr, v_ptr); return true; } bool decompress_float_list_qtz8mm(const char* compressed_data, size_t compressed_size, std::string* out) { if (compressed_size <= 4 * sizeof(Float16)) { // 长度<=8时,原始的数据长度<=4,直接用f16反解 return decompress_float_list_f16(compressed_data, compressed_size, out); } size_t num = compressed_size - 2 * sizeof(Float16); size_t out_size = num * sizeof(float); out->resize(out_size); const char* qtz_buf_ptr = static_cast(compressed_data); const Float16* w_ptr = reinterpret_cast(qtz_buf_ptr); const uint8_t* v_ptr = reinterpret_cast(qtz_buf_ptr + 2 * sizeof(Float16)); float* out_floats = reinterpret_cast(const_cast(out->data())); get_from_qtz8mm(out_floats, num, w_ptr, v_ptr); return true; } } // end namespace compression } // end namespace matrix ================================================ FILE: idl/matrix/compression/compression_qtz8mm.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef IDL_MATRIX_COMPRESSION_COMPRESSION_QTZ8MM_H_ #define IDL_MATRIX_COMPRESSION_COMPRESSION_QTZ8MM_H_ #include #include #include "idl/matrix/compression/float16.h" namespace matrix { namespace compression { inline void get_from_qtz8mm(float* dest, const int& dim, const matrix::compression::Float16* w, const uint8_t* v) { float min = w[0].get(); float max = w[1].get(); float step = (max - min) / 255; for (int i = 0; i < dim; ++i) { dest[i] = min + step * v[i]; } } inline void set_to_qtz8mm(const float* src_list, const int& data_size, // data_size 可能 #include #include #include #include #include "third_party/half_sourceforge_net/half.hpp" namespace matrix { namespace compression { class Float16 { public: Float16() {} Float16(const Float16& other) : value(other.value) {} Float16(float vf) { set(vf); } void set(float vf) { value = vf; } float get() const { float val = value; return std::isinf(val) ? ((val < 0) ? -65504 : 65504) : val; } /* * get value with random rounding value * * explain: * we want to store 1.23456 in a 16 bit unit, but because of truncation * what we really store is 1.234 * * get_r() will return (1.234 + random(0, 1) * 0.001) to mitigate the * truncation error */ float get_r() const { return get() + random_rounding_value(); } /* * get value with median rounding value * * explain: * we want to store 1.23456 in a 16 bit unit, but because of truncation * what we really store is 1.234 * * get_m() will return (1.234 + 0.0005) to mitigate the truncation error */ float get_m() const { if ((value.get_data() & 0x7FFF) == 0) return 0; else return get() + median_rounding_value(); } unsigned short get_raw_data() const { return value.get_data(); } private: half_float::half value; /* * random make use of Marsaglia's xorshf generator to generator float * number in [0, 1] * * About Marsaglia's xorshf generator, see * [stackoverflow](http://stackoverflow.com/a/1640399)] */ static float random(void) { static unsigned long x = 123456789, y = 362436069, z = 521288629; static unsigned long cnt = 0; if (!(cnt = (cnt + 1) & 0xFFFFFFFF)) { x = 123456789; y = 362436069; z = 521288629; } unsigned long t; x ^= x << 16; x ^= x >> 5; x ^= x << 1; t = x; x = y; y = z; z = t ^ x ^ y; return (double)(z & 0xFFFFFFFF) / (unsigned long)(0xFFFFFFFF); } float random_rounding_value() const { static constexpr float v[64] = { std::pow(2, -25), std::pow(2, -24), std::pow(2, -23), std::pow(2, -22), std::pow(2, -21), std::pow(2, -20), std::pow(2, -19), std::pow(2, -18), std::pow(2, -17), std::pow(2, -16), std::pow(2, -15), std::pow(2, -14), std::pow(2, -13), std::pow(2, -12), std::pow(2, -11), std::pow(2, -10), std::pow(2, -9), std::pow(2, -8), std::pow(2, -7), std::pow(2, -6), std::pow(2, -5), std::pow(2, -4), std::pow(2, -3), std::pow(2, -2), std::pow(2, -1), std::pow(2, 0), std::pow(2, 1), std::pow(2, 2), std::pow(2, 3), std::pow(2, 4), std::pow(2, 5), std::pow(2, 6), -std::pow(2, -25), -std::pow(2, -24), -std::pow(2, -23), -std::pow(2, -22), -std::pow(2, -21), -std::pow(2, -20), -std::pow(2, -19), -std::pow(2, -18), -std::pow(2, -17), -std::pow(2, -16), -std::pow(2, -15), -std::pow(2, -14), -std::pow(2, -13), -std::pow(2, -12), -std::pow(2, -11), -std::pow(2, -10), -std::pow(2, -9), -std::pow(2, -8), -std::pow(2, -7), -std::pow(2, -6), -std::pow(2, -5), -std::pow(2, -4), -std::pow(2, -3), -std::pow(2, -2), -std::pow(2, -1), -std::pow(2, 0), -std::pow(2, 1), -std::pow(2, 2), -std::pow(2, 3), -std::pow(2, 4), -std::pow(2, 5), -std::pow(2, 6), }; return v[value.get_data() >> 10] * random(); } float median_rounding_value() const { static constexpr float v[64] = { std::pow(2, -26), std::pow(2, -25), std::pow(2, -24), std::pow(2, -23), std::pow(2, -22), std::pow(2, -21), std::pow(2, -20), std::pow(2, -19), std::pow(2, -18), std::pow(2, -17), std::pow(2, -16), std::pow(2, -15), std::pow(2, -14), std::pow(2, -13), std::pow(2, -12), std::pow(2, -11), std::pow(2, -10), std::pow(2, -9), std::pow(2, -8), std::pow(2, -7), std::pow(2, -6), std::pow(2, -5), std::pow(2, -4), std::pow(2, -3), std::pow(2, -2), std::pow(2, -1), std::pow(2, 0), std::pow(2, 1), std::pow(2, 2), std::pow(2, 3), std::pow(2, 4), std::pow(2, 5), -std::pow(2, -26), -std::pow(2, -25), -std::pow(2, -24), -std::pow(2, -23), -std::pow(2, -22), -std::pow(2, -21), -std::pow(2, -20), -std::pow(2, -19), -std::pow(2, -18), -std::pow(2, -17), -std::pow(2, -16), -std::pow(2, -15), -std::pow(2, -14), -std::pow(2, -13), -std::pow(2, -12), -std::pow(2, -11), -std::pow(2, -10), -std::pow(2, -9), -std::pow(2, -8), -std::pow(2, -7), -std::pow(2, -6), -std::pow(2, -5), -std::pow(2, -4), -std::pow(2, -3), -std::pow(2, -2), -std::pow(2, -1), -std::pow(2, 0), -std::pow(2, 1), -std::pow(2, 2), -std::pow(2, 3), -std::pow(2, 4), -std::pow(2, 5), }; return v[value.get_data() >> 10]; } }; } // end namespace compression } // end namespace matrix #endif /* FLOAT_16_H */ ================================================ FILE: idl/matrix/proto/example.proto ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. syntax = "proto3"; package monolith.io.proto; option cc_enable_arenas = true; import "idl/matrix/proto/line_id.proto"; message FidList { repeated fixed64 value = 1; } message FidLists { repeated FidList list = 1; } message FloatList { repeated float value = 1; } message FloatLists { repeated FloatList list = 1; } message DoubleList { repeated double value = 1; } message DoubleLists { repeated DoubleList list = 1; } message Int64List { repeated int64 value = 1; } message Int64Lists { repeated Int64List list = 1; } message BytesList { repeated bytes value = 1; } message BytesLists { repeated BytesList list = 1; } // Basic extracted features message Feature { oneof type { FidList fid_v1_list = 1; FidList fid_v2_list = 2; FloatList float_list = 3; DoubleList double_list = 4; Int64List int64_list = 5; BytesList bytes_list = 6; FidLists fid_v2_lists = 7; FloatLists float_lists = 8; DoubleLists double_lists = 9; Int64Lists int64_lists = 10; BytesLists bytes_lists = 11; int64 int64_value = 12; float float_value = 13; double double_value = 14; bytes bytes_value = 15; FidLists fid_v1_lists = 16; } } // Feature map for easy retrieval message FeatureMap { map feature_map = 1; } // Raw features, or intermediate results during extraction message RawFeature { repeated Feature feature = 1; } // ---------ColumnMajor definitions---------- enum FeatureListType { INDIVIDUAL = 0; // each example has its own value SHARED = 1; // all examples share the same value } message NamedFeatureList { int32 id = 4; string name = 1; repeated Feature feature = 2; FeatureListType type = 3; } message NamedRawFeatureList { int32 id = 4; string name = 1; repeated RawFeature raw_feature = 2; FeatureListType type = 3; } // column major examples message ExampleBatch { repeated NamedFeatureList named_feature_list = 1; repeated NamedRawFeatureList named_raw_feature_list = 2; int32 batch_size = 3; uint32 data_source_key = 100; } // ---------RowMajor definitions---------- message NamedFeature { int32 id = 3; string name = 1; Feature feature = 2; int32 sorted_id = 6; } message NamedRawFeature { int32 id = 3; string name = 1; RawFeature raw_feature = 2; } // Example for both online and offline message Example { repeated NamedFeature named_feature = 1; repeated NamedRawFeature named_raw_feature = 2; idl.matrix.proto.LineId line_id = 100; repeated float label = 101; float instance_weight = 102; uint32 data_source_key = 103; } message ExampleBatchRowMajor { repeated NamedFeature shared_feature = 1; repeated NamedRawFeature shared_raw_feature = 2; repeated Example example = 3; } message FeatureData { int64 gid = 1; repeated int64 fids = 2; repeated NamedFeature feature_columns = 3; int64 origin_cnt = 4; int64 sample_cnt = 5; } message ChannelCache { int64 channel_id = 1; repeated FeatureData feature_datas = 2; } message FilterValues { oneof type { FloatList float_list = 1; Int64List int64_list = 2; BytesList bytes_list = 3; } } enum PoolingType { SUM = 0; MEAN = 1; FIRSTN = 3; } enum OutType { CONCAT = 0; STACK = 1; ADDN = 2; NONE = 3; } message SliceConfig { string feature_name = 1; int32 start = 2; int32 end = 3; int32 feature_idx = 4; int32 slice_idx = 5; PoolingType pooling_type = 10; int32 max_sequence_length = 11; } message TensorShape { repeated int32 dims = 1; } message OutConfig { repeated SliceConfig slice_configs = 1; OutType out_type = 2; repeated TensorShape shape = 3; } message FeatureConfig { string table = 1; PoolingType pooling_type = 2; repeated int32 slice_dims = 3; int32 max_sequence_length = 4; } message FeatureConfigs { map feature_configs = 1; map out_configs = 2; } ================================================ FILE: idl/matrix/proto/feature.proto ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. syntax = "proto2"; package idl.matrix.proto; option java_outer_classname = "FeatureProto"; // 序列特征,文档见下面链接 // 离散序列特征 message Fixed64List { repeated fixed64 value = 1 [packed = true]; } // 浮点型连续值序列特征 message FloatList { repeated float value = 1 [packed = true]; } // 整型连续值序列特征 message Int64List { repeated int64 value = 1 [packed = true]; } // 原始值序列特征 message BytesList { repeated bytes value = 1; } message Feature { // feature column name // 一定要有名字,否则无法使用。 // 名字是唯一的,以 fc_ 开头。 optional string name = 1; // 以下字段只使用其中一个。离散值和连续值特征都是有序的。 // 如果要对 fid 赋权,需要分在两个 feature column 中,顺序对应。 // oneof { // 离散 id 化特征 repeated fixed64 fid = 2 [packed = true]; // 连续值特征 repeated float float_value = 3 [packed = true]; repeated int64 int64_value = 4 [packed = true]; // 原始特征 repeated bytes bytes_value = 5; // 以下为序列特征,表达一个序列对应的离散或连续特征 repeated Fixed64List fid_list = 6; repeated FloatList float_list = 7; repeated Int64List int64_list = 8; repeated BytesList bytes_list = 9; // } // oneof } ================================================ FILE: idl/matrix/proto/line_id.proto ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. syntax = "proto2"; package idl.matrix.proto; option java_outer_classname = "LineIdProto"; message MapStringFloatEntry { optional string key = 1; optional float value = 2; } message LineId { optional fixed64 uid = 2; optional int64 req_time = 3; optional fixed64 item_id = 4; repeated int32 actions = 6 [packed = true]; optional int64 chnid = 19; repeated int32 pre_actions = 23 [packed = true]; optional float sample_rate = 27 [default = 1.0]; repeated int32 special_strategies = 39; optional string device_type = 41; optional string cid = 48; optional string user_id = 49; optional bool is_draw = 87; optional int32 rank = 145; optional string data_source_name = 235; } ================================================ FILE: idl/matrix/proto/proto_parser.proto ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. syntax = "proto2"; package parser.proto; import "idl/matrix/proto/feature.proto"; import "idl/matrix/proto/line_id.proto"; message Instance { repeated fixed64 fid = 1 [packed=true]; repeated float value = 2 [packed=true]; repeated float label = 3 [packed=true]; optional float instance_weight = 4; optional idl.matrix.proto.LineId line_id = 5; // deprecated, move to feature columns repeated float dense = 6 [packed=true, deprecated=true]; repeated LabelTag label_tag = 7; repeated fixed64 next_fid = 8 [packed=true]; // feature columns repeated idl.matrix.proto.Feature feature = 9; optional uint32 data_source_key = 100; } message InstanceWrapper { // Serialized `Instance` message optional bytes instance = 1; // Which data source this Instance comes from. optional string data_source = 2; } message LabelTag { optional int32 key = 1; optional float val = 2; } message Request { optional string req_id = 1; optional int32 ut = 2; optional fixed64 uid = 3; optional int64 req_time = 4; repeated Instance instances = 5; optional string user = 6; } ================================================ FILE: markdown/demo/AWS-EKS.md ================================================ # Distributed async training on EKS To scale to multiple machines and handle failure recovery, we can utilize container orchestration frameworks such as yarn and kubernetes. Regradless what tool you use, as long as the `TF_CONFIG` environment variable is correctly set for each worker and ps, it will work just fine. In this tutorial, we will show how to setup distributed training using kubernetes, kubeflow, and AWS's elastic kubernetes service (EKS). Kubeflow is used as the middleware that injects `TF_CONFIG` environment variable for each worker container. ## Prerequisite Setup kubeflow on AWS by following the official guide. It will also help you to setup other tools such as aws cli and eksctl. Make sure to complete - Prerequisites - Create an EKS Cluster - Vanilla Installation https://awslabs.github.io/kubeflow-manifests/docs/deployment/ ## Prepare monolith docker TODO ## Write Spec and launch training If you have completed all the prerequisites, `kubectl` should be able to connect to your cluster on AWS. Now, create a spec file called `aws-tfjob.yaml`. ```yaml apiVersion: "kubeflow.org/v1" kind: "TFJob" metadata: name: "monolith-train" namespace: kubeflow spec: runPolicy: cleanPodPolicy: None tfReplicaSpecs: Worker: replicas: 4 restartPolicy: Never template: metadata: annotations: # solve RBAC permission problem sidecar.istio.io/inject: "false" spec: containers: - name: tensorflow image: YOUR_IMAGE args: - --model_dir=/tmp/model PS: replicas: 4 restartPolicy: Never template: metadata: annotations: sidecar.istio.io/inject: "false" spec: containers: - name: tensorflow image: YOUR_IMAGE args: - --model_dir=/tmp/model ``` Then, launch training: ```bash kubectl apply -f aws-tfjob.yaml ``` To view the status of workers, you can use ```bash # use this to list pods kubectl --namespace kubeflow get pods # use this get a log of a worker kubectl --namespace kubeflow logs monolith-train-worker-0 ``` Of course, there are other middlewares built on top of kubeflow to better help you to keep track of the training progress. Monolith's compatibility with tensorflow means that tools that are built for tensorflow will likely work with Monolith too. ================================================ FILE: markdown/demo/BUILD ================================================ load("@rules_python//python:defs.bzl", "py_binary", "py_library") package(default_visibility = ["//visibility:public"]) py_binary( name = "ml_dataset", srcs = ["ml_dataset.py"], deps = [ "@org_tensorflow//tensorflow:tensorflow_py", ], ) py_binary( name = "kafka_producer", srcs = ["kafka_producer.py"], deps = [ "@org_tensorflow//tensorflow:tensorflow_py", ":ml_dataset", ], ) py_binary( name = "kafka_receiver", srcs = ["kafka_receiver.py"], deps = [ "//monolith/native_training:native_model", ], ) py_binary( name = "demo_model", srcs = ["demo_model.py"], deps = [ ":kafka_producer", ":kafka_receiver", "//monolith/native_training:native_model", ], ) py_binary( name = "demo_local_runner", srcs = ["demo_local_runner.py"], deps = [ ":demo_model", ], ) ================================================ FILE: markdown/demo/Batch.md ================================================ # Movie Ranking Batch Training This tutorial demonstrates how to use Monolith to perform a movie ranking task. This tutorial is essentially the same as [Tensorflow's tutorial on movie ranking](https://www.tensorflow.org/recommenders/examples/basic_ranking), but with Monolith's API. Through this tutorial, you'll learn the similarity and differences between Monolith and native Tensorflow. Additionally, we'll showcase how batching training and stream training is done with Monolith. ## Building the Model Source code: [kafka_producer.py](./kafka_producer.py) ### Monolith Model API ```python class MovieRankingModel(MonolithModel): def __init__(self, params): super().__init__(params) self.p = params self.p.serving.export_when_saving = True def input_fn(self, mode): return dataset def model_fn(self, features, mode): # features = return EstimatorSpec(...) def serving_input_receiver_fn(self): return tf.estimator.export.ServingInputReceiver({...}) ``` A monolith model follows the above template. `input_fn` returns an instance of tf.data.Dataset. `model_fn` builds the graph for the forward pass and returns an EstimatorSpec. The `features` argument is an item from the dataset returned by the `input_fn`. Finally, if you want to serve the model, you need to implement the `serving_input_receiver_fn`. ### Prepare the dataset We can use tfds to load dataset. Then, we select the features that we're going to use from the dataset, and do some preprocessing. In our case, we need to convert user ids and movie titles from strings to unique integer ids. ```python def get_preprocessed_dataset(size='100k') -> tf.data.Dataset: ratings = tfds.load(f"movielens/{size}-ratings", split="train") # For simplicity, we map each movie_title and user_id to numbers # by hashing. You can use other ways to number them to avoid # collision and better leverage Monolith's collision-free hash tables. max_b = (1 << 63) - 1 return ratings.map(lambda x: { 'mov': tf.strings.to_hash_bucket_fast([x['movie_title']], max_b), 'uid': tf.strings.to_hash_bucket_fast([x['user_id']], max_b), 'label': tf.expand_dims(x['user_rating'], axis=0) }) ``` ### Write input_fn for batch training To enable distributed training, our `input_fn` first shard the dataset according to total number of workers, then batch. Note that Monolith requires sparse features to be ragged tensors, so a .map(to_ragged) is required if this isn't the case. ```python def to_ragged(x): return { 'mov': tf.RaggedTensor.from_tensor(x['mov']), 'uid': tf.RaggedTensor.from_tensor(x['uid']), 'label': x['label'] } def input_fn(self, mode): env = json.loads(os.environ['TF_CONFIG']) cluster = env['cluster'] worker_count = len(cluster.get('worker', [])) + len(cluster.get('chief', [])) dataset = get_preprocessed_dataset('25m') dataset = dataset.shard(worker_count, env['task']['index']) return dataset.batch(512, drop_remainder=True)\ .map(to_ragged).prefetch(tf.data.AUTOTUNE) ``` ### Build the model ```python def model_fn(self, features, mode): # for sparse features, we declare an embedding table for each of them for s_name in ["mov", "uid"]: self.create_embedding_feature_column(s_name) mov_embedding, user_embedding = self.lookup_embedding_slice( features=['mov', 'uid'], slice_name='vec', slice_dim=32) ratings = tf.keras.Sequential([ # Learn multiple dense layers. tf.keras.layers.Dense(256, activation="relu"), tf.keras.layers.Dense(64, activation="relu"), # Make rating predictions in the final layer. tf.keras.layers.Dense(1) ]) rank = ratings(tf.concat((user_embedding, mov_embedding), axis=1)) label = features['label'] loss = tf.reduce_mean(tf.losses.mean_squared_error(rank, label)) optimizer = tf.compat.v1.train.AdagradOptimizer(0.05) return EstimatorSpec( label=label, pred=rank, head_name="rank", loss=loss, optimizer=optimizer, classification=False ) ``` In `model_fn`, we use `self.create_embedding_feature_column(feature_name)` to declare a embedding table for each of the feature name that requires an embedding. In our case, they are `mov` and `uid`. Note that the these feature names must match what the `input_fn` provides. Then, we use `self.lookup_embedding_slice` to lookup the embeddings at once. If your features require different embedding length, then you can use multiple calls to `self.lookup_embedding_slice`. The rest is straightforward and is identical to how you do it in native tensorflow in graph mode. Finally, we return an `EstimatorSpec`. This `EstimatorSpec` is a wrapped version of `tf.estimator.EstimatorSpec` and thus has more fields. ## Run distributed batch training locally There're multiple ways to setup a distributed training. In this tutorial, we'll use the parameter server (PS) training strategy. In this strategy, model weights are partitioned across PS, and workers read data and pull weights from PS and do training. While we usually run distributed training on top of a job scheduler such as YARN and Kubernetes, it can be done locally too. To launch a training, we start multiple processes, some of which are workers and some of which are PS. Tensorflow uses a `TF_CONFIG` variable to define a cluster and the role of the current process in the cluster. This environment variable also enables service discovery between worker and PS. Example of a `TF_CONFIG`: ```python os.environ["TF_CONFIG"] = json.dumps({ "cluster": { "worker": ["host1:port", "host2:port", "host3:port"], "ps": ["host4:port", "host5:port"] }, "task": {"type": "worker", "index": 1} }) ``` We provide a script for this: [demo_local_runner.py](./demo_local_runner.py). To run batch training, simply do ```bash bazel run //markdown/demo:demo_local_runner -- --training_type=batch ``` ================================================ FILE: markdown/demo/README.md ================================================ # Monolith demo model and tutorials This is a 3-part tutorial for building monolith models and launch training. ### [Part 1: building a model and launch distributed async batch training](./Batch.md) ### [Part 2: training with streaming input data](./Stream.md) ### [Part 3: launching distributed async training on the cloud](./AWS-EKS.md) ================================================ FILE: markdown/demo/Stream.md ================================================ # Stream training tutorial > This tutorial depends on the batching training tutorial. Please read it first if you haven't. Monolith supports reading input data from Kafka stream. To add stream training support to your model, simply change the `input_fn` and read data from a KafkaDataset. ## Kafka producer Source code: [kafka_producer.py](./kafka_producer.py) Let's create a kafka producer for our movie-lens dataset. Kafka requires serializing everything to bytes, so we convert each data item in the dataset to String by putting them into the standard Tensorflow Example protobuf. ```python def serialize_one(data): # serialize an training instance to string return tf.train.Example(features=tf.train.Features( feature={ 'mov': tf.train.Feature(int64_list=tf.train.Int64List(value=data['mov'])), 'uid': tf.train.Feature(int64_list=tf.train.Int64List(value=data['uid'])), 'label': tf.train.Feature(float_list=tf.train.FloatList(value=data['label'])) } )).SerializeToString() ``` Then, we create a KafkaProducer, iterate over the dataset, serializing each item and write it to the desired kafka topic. ```python if __name__ == "__main__": ds = get_preprocessed_dataset() producer = KafkaProducer(bootstrap_servers=['127.0.0.1:9092']) for count, val in tqdm(enumerate(ds), total=len(ds)): # note: we omit error callback here for performance producer.send( "movie-train", key=str(count).encode('utf-8'), value=serialize_one(val), headers=[]) producer.flush() ``` ## Kafka consumer in the input_fn Source code: [kafka_receiver.py](./kafka_receiver.py) and [demo_model.py](./demo_model.py) Since the kafka stream contains serialized `tf.train.Example`, we can use `tf.io.parse_example` to parse multiple of them at once. ```python def decode_example(v): x = tf.io.parse_example(v, raw_feature_desc) return to_ragged(x) ``` In the `input_fn`, we use the Monolith's utility function to create a kafka dataset, and use the function above the decode. The parameter `poll_batch_size` determines the how many serialized `Example` we should batch before sending them to `decode_example`. It effectively means the training batch size. ```python def input_fn(self, mode): dataset = create_plain_kafka_dataset(topics=["movie-train"], group_id="cgonline", servers="127.0.0.1:9092", stream_timeout=10000, poll_batch_size=16, configuration=[ "session.timeout.ms=7000", "max.poll.interval.ms=8000" ], ) return dataset.map(lambda x: decode_example(x.message)) ``` ================================================ FILE: markdown/demo/demo_local_runner.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import subprocess from typing import List import time from absl import app from absl import flags import os from monolith.native_training import yarn_runtime from socket import socket import json flags.DEFINE_enum('training_type', 'batch', ['batch', 'stream'], "type of training to launch") FLAGS = flags.FLAGS occupied_ports = set() def get_rand_port(): # this function returns a unique unused port while True: with socket() as s: s.bind(('',0)) port = s.getsockname()[1] if port not in occupied_ports: occupied_ports.add(port) return port def launch_workers(num_ps: int, num_workers: int): args = [ "markdown/demo/demo_model", f"--training_type={FLAGS.training_type}", "--model_dir=/tmp/movie_lens_tutorial", "--model_name=movie_lens_tutorial" ] assert num_workers > 1, "must have more than 1 workers" ip = yarn_runtime.get_local_host() ps_addrs = [f'{ip}:{get_rand_port()}' for i in range(num_ps)] worker_addrs = [f'{ip}:{get_rand_port()}' for i in range(num_workers)] env = os.environ.copy() tf_config = { "cluster": { "worker": worker_addrs, "ps": ps_addrs, } } processes = [] for i in range(num_ps): tf_config['task'] = {"type": "ps", "index": i} env['TF_CONFIG'] = json.dumps(tf_config) processes.append(subprocess.Popen(args, env=env)) for i in range(num_workers): tf_config['task'] = {"type": "worker", "index": i} env['TF_CONFIG'] = json.dumps(tf_config) processes.append(subprocess.Popen(args, env=env)) if i == 0: time.sleep(2) return processes def main(_): num_ps = 2 num_workers = 2 processes = launch_workers( num_ps, num_workers ) try: for p in processes: p.wait() finally: for p in processes: p.kill() if __name__ == "__main__": app.run(main) ================================================ FILE: markdown/demo/demo_model.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from absl import app from absl import flags from absl import logging import json import os import sys import tensorflow as tf from kafka_receiver import decode_example, to_ragged from ml_dataset import get_preprocessed_dataset from monolith.native_training.estimator import EstimatorSpec, Estimator, RunnerConfig, ServiceDiscoveryType from monolith.native_training.native_model import MonolithModel from monolith.native_training.data.datasets import create_plain_kafka_dataset flags.DEFINE_enum('training_type', 'batch', ['batch', 'stream', 'stdin'], "type of training to launch") FLAGS = flags.FLAGS def get_worker_count(env: dict): cluster = env['cluster'] worker_count = len(cluster.get('worker', [])) + len(cluster.get('chief', [])) assert worker_count > 0 return worker_count class MovieRankingModelBase(MonolithModel): def __init__(self, params): super().__init__(params) self.p = params def model_fn(self, features, mode): # for sparse features, we declare an embedding table for each of them for s_name in ["mov", "uid"]: self.create_embedding_feature_column(s_name) mov_embedding, user_embedding = self.lookup_embedding_slice( features=['mov', 'uid'], slice_name='vec', slice_dim=32) ratings = tf.keras.Sequential([ # Learn multiple dense layers. tf.keras.layers.Dense(256, activation="relu"), tf.keras.layers.Dense(64, activation="relu"), # Make rating predictions in the final layer. tf.keras.layers.Dense(1) ]) concated = tf.concat((user_embedding, mov_embedding), axis=1) rank = ratings(concated) label = features['label'] loss = tf.reduce_mean(tf.losses.mean_squared_error(rank, label)) optimizer = tf.compat.v1.train.AdagradOptimizer(0.05) return EstimatorSpec( label=label, pred=rank, head_name="rank", loss=loss, optimizer=optimizer, classification=False ) def serving_input_receiver_fn(self): # a dummy serving input receiver return tf.estimator.export.ServingInputReceiver({}) class MovieRankingBatchTraining(MovieRankingModelBase): def input_fn(self, mode): env = json.loads(os.environ['TF_CONFIG']) dataset = get_preprocessed_dataset('1m') dataset = dataset.shard(get_worker_count(env), env['task']['index']) return dataset.batch(512, drop_remainder=True)\ .map(to_ragged).prefetch(tf.data.AUTOTUNE) class MovieRankingStreamTraining(MovieRankingModelBase): def input_fn(self, mode): dataset = create_plain_kafka_dataset(topics=["movie-train"], group_id="cgonline", servers="127.0.0.1:9092", stream_timeout=10000, poll_batch_size=16, configuration=[ "session.timeout.ms=7000", "max.poll.interval.ms=8000" ], ) return dataset.map(lambda x: decode_example(x.message)) def read_stdin(): while True: line = sys.stdin.readline() if line: tokens = line.strip().split(',') yield { 'mov': [int(tokens[0])], 'uid': [int(tokens[1])], 'label': float(tokens[2]) } else: return class MovieRankingBatchTrainingStdin(MovieRankingModelBase): def input_fn(self, mode): return tf.data.Dataset.from_generator( read_stdin, output_signature={ 'mov': tf.TensorSpec(shape=(1,), dtype=tf.int64), 'uid': tf.TensorSpec(shape=(1,), dtype=tf.int64), 'label': tf.TensorSpec(shape=(), dtype=tf.float32), } ).batch(512, drop_remainder=True)\ .map(to_ragged).prefetch(tf.data.AUTOTUNE) FLAGS = flags.FLAGS def main(_): tf.compat.v1.disable_eager_execution() raw_tf_conf = os.environ['TF_CONFIG'] tf_conf = json.loads(raw_tf_conf) config = RunnerConfig( discovery_type=ServiceDiscoveryType.PRIMUS, tf_config=raw_tf_conf, save_checkpoints_steps=10000, enable_model_ckpt_info=True, num_ps=len(tf_conf['cluster']['ps']), num_workers=get_worker_count(tf_conf), server_type=tf_conf['task']['type'], index=tf_conf['task']['index'] ) if FLAGS.training_type == "batch": params = MovieRankingBatchTraining.params().instantiate() elif FLAGS.training_type == "stdin": params = MovieRankingBatchTrainingStdin.params().instantiate() else: params = MovieRankingStreamTraining.params().instantiate() estimator = Estimator(params, config) estimator.train(max_steps=1000000) if __name__ == '__main__': logging.set_verbosity(logging.INFO) app.run(main) ================================================ FILE: markdown/demo/kafka_producer.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from ml_dataset import get_preprocessed_dataset, serialize_one from tqdm import tqdm from kafka import KafkaProducer if __name__ == "__main__": ds = get_preprocessed_dataset() producer = KafkaProducer(bootstrap_servers=['127.0.0.1:9092']) for count, val in tqdm(enumerate(ds), total=len(ds)): # note: we omit error callback here for performance producer.send( "movie-train", key=str(count).encode('utf-8'), value=serialize_one(val), headers=[]) producer.flush() ================================================ FILE: markdown/demo/kafka_receiver.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import tensorflow as tf from monolith.native_training.data.datasets import create_plain_kafka_dataset raw_feature_desc = { 'mov': tf.io.FixedLenFeature([1], tf.int64), 'uid': tf.io.FixedLenFeature([1], tf.int64), 'label': tf.io.FixedLenFeature([], tf.float32) } def to_ragged(x): return { 'mov': tf.RaggedTensor.from_tensor(x['mov']), 'uid': tf.RaggedTensor.from_tensor(x['uid']), 'label': x['label'] } # corresponds to serailize_one in kafka_producer.py def decode_example(v): x = tf.io.parse_example(v, raw_feature_desc) return to_ragged(x) if __name__ == "__main__": dataset = create_plain_kafka_dataset(topics=["movie-train"], group_id="cgonline", servers="127.0.0.1:9092", stream_timeout=10000, # in milliseconds, to block indefinitely, set it to -1. poll_batch_size=8, configuration=[ "session.timeout.ms=7000", "max.poll.interval.ms=8000" ], ) for x in dataset.map(lambda x: decode_example(x.message)): print(x) ================================================ FILE: markdown/demo/kafka_utils/add_data_topics.sh ================================================ #!/bin/bash source ./kafka_base.sh $KAFKA_PATH/bin/kafka-topics.sh --create --bootstrap-server 127.0.0.1:9092 --replication-factor 1 --partitions 6 --topic movie-train $KAFKA_PATH/bin/kafka-topics.sh --describe --bootstrap-server 127.0.0.1:9092 --topic movie-train ================================================ FILE: markdown/demo/kafka_utils/delete_topics.sh ================================================ #!/bin/bash source ./kafka_base.sh $KAFKA_PATH/bin/kafka-topics.sh --delete --bootstrap-server 127.0.0.1:9092 --topic movie-train ================================================ FILE: markdown/demo/kafka_utils/kafka_base.sh ================================================ #!/bin/bash export KAFKA_PATH=$HOME/kafka_2.13-2.8.1 ================================================ FILE: markdown/demo/kafka_utils/start_broker.sh ================================================ #!/bin/bash source ./kafka_base.sh bash $KAFKA_PATH/bin/zookeeper-server-start.sh -daemon $KAFKA_PATH/config/zookeeper.properties sleep 10 bash $KAFKA_PATH/bin/kafka-server-start.sh -daemon $KAFKA_PATH/config/server.properties ================================================ FILE: markdown/demo/ml_dataset.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import tensorflow as tf import tensorflow_datasets as tfds from tqdm import tqdm from multiprocessing import Process, cpu_count def get_preprocessed_dataset(size='1m') -> tf.data.Dataset: ratings = tfds.load(f"movielens/{size}-ratings", split="train") # For simplicity, we map each movie_title and user_id to numbers # by Hashing. You can use other ways to number them to avoid # collision and better leverage Monolith's collision-free hash tables. max_b = (1 << 63) - 1 return ratings.map(lambda x: { 'mov': tf.strings.to_hash_bucket_fast([x['movie_title']], max_b), 'uid': tf.strings.to_hash_bucket_fast([x['user_id']], max_b), 'label': tf.expand_dims(x['user_rating'], axis=0) }) def serialize_one(data): # serialize an training instance to string return tf.train.Example(features=tf.train.Features( feature={ 'mov': tf.train.Feature(int64_list=tf.train.Int64List(value=data['mov'])), 'uid': tf.train.Feature(int64_list=tf.train.Int64List(value=data['uid'])), 'label': tf.train.Feature(float_list=tf.train.FloatList(value=data['label'])) } )).SerializeToString() # serialize to human readable (csv) format def serialize_hr(data): return f"{data['mov']},{data['uid']},{data['label']}\n" def save_one_shard(total_shards, pid, start, end): ds = get_preprocessed_dataset('1m').map(lambda x: { 'mov': tf.squeeze(x['mov']), 'uid': tf.squeeze(x['uid']), 'label': tf.squeeze(x['label']) }) pbar = tqdm(position=pid, desc="[Serializing]") for i in range(start, end): ds_shard = ds.shard(total_shards, i).as_numpy_iterator() with open(f"data_1m/part_{i}.csv", "w") as f: for item in ds_shard: f.write(serialize_hr(item)) pbar.update() if __name__ == "__main__": # just let TF download this dataset if it doesn't exist ds = get_preprocessed_dataset('1m') for _ in ds.take(1): pass total_shards = 4 num_process = min(max(cpu_count() // 4, 1), total_shards) processes = [] shards_per_p = total_shards // num_process for i in range(num_process): # note: this multiprocessing is not very efficient because .shard needs to skip elements p = Process(target=save_one_shard, args=(total_shards, i, shards_per_p * i, shards_per_p * (i + 1))) p.start() processes.append(p) for p in processes: p.join() ================================================ FILE: markdown/input_and_model_fn.md ================================================ # Monolith `input_fn` and `model_fn` This is guide on how to setup `input_fn` and using Monolith's embedding hash table in `model_fn` ## How to create an `input_fn` An important part of `MonolithModel` is the input function. It has two requirements: 1. It needs to return an instance of anything that inherits from `tf.data.Dataset`. 2. When this instance is iterated over batch by batch, it yields a dict containing sparse ids and dense data. The keys should be feature names. 3. Sparse ids must be instance of `tf.RaggedTensor` with dtype `tf.int64`, and the remaining values in the dict are treated as dense features The reason sparse ids must be RaggedTensor is that they can vary in length bewteen different training instance. For example, consider a dataset like this ```python { 'user_id': 15, 'gender': 0, 'recently_liked_videos': [1, 2, 3] } ``` The feature `recently_liked_videos` may vary in length, so when we batch these training instances, the resulting tensor is a RaggedTensor of 2 dimensions. The first dimension is the batch dimension, and the second dimension is ragged. A constant dataset returning a single **batch** of data where batch_size=2 may look like this ```python def input_fn(self, mode): features = { "mov": tf.ragged.constant([[155], [13]], dtype=tf.int64), # sparse feature "uid": tf.ragged.constant([[324], [75]], dtype=tf.int64), # sparse feature "ratings": tf.constant([5.0, 2.0], dtype=tf.float32) # dense feature } return tf.data.Dataset.from_tensors(features) ``` ## `model_fn` The model function's argument `features` is exactly what the dataset `input_fn` returns when iterated over. To lookup the embeddings corresponding to the sparse features, we first define the configuration for each embedding table by using `self.create_embedding_feature_column(sparse_feature_name)`, where `sparse_feature_name` is one of the sparse feature returned in the dataset. ```python def model_fn(self, features, mode): for feature_name in ["mov", "uid"]: self.create_embedding_feature_column(feature_name) ``` Then we lookup the embeddings corresponding to each sparse feature with `self.lookup_embedding_slice`. We can lookup embeddings from multiple tables at once by specifying the list of feature names. ```python mov_embedding, user_embedding = self.lookup_embedding_slice( features=['mov', 'uid'], slice_name='vec', slice_dim=32) ``` Note that we do not use `features` directly to obtain the sparse ids here, as it is handled internally through `self.lookup_embedding_slice`. To get dense features, simply use the `features` dictionary ```python ratings = features['ratings'] ``` ## TFRecordDataset It is a common practice to prepare the dataset in `tf.train.Example` format, and then stored as a `TFRecordDataset`. In this way, the dataset can be parsed as easily as ```python def input_fn(self, mode): raw_feature_desc = { 'mov': tf.io.VarLenFeature(tf.int64), 'uid': tf.io.VarLenFeature(tf.int64), 'label': tf.io.FixedLenFeature([], tf.float32) } def decode_example(v): return tf.io.parse_example(v, raw_feature_desc) return tf.data.TFRecordDataset([PATH_TO_YOUR_DATASET]).batch(BATCH_SIZE).map(decode_example) ``` Where `tf.io.parse_example` automatically parses batches of `tf.train.Example`, converting `VarLenFeature` to ragged tensors and the remaining to regular tensors. ## Final note As long as your dataset adheres to the requirements above, it shouldn't be a issue. You can also leverage any kinds of dataset that tensorflow provides. For more informaiton, please refer to the official tensorflow documentation. ================================================ FILE: markdown/primus_demo/README.md ================================================ # Monolith x Primus Demo ## Setup Primus Follow the primus quickstart guide to setup the primus baseline virtual machine: https://github.com/bytedance/primus/blob/master/docs/primus-quickstart.md ## Setup Monolith In your virtual machine, clone the open source monolith ```bash cd git clone https://github.com/bytedance/monolith ``` ### Prepare monolith image ```bash cd monolith/markdown/primus_demo docker build -t monolith_ubuntu22_exec:1.0 -f monolith.Dockerfile . ``` Then, load this image to k8s cluster ```bash kind load docker-image monolith_ubuntu22_exec:1.0 ``` ## Prepare the movie-lens dataset Now, we will convert the movie-lens dataset to CSV format, which is later feed to the model through Primus's input manager. This may take a while (a few hours) due to the size of the dataset, depending on the number of CPU cores you have. ```bash pip3 install tensorflow==2.4.0 tensorflow-datasets cd monolith/markdown/demo mkdir -p data_1m python3 ml_dataset.py ``` When the conversion finished, upload the data to HDFS ```bash /usr/lib/hadoop/bin/hdfs dfs -put data_1m /primus/ ``` ## Launch training with Primus on k8s First, make sure that the `files` entry of `monolith/markdown/primus_demo/primus_monolith.json` matches the actual place where you clone monolith. Then, you can submit the training via ```bash /usr/lib/primus-kubernetes/sbin/primus-submit --primus_conf primus_monolith.json ``` ================================================ FILE: markdown/primus_demo/main.sh ================================================ #!/bin/bash set -ex # setup env export JAVA_HOME=/usr/lib/jvm/java-8-openjdk-amd64 export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$JAVA_HOME/jre/lib/amd64/server/ export HADOOP_HDFS_HOME=/usr/lib/hadoop export CLASSPATH="$(/usr/lib/hadoop/bin/hadoop classpath --glob)" python3 demo/demo_model.py \ --model_dir=hdfs:///primus/model-checkpoints/movie_lens_tutorial \ --model_name=movie_lens_tutorial \ --training_type=stdin ================================================ FILE: markdown/primus_demo/monolith.Dockerfile ================================================ FROM hanzhi713/monolith:ubuntu22.04 # Java will be mounted ENV JAVA_HOME=/usr/lib/jvm/java-8-openjdk-amd64 ENV PATH=$JAVA_HOME/bin:$PATH # Hadoop will be mounted ENV HADOOP_HOME=/usr/lib/hadoop ENV HADOOP_CONF_DIR=$HADOOP_HOME/etc/hadoop ENV PATH=$HADOOP_HOME/bin:$PATH ENTRYPOINT ["sleep"] CMD ["43200"] ================================================ FILE: markdown/primus_demo/primus_monolith.json ================================================ { "name": "primus-monolith", "files": [ "/home/ubuntu/monolith/markdown/demo", "/home/ubuntu/monolith/markdown/primus_demo/main.sh" ], "role": [ { "roleName": "worker", "num": 2, "vcores": 1, "memoryMb": 4096, "jvmMemoryMb": 4096, "command": "env && bash main.sh", "successPercent": 100, "failover": { "commonFailoverPolicy": { "commonFailover": { "restartType": "ON_FAILURE", "maxFailureTimes": 1, "maxFailurePolicy": "FAIL_ATTEMPT" } } }, "inputPolicy": "STREAMING" }, { "roleName": "ps", "num": 2, "vcores": 1, "memoryMb": 4096, "jvmMemoryMb": 4096, "command": "env && bash main.sh", "successPercent": 100, "failover": { "commonFailoverPolicy": { "commonFailover": { "restartType": "ON_FAILURE", "maxFailureTimes": 1, "maxFailurePolicy": "FAIL_ATTEMPT" } } } } ], "inputManager": { "fileConfig": { "inputs": [ { "name": "data", "spec": { "pathPattern": "/primus/data_1m/", "namePattern": "part_*.csv", "textInput": {} } } ], "stopPolicy": { "taskSuccessPercent": 100 } }, "workPreserve": { "dumpIntervalSecs": 5, "hdfsConfig": {} }, "gracefulShutdown": "true" }, "runtimeConf": { "kubernetesNativeConf": { "executorPodConf": { "mainContainerConf": { "imageName": "monolith_ubuntu22_exec:1.0" } } } } } ================================================ FILE: markdown/serving.md ================================================ # Serving ## Understanding Hashtable Ckpt Format ### Export Hashtable Ckpt ```python import tensorflow as tf from monolith.native_training import hash_table_ops with tf.compat.v1.Session() as sess: table1 = hash_table_ops.test_hash_table(4) table1 = table1.assign(tf.convert_to_tensor([1], tf.int64), tf.convert_to_tensor([[0,1,2,3]], tf.float32)) table1 = table1.save("/tmp/save_restore") sess.run(table1.as_op()) ``` ### Parse Hashtable Ckpt ```python import tensorflow as tf from monolith.native_training.runtime.hash_table import \ embedding_hash_table_pb2 dataset = tf.data.TFRecordDataset("/tmp/save_restore-00000-of-00001") for raw_dump in dataset: entry_dump = embedding_hash_table_pb2.EntryDump() entry_dump.ParseFromString(raw_dump.numpy()) print(entry_dump) ``` ``` 2022-10-19 07:09:30.406350: I external/org_tensorflow/tensorflow/compiler/jit/xla_cpu_device.cc:41] Not creating XLA devices, tf_xla_enable_xla_devices not set 2022-10-19 07:09:30.406563: I external/org_tensorflow/tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX512F To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags. 2022-10-19 07:09:30.426969: I external/org_tensorflow/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:116] None of the MLIR optimization passes are enabled (registered 2) 2022-10-19 07:09:30.439333: I external/org_tensorflow/tensorflow/core/platform/profile_utils/cpu_utils.cc:112] CPU Frequency: 2300000000 Hz id: 1 num: 0.0 num: 1.0 num: 2.0 num: 3.0 opt { dump { sgd { } } } last_update_ts_sec: 0 ``` ## Model Serving Monolith uses tensorflow saved_model as servable format. There are two kinds of saved_model in monolith. One is entry, the other is PS. PS is a KV-Storage for embeddings. Entry accepts client calls, calls PS to fetch embeddings and runs the computation graph to get the target tensor value. PS is not callable from client directly, only entry should call PS. ### Config saved_model export Saved_model exporting happens during ckpt saving stage during training, in order to enable that we need to have two changes. First set `self.p.serving.export_when_saving = True`, then implement `serving_input_receiver_fn` to parse serving request. ```python class DemoModel(MonolithModel): def __init__(self, params): super().__init__(params) self.p = params self.p.serving.export_when_saving = True ...... def serving_input_receiver_fn(self): input_placeholder = tf.compat.v1.placeholder(dtype=tf.string, shape=(None,)) receiver_tensors = {'examples': input_placeholder} raw_feature_desc = { 'mov': tf.io.FixedLenFeature([1], tf.int64), 'uid': tf.io.FixedLenFeature([1], tf.int64), 'label': tf.io.FixedLenFeature([], tf.float32) } examples = tf.io.parse_example(input_placeholder, raw_feature_desc) parsed_features = { 'mov': tf.RaggedTensor.from_tensor(examples['mov']), 'uid': tf.RaggedTensor.from_tensor(examples['uid']), 'label': examples['label'] } return tf.estimator.export.ServingInputReceiver(receiver_tensors, parsed_features) ``` ### Exported File Structure Suppose a training job has `hdfs:///user/xxx/model_checkpoint` as its model_dir, the saved_models for saving will reside in `hdfs:///user/xxx/model_checkpoint/exported_models` For example ``` ➜ hdfs dfs -ls hdfs:///user/xxx/model_checkpoint/exported_models drwxr-xr-x - nnproxy supergroup 0 2022-07-14 07:38 hdfs:///user/xxx/model_checkpoint/exported_models/entry drwxr-xr-x - nnproxy supergroup 0 2022-07-14 07:38 hdfs:///user/xxx/model_checkpoint/exported_models/ps_0 drwxr-xr-x - nnproxy supergroup 0 2022-07-14 07:38 hdfs:///user/xxx/model_checkpoint/exported_models/ps_1 drwxr-xr-x - nnproxy supergroup 0 2022-07-14 07:38 hdfs:///user/xxx/model_checkpoint/exported_models/ps_2 ``` ### Serving Configuration Using the above file structure as an example, saved_models are stored in hdfs:///user/xxx/model_checkpoint/exported_models #### Standalone serving For standalone serving, we serve all saved_models of the same model in the same tf serving instance. We can using the following configuration, save it as `demo.conf` ```conf bzid monolith_serving_test # namespace deploy_type unified # always unified zk_servers 10.*.91.73:2181,10.*.86.70:2181,10.*.126.131:2181,10.*.109.135:2181 base_path hdfs:///user/xxx/model_checkpoint/exported_models layout_filters entry; True layout_filters ps_{i}; True agent_version 3 # always 3 # tensorflow serving flags fetch_ps_timeout_ms 10000 enable_batching false tensorflow_session_parallelism 0 tensorflow_intra_op_parallelism 0 tensorflow_inter_op_parallelism 0 per_process_gpu_memory_fraction 0 num_load_threads 0 num_unload_threads 0 max_num_load_retries 5 load_retry_interval_micros 60 * 1000 * 1000 file_system_poll_wait_seconds 60 flush_filesystem_caches true saved_model_tags none grpc_channel_arguments none grpc_max_threads 0 enable_model_warmup true enable_signature_method_name_check false xla_cpu_compilation_enabled false enable_profiler true ``` #### Start TF Serving We use the following command to start serving agent. It will start the tf serving process and register to the name service(zookeeper). ```bash bazel run monolith/agent_service:agent -- --conf=`demo.conf` --tfs_log=tfs.std.log ``` We can see the following log printed out, showing our saved_models are successfully loaded and registered to the name service. ``` I1101 05:55:59.951008 139897902933760 backends.py:222] available saved models updating, add: {test_ffm_model_2:ps_0, test_ffm_model_2:ps_1, test_ffm_model_2:ps_2, test_ffm_model_2:entry}, remove: set() I1101 05:55:59.973262 139897902933760 backends.py:230] available saved models updated: {test_ffm_model_2:ps_0, test_ffm_model_2:ps_1, test_ffm_model_2:ps_2, test_ffm_model_2:entry} ``` #### Distributed Serving There are cases when our models are too large that they can not fit in one container. Still using `hdfs:///user/xxx/model_checkpoint/exported_models` as an example, now we want to have two machines to serve the models. For machine 1, we want to load `entry` and `ps_1`, for machine 2 we want to load `ps_0` and `ps_2`. ##### Conf for Machine 1 ``` ... base_path hdfs:///user/xxx/model_checkpoint/exported_models layout_filters entry; True layout_filters ps_{i}; i % 2 == 1 # ps_1 ... ``` ##### Conf for Machine 2 ``` ... base_path hdfs:///user/xxx/model_checkpoint/exported_models layout_filters ps_{i}; i % 2 == 0 # ps_0 and ps_2 ... ``` we use layout_filters for a container to pick the saved_models. The pattern is `{match}:{filter}`. For example, ps_{i} will match ps_1 and assign i = 1, then i can be used in filter clause. ================================================ FILE: monolith/BUILD ================================================ load("@rules_python//python:defs.bzl", "py_binary", "py_library") package(default_visibility = ["//visibility:public"]) py_library( name = "path_utils", srcs = ["path_utils.py"], ) py_library( name = "utils", srcs = ["utils.py"], deps = [ ":path_utils", ], ) py_test( name = "utils_test", srcs = ["utils_test.py"], deps = [ ":utils", ], ) py_library( name = "init", srcs = ["__init__.py"], deps = [ "//monolith/native_training:entry", "//monolith/native_training:estimator", "//monolith/native_training:native_model", "//monolith/native_training/data", "//monolith/native_training/layers", "//monolith/native_training/model_export", "@org_tensorflow//tensorflow:tensorflow_py", ], ) ================================================ FILE: monolith/__init__.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import sys from absl import logging import importlib from tensorflow.python.tools import module_util as _module_util from monolith.native_training import data from monolith.native_training import layers from monolith.native_training import model_export from monolith.native_training import entry from monolith.native_training import native_model as base_model from monolith.native_training import estimator from monolith.utils import enable_monkey_patch def add_module(module): try: if isinstance(module, str): name = module.split('.')[-1] module = importlib.import_module(module) else: name = module.__name__.split('.')[-1] if name == 'native_model': name = 'base_model' except ImportError as e: raise e sys.modules[f'{__name__}.{name}'] = module add_module(data) add_module(layers) add_module(model_export) add_module(entry) add_module(base_model) add_module(estimator) try: enable_monkey_patch() except: logging.error('enable_monkey_patch failed') ================================================ FILE: monolith/agent_service/BUILD ================================================ load("@com_github_grpc_grpc//bazel:cc_grpc_library.bzl", "cc_grpc_library") load("@com_github_grpc_grpc//bazel:python_rules.bzl", "py_grpc_library", "py_proto_library") load("@com_google_protobuf//:protobuf.bzl", "cc_proto_library") load("@pip_deps//:requirements.bzl", "requirement") load("@rules_python//python:defs.bzl", "py_binary", "py_library", "py_test") load("@rules_proto//proto:defs.bzl", "proto_library") package(default_visibility = [ "//monolith/agent_service:__subpackages__", "//monolith/integration_test:__subpackages__", "//monolith/native_training:__subpackages__", ]) py_library( name = "utils", srcs = ["utils.py"], srcs_version = "PY3", deps = [ "//idl:proto_parser_py_proto", "@org_tensorflow//tensorflow/core:protos_all_py", "@org_tensorflow//tensorflow/core/example:protos_all_py", "@org_tensorflow_serving//tensorflow_serving/apis:model_service_proto_py_pb2", "@org_tensorflow_serving//tensorflow_serving/apis:prediction_service_proto_py_pb2", "@org_tensorflow_serving//tensorflow_serving/apis:session_service_proto_py_pb2", "@org_tensorflow_serving//tensorflow_serving/config:platform_config_py_pb2", "@org_tensorflow_serving//tensorflow_serving/servables/tensorflow:saved_model_bundle_source_adapter_py_pb2", "@org_tensorflow_serving//tensorflow_serving/servables/tensorflow:session_bundle_config_py_pb2", requirement("dataclasses_json"), ":constants", "//monolith/native_training:env_utils", "//monolith/native_training:zk_utils", ], ) py_test( name = "utils_test", srcs = ["utils_test.py"], data = [ "agent.conf", "//monolith/agent_service/test_data", ], srcs_version = "PY3", deps = [ ":utils", ], ) py_library( name = "backends", srcs = ["backends.py"], srcs_version = "PY3", deps = [ requirement("kazoo"), requirement("dataclasses_json"), ":utils", ], ) py_test( name = "backends_test", srcs = ["backends_test.py"], srcs_version = "PY3", deps = [ ":backends", ":mocked_zkclient", ":utils", ], ) py_library( name = "constants", srcs = ["constants.py"], ) py_binary( name = "agent_controller", srcs = [ "agent_controller.py", ], srcs_version = "PY3", deps = [ ":backends", ":utils", "@org_tensorflow//tensorflow:tensorflow_py", ], ) py_library( name = "replica_manager", srcs = ["replica_manager.py"], srcs_version = "PY3", deps = [ ":agent_service_py_pb2", ":backends", ":data_def", ":resource_utils", ":tfs_monitor", ":utils", "//monolith/native_training/metric:cli", "//monolith/native_training/model_export:export_state_utils", "//monolith/native_training/runtime/parameter_sync:parameter_sync_py_proto", requirement("dataclasses_json"), ], ) py_library( name = "model_manager", srcs = ["model_manager.py"], srcs_version = "PY3", deps = [ "//monolith/native_training/metric:cli", ], ) py_test( name = "model_manager_test", srcs = ["model_manager_test.py"], srcs_version = "PY3", deps = [ ":model_manager", ], ) py_test( name = "replica_manager_test", srcs = ["replica_manager_test.py"], srcs_version = "PY3", deps = [ ":constants", ":mocked_tfserving", ":mocked_zkclient", ":replica_manager", ":tfs_monitor", ":utils", ], ) py_library( name = "agent_base", srcs = ["agent_base.py"], srcs_version = "PY3", visibility = ["//visibility:public"], deps = [ ":utils", ], ) py_library( name = "agent_v1", srcs = ["agent_v1.py"], srcs_version = "PY3", visibility = ["//visibility:public"], deps = [ ":agent_base", ":agent_service", ":replica_manager", ":tfs_monitor", ], ) py_library( name = "mocked_tfserving", srcs = [ "mocked_tfserving.py", ], srcs_version = "PY3", deps = [ ":utils", ], ) py_test( name = "mocked_tfserving_test", srcs = [ "mocked_tfserving_test.py", ], srcs_version = "PY3", deps = [ ":mocked_tfserving", ], ) py_library( name = "mocked_zkclient", srcs = [ "mocked_zkclient.py", ], srcs_version = "PY3", deps = [ requirement("kazoo"), ], ) py_test( name = "mocked_zkclient_test", srcs = [ "mocked_zkclient_test.py", ], srcs_version = "PY3", deps = [ ":mocked_zkclient", ], ) proto_library( name = "agent_service_proto", srcs = ["agent_service.proto"], ) py_proto_library( name = "agent_service_py_pb2", deps = [":agent_service_proto"], ) py_grpc_library( name = "agent_service_py_pb2_grpc", srcs = [":agent_service_proto"], deps = [":agent_service_py_pb2"], ) py_library( name = "agent_service", srcs = ["agent_service.py"], srcs_version = "PY3", deps = [ ":agent_service_py_pb2", ":agent_service_py_pb2_grpc", ":data_def", ":replica_manager", ":resource_utils", ":utils", ":zk_mirror", ], ) py_binary( name = "agent_client", srcs = ["agent_client.py"], srcs_version = "PY3", deps = [ ":agent_service_py_pb2", ":agent_service_py_pb2_grpc", ":utils", requirement("kazoo"), ":client", ":data_def", ":resource_utils", ], ) py_test( name = "agent_service_test", srcs = [ "agent_service_test.py", ], srcs_version = "PY3", deps = [ ":agent_service", ":mocked_zkclient", ":svr_client", ], ) py_binary( name = "tfs_client", srcs = ["tfs_client.py"], deps = [ ":client", ":utils", "//idl:example_py_proto", "//idl:line_id_py_proto", "//idl:proto_parser_py_proto", "//monolith/native_training/data:feature_list", "//monolith/native_training/model_export:data_gen_utils", ], ) cc_proto_library( name = "agent_service_cc_proto", srcs = [ "agent_service.proto", ], visibility = ["//visibility:public"], ) cc_grpc_library( name = "agent_service_cc_proto_grpc", srcs = [ ":agent_service_proto", ], generate_mocks = True, grpc_only = True, visibility = ["//visibility:public"], deps = [ ":agent_service_cc_proto", ], ) py_library( name = "data_def", srcs = ["data_def.py"], srcs_version = "PY3", deps = [ ":utils", "//monolith/native_training:net_utils", ], ) py_test( name = "data_def_test", srcs = ["data_def_test.py"], srcs_version = "PY3", deps = [ ":data_def", ], ) py_library( name = "tfs_monitor", srcs = ["tfs_monitor.py"], srcs_version = "PY3", deps = [ ":data_def", ":utils", ], ) py_test( name = "tfs_monitor_test", srcs = ["tfs_monitor_test.py"], srcs_version = "PY3", deps = [ ":constants", ":mocked_tfserving", ":tfs_monitor", ], ) py_library( name = "zk_mirror", srcs = ["zk_mirror.py"], deps = [ ":data_def", ":utils", requirement("kazoo"), ], ) py_test( name = "zk_mirror_test", srcs = ["zk_mirror_test.py"], deps = [ ":agent_service_py_pb2", ":constants", ":mocked_tfserving", ":mocked_zkclient", ":zk_mirror", ], ) py_library( name = "resource_utils", srcs = ["resource_utils.py"], srcs_version = "PY3", deps = [ ":data_def", ":utils", "//monolith/native_training/model_export:export_py_proto", "//monolith/native_training/model_export:export_state_utils", "@org_tensorflow//tensorflow:tensorflow_py", requirement("psutil"), ], ) py_test( name = "resource_utils_test", srcs = ["resource_utils_test.py"], deps = [ ":resource_utils", ], ) py_library( name = "tfs_wrapper", srcs = ["tfs_wrapper.py"], data = [ "//conf:serving", ], deps = [ ":utils", "//monolith:utils", ], ) py_library( name = "agent_v3", srcs = ["agent_v3.py"], srcs_version = "PY3", deps = [ ":agent_base", ":agent_service", ":backends", ":data_def", ":resource_utils", ":tfs_wrapper", ], ) py_test( name = "agent_v3_test", srcs = ["agent_v3_test.py"], srcs_version = "PY3", deps = [ ":agent_v3", ":mocked_zkclient", ], ) py_binary( name = "client", srcs = ["client.py"], srcs_version = "PY3", deps = [ ":data_def", ":utils", ":zk_mirror", requirement("kazoo"), ], ) filegroup( name = "agent_internal_data", ) py_binary( name = "agent", srcs = ["agent.py"], data = [ ":agent_internal_data", ], srcs_version = "PY3", visibility = ["//visibility:public"], deps = [ ":agent_v1", ":agent_v3", ":mocked_zkclient", ":model_manager", ], ) py_binary( name = "run", srcs = ["run.py"], visibility = ["//visibility:public"], deps = [ ":tfs_client", ":agent_client", ":agent", ], ) py_binary( name = "svr_client", srcs = ["svr_client.py"], srcs_version = "PY3", deps = [ ":agent_service_py_pb2", ":agent_service_py_pb2_grpc", ":utils", ], ) filegroup( name = "agent_exported", ) ================================================ FILE: monolith/agent_service/__init__.py ================================================ ================================================ FILE: monolith/agent_service/agent.conf ================================================ bzid predict_ctr base_name predict_ctr base_path hdfs:///test/data num_ps 10 server_type entry zk_servers 127.0.0.1:12345 max_waiting_sec 600 layout_filters ps_0 layout_filters ps_1 agent_version 1 stand_alone_serving true update_model_status_interval 10 enable_batching false tensorflow_session_parallelism 0 tensorflow_intra_op_parallelism 0 tensorflow_inter_op_parallelism 0 per_process_gpu_memory_fraction 0 num_load_threads 0 num_unload_threads 0 max_num_load_retries 5 load_retry_interval_micros 60 * 1000 * 1000 file_system_poll_wait_seconds 1 file_system_poll_wait_seconds_ps 0 flush_filesystem_caches true saved_model_tags none grpc_channel_arguments none grpc_max_threads 0 enable_model_warmup true enable_signature_method_name_check false xla_cpu_compilation_enabled false enable_profiler true num_shard 1 dc_aware 1 ================================================ FILE: monolith/agent_service/agent.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from absl import app, flags, logging from concurrent.futures import ThreadPoolExecutor from enum import Enum from kazoo.client import KazooClient import os import copy import subprocess import signal from subprocess import CalledProcessError import threading import time from typing import List from multiprocessing import Process from monolith.agent_service.replica_manager import ReplicaManager from monolith.agent_service.agent_service import AgentService from monolith.agent_service.utils import AgentConfig, DeployType, check_port_open from monolith.native_training.zk_utils import MonolithKazooClient from monolith.native_training import env_utils from monolith.agent_service.agent_v1 import AgentV1 from monolith.agent_service.agent_v3 import AgentV3 from monolith.agent_service.model_manager import ModelManager FLAGS = flags.FLAGS flags.DEFINE_string('tfs_log', '/var/log/tfs.std.log', 'The tfs log file path') def run_agent(agent_config_path: str, tfs_log: str, use_mps: bool, replica_id: int, dense_service_index: int): if use_mps: os.environ["REPLICA_ID"] = str(replica_id) logging.info(f"[INFO] the corresponding replica_id {replica_id}") os.environ["DENSE_SERVICE_IDX"] = str(dense_service_index) tfs_log = "{}.mps{}".format(tfs_log, dense_service_index) config = AgentConfig.from_file(agent_config_path) conf_path = os.path.dirname(agent_config_path) if config.agent_version == 1: agent = AgentV1(config, conf_path, tfs_log) elif config.agent_version == 2: raise Exception('agent_version v2 is not support') elif config.agent_version == 3: agent = AgentV3(config, conf_path, tfs_log) else: raise Exception(f"agent_version error {config.agent_version}") # start model manager for rough sort model model_manager = ModelManager(config.rough_sort_model_name, config.rough_sort_model_p2p_path, config.rough_sort_model_local_path, True) ret = model_manager.start() if not ret: logging.error('model_manager start failed, kill self') os.kill(os.getpid(), signal.SIGKILL) agent.start() agent.wait_for_termination() def main(_): try: env_utils.setup_hdfs_env() except Exception as e: logging.error('setup_hdfs_env fail {}!'.format(e)) logging.info(f'environ is : {os.environ!r}') if FLAGS.conf is None: print(FLAGS.get_help()) return config = AgentConfig.from_file(FLAGS.conf) if config.deploy_type == DeployType.DENSE and config.dense_service_num > 1: p_list = [] for i in range(config.dense_service_num): cur_rid = config.replica_id * config.dense_service_num + i p = Process(target=run_agent, args=(FLAGS.conf, FLAGS.tfs_log, True, cur_rid, i)) p.start() p_list.append(p) for p in p_list: p.join() else: run_agent(FLAGS.conf, FLAGS.tfs_log, False, config.replica_id, 0) if __name__ == '__main__': app.run(main) ================================================ FILE: monolith/agent_service/agent_base.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os from abc import ABCMeta, abstractmethod from monolith.agent_service.utils import AgentConfig, TFS_HOME, TFSServerType TFS_BINARY = f'{TFS_HOME}/bin/tensorflow_model_server' PROXY_BINARY = f'{TFS_HOME}/bin/server' def get_cmd_path(): path = os.path.abspath(__file__) return path def get_cmd_and_port(config: AgentConfig, conf_path: str = None, server_type: str = None, config_file: str = None, tfs_binary: str = TFS_BINARY, proxy_binary: str = PROXY_BINARY): if server_type == TFSServerType.PS: return config.get_cmd_and_port(tfs_binary, server_type=TFSServerType.PS, config_file=config_file) elif server_type == TFSServerType.ENTRY: return config.get_cmd_and_port(tfs_binary, server_type=TFSServerType.ENTRY, config_file=config_file) elif server_type == TFSServerType.DENSE: return config.get_cmd_and_port(tfs_binary, server_type=TFSServerType.DENSE, config_file=config_file) else: proxy_conf = os.path.join(conf_path, 'proxy.conf') if os.path.exists(proxy_conf): cmd = f'{proxy_binary} --port={config.proxy_port} ' \ f'--grpc_target=localhost:{config.tfs_entry_port} --conf_file={proxy_conf} &' else: cmd = f'{proxy_binary} --port={config.proxy_port} ' \ f'--grpc_target=localhost:{config.tfs_entry_port} &' return cmd, config.proxy_port class ServingLog(object): def __init__(self, log_prefix: str, tfs_log: str): self._log_prefix = log_prefix self._tfs_log = tfs_log self._cwd = None self._log = None def __enter__(self): dirname = os.path.dirname(self._tfs_log) basename = os.path.basename(self._tfs_log) log_filename = os.path.join(dirname, f"{self._log_prefix}_{basename}") self._cwd = os.getcwd() os.chdir(f'{TFS_HOME}/bin') return open(log_filename, 'a') def __exit__(self, exc_type, exc_val, exc_tb): os.chdir(self._cwd) class AgentBase(metaclass=ABCMeta): def __init__(self, conf: AgentConfig): self.config = conf @abstractmethod def start(self): raise NotImplementedError("start is not implemented") @abstractmethod def wait_for_termination(self): raise NotImplementedError("wait_for_termination is not implemented") ================================================ FILE: monolith/agent_service/agent_client.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from absl import app, flags, logging import grpc import socket import os, re from monolith.agent_service import utils from monolith.agent_service.agent_service_pb2_grpc import AgentServiceStub from monolith.agent_service.agent_service_pb2 import HeartBeatRequest, ServerType, \ GetReplicasRequest from monolith.agent_service.data_def import ModelMeta, ReplicaMeta, ModelState from monolith.agent_service.resource_utils import cal_model_info_v2 from monolith.agent_service.client import FLAGS from monolith.native_training import env_utils from monolith.native_training.zk_utils import MonolithKazooClient from kazoo.exceptions import NoNodeError flags.DEFINE_integer("port", 0, "agent_port") flags.DEFINE_enum("args", "addr", ["addr", "portal", "pub", "res", "lock", "elect", "info"], "args: addr, portal, pub, res, lock, elect") flags.DEFINE_enum("server_type", "ps", ["ps", "entry", "dense"], "server_type, ps or entry or dense") flags.DEFINE_integer("task", 0, "task id of given server_type") flags.DEFINE_string('model_dir', None, 'saved model dir') flags.DEFINE_string('ckpt', None, 'ckpt name') flags.DEFINE_integer('num_shard', -1, 'number of shard will use of current model') def main(_): env_utils.setup_hdfs_env() agent_conf = utils.AgentConfig.from_file(FLAGS.conf) if FLAGS.port != 0: agent_conf.agent_port = FLAGS.port host = os.environ.get("MY_HOST_IP", socket.gethostbyname(socket.gethostname())) channel = grpc.insecure_channel(f"{host}:{agent_conf.agent_port}") stub = AgentServiceStub(channel) model_name = agent_conf.base_name or FLAGS.model_name if FLAGS.server_type == "ps": server_type = ServerType.PS elif FLAGS.server_type == "dense": server_type = ServerType.DENSE else: server_type = ServerType.ENTRY if FLAGS.cmd_type == 'hb': request = HeartBeatRequest(server_type=server_type) addresses = stub.HeartBeat(request).addresses for k, v in addresses.items(): addrs = f"{v}".strip().split("\n") print("{k} -> ({length}) \n\t{addrs}".format(k=k, length=len(addrs), addrs="\n\t".join(addrs))) elif FLAGS.cmd_type == 'gr': assert model_name is not None request = GetReplicasRequest(server_type=server_type, task=FLAGS.task, model_name=model_name) print(ServerType.Name(server_type), FLAGS.task, " => ", stub.GetReplicas(request).address_list.address) elif (FLAGS.cmd_type == 'get' and FLAGS.args == 'addr') or FLAGS.cmd_type == 'addr': assert model_name is not None zk = MonolithKazooClient(hosts=agent_conf.zk_servers) zk.start() # bzid/service/model_name/idc:cluster/server_type:task/replica_id path_prefix = f'/{agent_conf.bzid}/service/{model_name}' servers = [] TASK = re.compile(r'^(\w+):(\d+)$') try: ics_or_svrs = zk.get_children(path_prefix) for ic_svr in ics_or_svrs: matched = TASK.match(ic_svr) if matched: svr = ic_svr servers.append(svr) else: ic = ic_svr svrs = zk.get_children(f'{path_prefix}/{ic}') if svrs: servers.extend([f'{ic}/{svr}' for svr in svrs]) except NoNodeError as e: print(f'{model_name} has not load !') zk.stop() return entry_id = 0 for_print = [] if servers: for server in servers: replicas = zk.get_children(f"{path_prefix}/{server}") if replicas: for replica in replicas: data, _ = zk.get(f"{path_prefix}/{server}/{replica}") data = ReplicaMeta.deserialize(data) replica_id = replica for_print.append( f"{path_prefix}/{server}/{replica_id}\tarchon_address: {data.archon_address}\t" f"address: {data.address}\tstate: {ModelState.Name(data.stat)}") for_print.sort() print("\n".join(for_print)) zk.stop() elif FLAGS.cmd_type == 'get' and FLAGS.args == 'info': print(cal_model_info_v2(FLAGS.model_dir, FLAGS.ckpt)) elif FLAGS.cmd_type == 'get': zk = MonolithKazooClient(hosts=agent_conf.zk_servers) zk.start() # /{bzid}/resource/{shard_id}:{replica_id} -> ResourceSpec if FLAGS.args == 'res': path_prefix = f'/{agent_conf.bzid}/resource' elif FLAGS.args == 'pub': path_prefix = f'/{agent_conf.bzid}/publish' elif FLAGS.args == 'portal': path_prefix = f'/{agent_conf.bzid}/portal' elif FLAGS.args == 'lock': path_prefix = f'/{agent_conf.bzid}/lock' elif FLAGS.args == 'elect': path_prefix = f'/{agent_conf.bzid}/election' else: return try: servers = zk.get_children(path_prefix) except NoNodeError as e: print(f'no {FLAGS.args} found !') zk.stop() return resources = {} if servers: for server in servers: data, _ = zk.get(f"{path_prefix}/{server}") resources[server] = data if resources: keys = list(resources.keys()) keys.sort() for key in keys: print(key, resources[key]) else: print(resources) zk.stop() elif FLAGS.cmd_type == 'load': assert model_name is not None zk = MonolithKazooClient(hosts=agent_conf.zk_servers) zk.start() mm = ModelMeta(model_name=model_name, model_dir=FLAGS.model_dir, ckpt=FLAGS.ckpt, num_shard=FLAGS.num_shard) path = f'/{agent_conf.bzid}/portal/{model_name}' try: zk.create(path, value=mm.serialize(), include_data=True, makepath=True) except Exception as e: logging.info(e) zk.set(path, value=mm.serialize()) zk.stop() elif FLAGS.cmd_type == 'unload': zk = MonolithKazooClient(hosts=agent_conf.zk_servers) zk.start() path = f'/{agent_conf.bzid}/portal/{model_name}' try: zk.delete(path) except Exception as e: logging.info(e) zk.stop() elif FLAGS.cmd_type == 'clean': zk = MonolithKazooClient(hosts=agent_conf.zk_servers) zk.start() if FLAGS.args == 'portal': path = f'/{agent_conf.bzid}/portal' for node in zk.get_children(path): zk.delete(os.path.join(path, node)) elif FLAGS.args == 'pub': path = f'/{agent_conf.bzid}/publish' for node in zk.get_children(path): zk.delete(os.path.join(path, node)) elif FLAGS.args == 'addr': path = f'/{agent_conf.bzid}/service' for node in zk.get_children(path): zk.delete(os.path.join(path, node), recursive=True) elif FLAGS.args == 'res': path = f'/{agent_conf.bzid}/resource' for node in zk.get_children(path): zk.delete(os.path.join(path, node), recursive=True) else: raise RuntimeError(f"{FLAGS.args} is not support!") zk.stop() if __name__ == "__main__": app.run(main) ================================================ FILE: monolith/agent_service/agent_controller.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import json import os import fnmatch from absl import app, flags, logging import tensorflow as tf from tensorflow.core.protobuf import saved_model_pb2 from tensorflow.python.util import compat from monolith.native_training import env_utils from monolith.agent_service.backends import CtrlBackend, ZKBackend, SavedModel, SavedModelDeployConfig SUPPORTED_CMDS = "decl|pub|unpub|bzid_info" flags.DEFINE_string( "zk_servers", "", "zk connection string") flags.DEFINE_string("bzid", "test", "namespace") flags.DEFINE_string("export_base", "", "exported model base path") flags.DEFINE_integer("overwrite", 0, "overwrite existing saved_model configs") flags.DEFINE_string("model_name", "", "model_name") flags.DEFINE_string("layout", "", "layout base") flags.DEFINE_string("arch", "entry_ps", "serving architecture") flags.DEFINE_string("cmd", "bzid_info", SUPPORTED_CMDS) FLAGS = flags.FLAGS def find_model_name(exported_models_path: str): # find model name used in remote predict op entry_path = os.path.join(exported_models_path, 'entry') latest_timestamp = sorted(tf.io.gfile.listdir(entry_path))[0] sm_file = os.path.join(entry_path, latest_timestamp, "saved_model.pb") logging.info(f"loading: {sm_file}") with tf.io.gfile.GFile(sm_file, 'rb') as f: sm = saved_model_pb2.SavedModel() sm.ParseFromString(compat.as_bytes(f.read())) remote_predict_model_names = [ node.attr['model_name'].s.decode('utf-8') for node in sm.meta_graphs[0].graph_def.node if node.op == 'TfServingRemotePredict' ] if not remote_predict_model_names: return None else: return remote_predict_model_names[0].split(":")[0] def declare_saved_model(bd: CtrlBackend, export_base: str, model_name: str = None, overwrite=False, arch="entry_ps"): assert arch == "entry_ps", "only entry + ps architecture supported" model_name_from_export = find_model_name(export_base) if not model_name: model_name = model_name_from_export if model_name != model_name_from_export: logging.error( f"user model_name: {model_name}, exported_model_name: {model_name_from_export}" ) assert model_name is not None, "Model name is None" assert not bd.list_saved_models( model_name) or overwrite, f"{model_name} exists and not in overwrite mode" sub_graphs = tf.io.gfile.listdir(export_base) for sub_graph in sub_graphs: deploy_config = SavedModelDeployConfig( model_base_path=os.path.join(export_base, sub_graph), version_policy='latest' if sub_graph == 'entry' else 'latest_once') bd.decl_saved_model(SavedModel(model_name, sub_graph), deploy_config) logging.info( f"declare saved_model for {model_name} on path {export_base} success") return model_name def map_model_to_layout(bd: CtrlBackend, model_pattern: str, layout_path: str, action: str): model_name, sub_graph_pattern = model_pattern.split(":", 1) sub_graphs = [ saved_model.sub_graph for saved_model in bd.list_saved_models(model_name) ] matched_sub_graphs = fnmatch.filter(sub_graphs, sub_graph_pattern) for sub_graph in matched_sub_graphs: saved_model = SavedModel(model_name, sub_graph) if action == 'pub': logging.info(f"publishing {saved_model} to {layout_path}") bd.add_to_layout(layout_path, saved_model) elif action == 'unpub': logging.info(f"deleting {saved_model} from {layout_path}") bd.remove_from_layout(layout_path, saved_model) def bzid_info(bd: CtrlBackend): print(json.dumps(bd.bzid_info(), indent=2)) def main(_): if FLAGS.cmd not in SUPPORTED_CMDS.split("|"): raise ValueError( f"unsupported cmd {FLAGS.cmd}, options are {SUPPORTED_CMDS}") print() bd = ZKBackend(FLAGS.bzid, FLAGS.zk_servers) try: bd.start() if FLAGS.cmd == 'decl': assert FLAGS.export_base is not None and len(FLAGS.export_base) > 0 declare_saved_model(bd, FLAGS.export_base, overwrite=FLAGS.overwrite, arch=FLAGS.arch) elif FLAGS.cmd == 'pub' or FLAGS.cmd == 'unpub': assert len(FLAGS.layout) > 0 and len(FLAGS.model_name) > 0 layout_path = f"/{FLAGS.bzid}/layouts/{FLAGS.layout}" map_model_to_layout(bd, FLAGS.model_name, layout_path, action=FLAGS.cmd) elif FLAGS.cmd == 'bzid_info': bzid_info(bd) else: raise ValueError( f"unsupported cmd {FLAGS.cmd}, options are {SUPPORTED_CMDS}") finally: bd.stop() if __name__ == "__main__": try: env_utils.setup_hdfs_env() except Exception as e: logging.error('setup_hdfs_env fail {}!'.format(e)) logging.set_verbosity(logging.INFO) app.run(main) ================================================ FILE: monolith/agent_service/agent_controller_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import json import os import unittest from monolith.agent_service import agent_controller from monolith.agent_service import backends from monolith.agent_service.mocked_zkclient import FakeKazooClient def saved_model(sub_graph): return backends.SavedModel('test_ffm_model', sub_graph) class AgentControllerTest(unittest.TestCase): @classmethod def setUpClass(cls) -> None: cls.bzid = 'gip' cls.bd = backends.ZKBackend(cls.bzid, zk_servers='127.0.0.1:9999') cls.zk = FakeKazooClient() cls.bd._zk = cls.zk cls.bd.start() @classmethod def tearDownClass(cls) -> None: cls.bd.stop() print('tearDownClass finished!') def test_decl_saved_models(self): agent_controller.declare_saved_model( self.bd, os.path.join( os.environ['TEST_SRCDIR'], os.environ["TEST_WORKSPACE"], "monolith/native_training/model_export/testdata/saved_model"), 'test_ffm_model', overwrite=True) saved_models = self.bd.list_saved_models('test_ffm_model') self.assertEqual( set(saved_models), { saved_model(sub_graph) for sub_graph in ['ps_0', 'ps_1', 'ps_2', 'ps_3', 'ps_4', 'entry'] }) def test_pub(self): self.maxDiff = None agent_controller.declare_saved_model( self.bd, os.path.join( os.environ['TEST_SRCDIR'], os.environ["TEST_WORKSPACE"], "monolith/native_training/model_export/testdata/saved_model"), 'test_ffm_model', overwrite=True) agent_controller.map_model_to_layout(self.bd, "test_ffm_model:entry", "/gip/layouts/test_layout1", action="pub") self.assertEqual(self.bd.bzid_info()['layout_info']['test_layout1'], ['test_ffm_model:entry']) agent_controller.map_model_to_layout(self.bd, "test_ffm_model:ps_*", "/gip/layouts/test_layout1", action="pub") self.assertEqual(self.bd.bzid_info()['layout_info']['test_layout1'], [ 'test_ffm_model:entry', 'test_ffm_model:ps_0', 'test_ffm_model:ps_1', 'test_ffm_model:ps_2', 'test_ffm_model:ps_3', 'test_ffm_model:ps_4' ]) agent_controller.map_model_to_layout(self.bd, "test_ffm_model:ps_*", "/gip/layouts/test_layout1", action="unpub") self.assertEqual(self.bd.bzid_info()['layout_info']['test_layout1'], ['test_ffm_model:entry']) agent_controller.map_model_to_layout(self.bd, "test_ffm_model:entry", "/gip/layouts/test_layout1", action="unpub") self.assertEqual(self.bd.bzid_info()['layout_info']['test_layout1'], []) if __name__ == "__main__": unittest.main() ================================================ FILE: monolith/agent_service/agent_service.proto ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. syntax = "proto3"; option cc_enable_arenas = true; package monolith.serving.agent_service; enum ServerType { PS = 0; ENTRY = 1; DENSE = 2; }; message AddressList { repeated string address = 1; } message GetReplicasRequest { ServerType server_type = 1; int32 task = 2; string model_name = 3; } message GetReplicasResponse { AddressList address_list = 1; } message GetResourceRequest { } message GetResourceResponse { string address = 1; int32 shard_id = 2; int32 replica_id = 3; int64 memory = 4; float cpu = 5; float network = 6; float work_load = 7; } message HeartBeatRequest { ServerType server_type = 1; } message HeartBeatResponse { map addresses = 1; } service AgentService { rpc GetReplicas(GetReplicasRequest) returns (GetReplicasResponse) {} rpc GetResource(GetResourceRequest) returns (GetResourceResponse) {} rpc HeartBeat(HeartBeatRequest) returns (HeartBeatResponse) {} } ================================================ FILE: monolith/agent_service/agent_service.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from concurrent import futures import logging import grpc from typing import List, Dict, Callable from dataclasses import dataclass from functools import singledispatchmethod from monolith.agent_service.utils import AgentConfig, get_local_ip from monolith.agent_service.resource_utils import cal_available_memory_v2 from monolith.agent_service.agent_service_pb2 import AddressList, GetReplicasRequest, GetReplicasResponse, \ GetResourceRequest, GetResourceResponse, HeartBeatRequest, HeartBeatResponse from monolith.agent_service.agent_service_pb2_grpc import AgentServiceServicer from monolith.agent_service.agent_service_pb2_grpc import add_AgentServiceServicer_to_server from monolith.agent_service.replica_manager import ReplicaWatcher from monolith.agent_service.zk_mirror import ZKMirror from monolith.agent_service.data_def import ReplicaMeta class AgentDataProvider: def __init__(self, addrs_fn: Callable[[], Dict[str, List[str]]]): self._addrs_fn = addrs_fn class AgentServiceImpl(AgentServiceServicer): @singledispatchmethod def __init__(self, arg): raise NotImplementedError('__init__ is not implemented!') @__init__.register def _(self, watcher: ReplicaWatcher, conf: AgentConfig = None): self._watcher: ReplicaWatcher = watcher self.conf = conf @__init__.register def _(self, zk: ZKMirror, conf: AgentConfig): self._zk: ZKMirror = zk self.conf = conf @__init__.register def _(self, data_provider: AgentDataProvider, conf: AgentConfig): self._data_provider = data_provider self.conf = conf def GetReplicas(self, request: GetReplicasRequest, context) -> GetReplicasResponse: response = GetReplicasResponse() if self.conf is None or self.conf.agent_version == 1: idc, cluster = self._watcher._conf.idc, self._watcher._conf.cluster address = self._watcher.get_replicas(request.server_type, request.task, idc, cluster) response.address_list.address.extend(address) elif self.conf.agent_version == 2: rms: List[ReplicaMeta] = self._zk.get_task_replicas( request.model_name, request.server_type, request.task) response.address_list.address.extend([rm.address for rm in rms]) else: raise NotImplementedError("not implement for agent v3") return response def HeartBeat(self, request: HeartBeatRequest, context) -> HeartBeatResponse: response = HeartBeatResponse() addresses = response.addresses if self.conf is None or self.conf.agent_version == 1: dc_aware = self._watcher._conf.dc_aware idc, cluster = self._watcher._conf.idc, self._watcher._conf.cluster addrs = self._watcher.get_all_replicas(request.server_type, idc, cluster) for key, values in addrs.items(): key = key.split('/')[-1] if dc_aware else key addr_list = AddressList() addr_list.address.extend(values) addresses[key].CopyFrom(addr_list) elif self.conf.agent_version == 2: rm_dict: Dict[str, List[ReplicaMeta]] = self._zk.get_all_replicas( request.server_type) for key, rms in rm_dict.items(): addr_list = AddressList() addr_list.address.extend([rm.address for rm in rms]) addresses[key].CopyFrom(addr_list) else: addrs_map = self._data_provider._addrs_fn() if addrs_map: for saved_model_name, addrs in addrs_map.items(): addr_list = AddressList() addr_list.address.extend(addrs) addresses[saved_model_name].CopyFrom(addr_list) logging.info(f"heartbeat response({request.server_type}): {response}") return response def GetResource(self, request: GetResourceRequest, context) -> GetResourceResponse: if self.conf is None or self.conf.agent_version == 1: return GetResourceResponse() else: return GetResourceResponse( address=f'{get_local_ip()}:{self.conf.agent_port}', shard_id=int(self.conf.shard_id), replica_id=int(self.conf.replica_id), memory=cal_available_memory_v2(), cpu=-1.0, network=-1.0, work_load=-1.0) class AgentService: @singledispatchmethod def __init__(self, arg): raise NotImplementedError('__init__ is not implemented!') @__init__.register def _(self, watcher: ReplicaWatcher, port: int = None, max_workers: int = 10): self._server = grpc.server( futures.ThreadPoolExecutor(max_workers=max_workers)) add_AgentServiceServicer_to_server(AgentServiceImpl(watcher), self._server) self._server.add_insecure_port(f'[::]:{port or 0}') @__init__.register def _(self, zk: ZKMirror, conf: AgentConfig, max_workers: int = 10): self._server = grpc.server( futures.ThreadPoolExecutor(max_workers=max_workers)) add_AgentServiceServicer_to_server(AgentServiceImpl(zk, conf), self._server) self._server.add_insecure_port(f'[::]:{conf.agent_port or 0}') @__init__.register def _(self, data_provider: AgentDataProvider, conf: AgentConfig, max_workers: int = 10): self._server = grpc.server( futures.ThreadPoolExecutor(max_workers=max_workers)) add_AgentServiceServicer_to_server(AgentServiceImpl(data_provider, conf), self._server) self._server.add_insecure_port(f'[::]:{conf.agent_port or 0}') def start(self): self._server.start() def wait_for_termination(self): self._server.wait_for_termination() def stop(self, grace=None): self._server.stop(grace=grace) ================================================ FILE: monolith/agent_service/agent_service_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from absl import logging import grpc from kazoo.exceptions import NoNodeError, NodeExistsError import os import socket import unittest from monolith.agent_service import utils from monolith.agent_service.agent_service import AgentService from monolith.agent_service.agent_service_pb2 import HeartBeatRequest, ServerType, \ GetReplicasRequest from monolith.agent_service.agent_service_pb2_grpc import AgentServiceStub from monolith.agent_service.mocked_zkclient import FakeKazooClient from monolith.agent_service.replica_manager import ReplicaWatcher, ReplicaMeta, ModelState from monolith.agent_service.svr_client import SvrClient MODEL_NAME = 'test_model_ctr' BASE_PATH = f'/test_model/{MODEL_NAME}/saved_models' NUM_PS_REPLICAS = 2 NUM_ENTRY_REPLICAS = 2 class AgentServiceTest(unittest.TestCase): @classmethod def setUpClass(cls) -> None: os.environ['TCE_INTERNAL_IDC'] = 'lf' os.environ['TCE_LOGICAL_CLUSTER'] = 'default' cls.zk = FakeKazooClient() cls.zk.start() cls.agent_conf: utils.AgentConfig = utils.AgentConfig(bzid='test_model', base_name=MODEL_NAME, deploy_type='ps', base_path=BASE_PATH, num_ps=20, dc_aware=True) cls.watcher = ReplicaWatcher(cls.zk, cls.agent_conf) cls.register(cls.zk) cls.watcher.watch_data() cls.agent = AgentService(cls.watcher, port=cls.agent_conf.agent_port) cls.agent.start() cls.client = SvrClient(cls.agent_conf) @classmethod def tearDownClass(cls) -> None: cls.agent.stop() cls.watcher.stop() @classmethod def register(cls, zk): path_prefix = cls.agent_conf.path_prefix path_to_meta, idx = {}, 2 for task_id in range(cls.agent_conf.num_ps): for replica_id in range(NUM_PS_REPLICAS): meta = ReplicaMeta(address=f'192.168.1.{idx}:{utils.find_free_port()}', stat=ModelState.AVAILABLE) replica_path = f'{path_prefix}/ps:{task_id}/{replica_id}' print(replica_path, flush=True) path_to_meta[replica_path] = meta idx += 1 for replica_id in range(NUM_ENTRY_REPLICAS): replica_path = f'{path_prefix}/entry:0/{replica_id}' meta = ReplicaMeta(address=f'192.168.1.{idx}:{utils.find_free_port()}', stat=ModelState.AVAILABLE) path_to_meta[replica_path] = meta idx += 1 for replica_path, meta in path_to_meta.items(): replica_meta_bytes = bytes(meta.to_json(), encoding='utf-8') try: zk.retry(zk.create, path=replica_path, value=replica_meta_bytes, ephemeral=True, makepath=True) except NodeExistsError: logging.info(f'{replica_path} has already exists') zk.retry(zk.set, path=replica_path, value=replica_meta_bytes) def test_heart_beat(self): resp = self.client.heart_beat(server_type=ServerType.PS) self.assertTrue(len(resp.addresses) == 20) def test_get_replicas(self): resp = self.client.get_replicas(server_type=ServerType.PS, task=NUM_PS_REPLICAS - 1) self.assertTrue(len(resp.address_list.address) == NUM_PS_REPLICAS) if __name__ == "__main__": unittest.main() ================================================ FILE: monolith/agent_service/agent_v1.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from absl import app, flags, logging from concurrent.futures import ThreadPoolExecutor from enum import Enum from kazoo.client import KazooClient import os import subprocess import signal from subprocess import CalledProcessError import threading import time from typing import List from monolith.agent_service.replica_manager import ReplicaManager from monolith.agent_service.agent_service import AgentService from monolith.agent_service.utils import AgentConfig, TFSServerType, DeployType, check_port_open from monolith.agent_service.agent_base import AgentBase, ServingLog, TFS_HOME, get_cmd_and_port, TFS_BINARY from monolith.native_training.zk_utils import MonolithKazooClient class ProcessType(Enum): PS = 1 ENTRY = 2 PROXY = 3 UNKONWN = 4 DENSE = 5 class ProcessNode(object): def __init__(self, config: AgentConfig, replica_mgr: ReplicaManager, proc_type: ProcessType, is_tce_main: bool = False, conf_path: str = None, tfs_log: str = None, tfs_binary: str = TFS_BINARY): assert proc_type != ProcessType.UNKONWN self._config = config self._replica_mgr = replica_mgr self._shell = False self._stderr = subprocess.STDOUT self._env = os.environ self.proc_type = proc_type self.is_tce_main = is_tce_main self._tfs_log = tfs_log self._is_failover = False self._port = 0 if proc_type == ProcessType.PS: self._cmd, self._port = get_cmd_and_port(config, conf_path, TFSServerType.PS, tfs_binary=tfs_binary) elif proc_type == ProcessType.ENTRY: self._env = os.environ.copy() self._env["PORT2"] = str(self._config.agent_port) self._cmd, self._port = get_cmd_and_port(config, conf_path, TFSServerType.ENTRY, tfs_binary=tfs_binary) elif proc_type == ProcessType.DENSE: self._cmd, self._port = get_cmd_and_port(config, conf_path, TFSServerType.DENSE, tfs_binary=tfs_binary) else: self._cmd, self._port = get_cmd_and_port(config, conf_path, tfs_binary=tfs_binary) self._popen = None self._sub_procs = {} @property def sub_procs(self): return self._sub_procs def add_subproc(self, pn: 'ProcessNode'): if pn.proc_type in self._sub_procs: logging.warning(f'process {pn.proc_type} exists!') else: self._sub_procs[pn.proc_type] = pn @property def returncode(self): if self._is_failover: return None else: return None if self._popen is None else self._popen.returncode def poll(self): if self._is_failover: return None else: return None if self._popen is None else self._popen.poll() def kill(self): # kill subprocess for proc in self._sub_procs.values(): if proc is None: continue cnt, max_cnt = 0, 3 proc.poll() while proc.returncode is None and cnt < max_cnt: logging.info(f"kill proc {proc}") proc.kill() time.sleep(1) proc.poll() cnt += 1 # kill self if self._popen is not None: cnt, max_cnt = 0, 3 self._popen.poll() while self._popen.returncode is None and cnt < max_cnt: logging.info(f"kill proc {self._popen}") self._popen.kill() time.sleep(1) self._popen.poll() cnt += 1 def run(self): waiting_sec, max_waiting_sec = 0, 3600 if self.proc_type == ProcessType.ENTRY: # waiting for PS status change time.sleep(self._config.update_model_status_interval * 2) waiting_sec += self._config.update_model_status_interval * 2 # check at least one replica of all PSs are stared while not self._replica_mgr.is_ps_set_started( ) and waiting_sec < max_waiting_sec: time.sleep(self._config.update_model_status_interval * 2) waiting_sec += self._config.update_model_status_interval * 2 if waiting_sec >= max_waiting_sec: logging.error("found PS timeout") return False # check at least one replica of Dense are stared if self._config.dense_alone: while not self._replica_mgr.is_dense_set_started( ) and waiting_sec < max_waiting_sec: time.sleep(self._config.update_model_status_interval * 2) waiting_sec += self._config.update_model_status_interval * 2 if waiting_sec >= max_waiting_sec: logging.error("found Dense timeout") return False # strat self with ServingLog(self.proc_type.name.lower(), self._tfs_log) as log_stdout: # Popen will return self._popen = subprocess.Popen(self._cmd.split(), shell=self._shell, stderr=self._stderr, stdout=log_stdout if "MLP_POD_NAME" not in os.environ else None, env=self._env) logging.info(f'pid of <{self._cmd}> is {self._popen.pid}') if not self.wait_for_started(): logging.error(f"start {self.proc_type} failed") return False # start subprocess for proc in self._sub_procs.values(): if not proc.run(): logging.error(f"start {proc} failed") return False return True def failover(self): self._is_failover = True if not self.is_tce_main and (self.proc_type == ProcessType.PS or self.proc_type == ProcessType.DENSE): logging.info(f"failover {self.proc_type}, run") self.run() else: logging.info(f"failover {self.proc_type}, kill") self.kill() self._is_failover = False def wait_for_started(self): if self._port == 0: return True waiting_sec, max_waiting_sec = 0, 3600 while waiting_sec <= max_waiting_sec: ret = check_port_open(self._port) if ret: logging.info(f"proc {self.proc_type} opened!") return True logging.info(f"proc {self.proc_type} not open!") time.sleep(10) waiting_sec += 10 logging.info(f"proc {self.proc_type} start failed!") return False def get_proc(node: ProcessNode, res: List[ProcessNode]): res.append(node) for proc in node.sub_procs.values(): if proc is not None: get_proc(proc, res) class ProcessMgr(object): _is_killed = False _lock = threading.RLock() def __init__(self): self.pid = os.getpid() self._sub_procs: List[ProcessNode] = [] signal.signal(signal.SIGTERM, self.signal_handler) signal.signal(signal.SIGINT, self.signal_handler) self._thread_stopped = False self._thread = threading.Thread(target=self._poll) self._pool = ThreadPoolExecutor(max_workers=2) def add_subproc(self, proc: ProcessNode): self._sub_procs.append(proc) def signal_handler(self, signum, frame): def target(): with self._lock: if not ProcessMgr._is_killed: self._thread_stopped = True ProcessMgr._is_killed = True if signum in {signal.SIGINT, signal.SIGTERM}: logging.info(f"catch signal {signum}, kill all") self.kill_all() return True else: raise RuntimeError(f"unknown signal {signum} at {frame}") else: return True future = self._pool.submit(target) future.result() def _poll(self): procs = [] for proc in self._sub_procs: get_proc(proc, procs) logging.info(f"the number of procs is {len(procs)} ") while not self._thread_stopped and not ProcessMgr._is_killed: time.sleep(1) for proc in procs: proc.poll() if ProcessMgr._is_killed or self._thread_stopped: break if proc.returncode is not None: logging.info( f"{proc.proc_type} {proc.returncode} shutdown, kill all proc...") #proc.failover() #先简化管理, 有进程挂掉的话就整体挂了 self.kill_all() def start(self): for proc in self._sub_procs: if not proc.run(): logging.error(f"start {proc} failed, kill all proc") self.kill_all() logging.info('start poll thread') self._thread.start() def kill_all(self, include_self=True): with self._lock: ProcessMgr._is_killed = True for proc in self._sub_procs: logging.info(f"kill proc {proc.proc_type}") # [todo] add graceful shutdown later proc.kill() if include_self: logging.info("kill self") os.kill(os.getpid(), signal.SIGKILL) class AgentV1(AgentBase): def __init__(self, config: AgentConfig, conf_path: str, tfs_log: str, tfs_binary: str = TFS_BINARY): super(AgentV1, self).__init__(config) self.zk = MonolithKazooClient(hosts=config.zk_servers) self.replica_mgr = ReplicaManager(self.zk, config) self.agent_service = AgentService(self.replica_mgr.watcher, port=config.agent_port) pm = ProcessMgr() if config.deploy_type == DeployType.MIXED: ps_proc = ProcessNode(config, self.replica_mgr, proc_type=ProcessType.PS, conf_path=conf_path, tfs_log=tfs_log, tfs_binary=tfs_binary) pm.add_subproc(ps_proc) if config.dense_alone: dense_proc = ProcessNode(config, self.replica_mgr, proc_type=ProcessType.DENSE, conf_path=conf_path, tfs_log=tfs_log, tfs_binary=tfs_binary) pm.add_subproc(dense_proc) entry_proc = ProcessNode(config, self.replica_mgr, proc_type=ProcessType.ENTRY, is_tce_main=True, conf_path=conf_path, tfs_log=tfs_log, tfs_binary=tfs_binary) pm.add_subproc(entry_proc) elif config.deploy_type == DeployType.ENTRY: proxy_proc = ProcessNode(config, self.replica_mgr, proc_type=ProcessType.ENTRY, is_tce_main=True, conf_path=conf_path, tfs_log=tfs_log, tfs_binary=tfs_binary) pm.add_subproc(proxy_proc) elif config.deploy_type == DeployType.PS: ps_proc = ProcessNode(config, self.replica_mgr, proc_type=ProcessType.PS, is_tce_main=True, conf_path=conf_path, tfs_log=tfs_log, tfs_binary=tfs_binary) pm.add_subproc(ps_proc) else: dense_proc = ProcessNode(config, self.replica_mgr, proc_type=ProcessType.DENSE, is_tce_main=True, conf_path=conf_path, tfs_log=tfs_log, tfs_binary=tfs_binary) pm.add_subproc(dense_proc) self.process_mgr = pm def start(self): self.zk.start() logging.info(f'start kazoo finished!') self.replica_mgr.start() logging.info(f'start replica_mgr finished!') self.agent_service.start() logging.info(f'start agent service at localhost:{self.config.agent_port}') self.process_mgr.start() logging.info(f'start ProcessMgr finished!') def wait_for_termination(self): self.agent_service.wait_for_termination() def stop(self): self.process_mgr.kill_all(include_self=False) logging.info(f'close ProcessMgr finished!') self.agent_service.stop() logging.info(f'close agent service at localhost:{self.config.agent_port}') self.replica_mgr.stop() logging.info(f'close replica_mgr finished!') self.zk.stop() logging.info(f'close kazoo finished!') ================================================ FILE: monolith/agent_service/agent_v3.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from collections import defaultdict from functools import partial from absl import app, flags, logging import json import os import re import time import signal import threading import tempfile from threading import RLock from typing import Callable, Dict, List, Tuple from subprocess import Popen, STDOUT from google.protobuf import text_format, json_format from monolith.agent_service.utils import AgentConfig, DeployType, gen_model_config, get_local_ip, \ normalize_regex from monolith.agent_service.tfs_wrapper import TFSWrapper, State from monolith.agent_service.agent_service import AgentService, AgentDataProvider from monolith.agent_service.agent_base import AgentBase from monolith.agent_service.backends import ZKBackend, Container, SavedModel, \ ContainerServiceInfo, SavedModelDeployConfig from tensorflow_serving.config import model_server_config_pb2 def gen_empty_model_config_file(): tmp_file = tempfile.mktemp() with open(tmp_file, "w") as f: f.write("model_config_list {}") return tmp_file class AgentV3(AgentBase): _lock = RLock() def __init__(self, config: AgentConfig, conf_path: str, tfs_log: str): super(AgentV3, self).__init__(config) assert config.deploy_type == DeployType.UNIFIED, "agent v3 only supports unifed deploy_type" assert config.agent_version == 3, f"agent version {config.agent_version} unexpected" self._conf_path = conf_path self._tfs_log = tfs_log self._exit_event = threading.Event() signal.signal(signal.SIGTERM, self.signal_handler) signal.signal(signal.SIGINT, self.signal_handler) self._model_config_path = gen_empty_model_config_file() self._tfs_wrapper = TFSWrapper(config.tfs_port_archon, config.tfs_port_grpc, config.tfs_port_http, self._model_config_path, config, self._tfs_log) self._layout_filters = [] if config.layout_filters: shard_id = max(config.shard_id, 0) shard_num = max(config.num_shard, 1) for raw_filter in config.layout_filters: for k, v in [('${shard_id}', shard_id), ('${shard_num}', shard_num)]: raw_filter = raw_filter.replace(k, str(v)) match, cond = raw_filter.split(";", 1) self._layout_filters.append((normalize_regex(match), cond)) self._container = Container(self.config.container_cluster, self.config.container_id) local_ip = get_local_ip() self._service_info = ContainerServiceInfo( grpc=f"{local_ip}:{self.config.tfs_port_grpc}", http=f"{local_ip}:{self.config.tfs_port_http}", archon=f"{local_ip}:{self.config.tfs_port_archon}", agent=f"{local_ip}:{self.config.agent_port}", idc=self.config.idc, debug_info=json.dumps({ 'layout_path': config.layout_path, 'layout_filters': [ f"{match};{cond}" for match, cond in self._layout_filters ] })) self._backend = ZKBackend(bzid=config.bzid, zk_servers=config.zk_servers) self._threads = [] self._agent_service = AgentService( AgentDataProvider(addrs_fn=self._gen_addrs_map), conf=config) def _gen_addrs_map(self): service_map = self._backend.get_service_map() addrs_map = {} for model_name in service_map: for sub_graph, service_infos in service_map[model_name].items(): addrs_map[f"{model_name}:{sub_graph}"] = [ service_info.grpc if self._tfs_wrapper.is_grpc_remote_op else service_info.archon for service_info in service_infos ] return addrs_map def sync_available_saved_models(self): saved_model_status = self._tfs_wrapper.list_saved_models_status() available_saved_models = set() for saved_model_name, status in saved_model_status.items(): if status.state == State.AVAILABLE: model_name, sub_graph = saved_model_name.split(":")[:2] available_saved_models.add(SavedModel(model_name, sub_graph)) self._backend.sync_available_saved_models(self._container, available_saved_models) def layout_update_callback( self, saved_models: List[Tuple[SavedModel, SavedModelDeployConfig]]) -> bool: logging.info(f"layout callback with saved_models: {saved_models}") model_server_config = model_server_config_pb2.ModelServerConfig() model_server_config.model_config_list.SetInParent() model_config_list = model_server_config.model_config_list.config for saved_model, deploy_config in saved_models: accepted = len(self._layout_filters) == 0 for match, cond in self._layout_filters: m = re.match(match, saved_model.sub_graph) if m: if eval(cond, None, {k: int(v) for k, v in m.groupdict().items()}): accepted = True logging.info(f"loading {str(saved_model)} with rule {match}:{cond}") break if not accepted: continue tfs_model_config = gen_model_config( name=str(saved_model), base_path=deploy_config.model_base_path, version_policy=deploy_config.version_policy) model_config_list.add().CopyFrom(tfs_model_config) logging.info( f"writing model server_config: {text_format.MessageToString(model_server_config)}" ) with open(self._model_config_path, 'w') as f: f.write(text_format.MessageToString(model_server_config)) return True def signal_handler(self, signum, frame): logging.info(f"catch signal {signum}, frame {frame}") self._exit_event.set() def start_bg_thread(self, fn, interval=10): def target(): while not self._exit_event.is_set(): try: fn() except Exception as e: logging.error(f"error in bg thread: {e}") time.sleep(interval) bg_thread = threading.Thread(target=target) bg_thread.start() self._threads.append(bg_thread) def start(self): self._tfs_wrapper.start() self._backend.start() self._agent_service.start() self.start_bg_thread(partial(self._backend.report_service_info, self._container, self._service_info), interval=60) self.start_bg_thread(self.sync_available_saved_models, interval=30) self._backend.register_layout_callback(self.config.layout_path, self.layout_update_callback) def stop(self): self._exit_event.set() try: for t in self._threads: t.join() self._agent_service.stop() self._backend.stop() self._tfs_wrapper.stop() except Exception as e: logging.warning(e) def wait_for_termination(self): while not self._exit_event.is_set(): time.sleep(1) ret_code = self._tfs_wrapper.poll() if ret_code is not None: self._exit_event.set() self.stop() time.sleep(1) os.kill(os.getpid(), signal.SIGKILL) ================================================ FILE: monolith/agent_service/agent_v3_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import json import time import unittest import os from absl import logging from monolith.agent_service import utils from monolith.agent_service import backends from monolith.agent_service.agent_v3 import AgentV3 from monolith.agent_service.tfs_wrapper import FakeTFSWrapper from monolith.agent_service.mocked_zkclient import FakeKazooClient class AgentV3Test(unittest.TestCase): @classmethod def setUpClass(cls) -> None: cls.bzid = 'gip' os.environ['MY_HOST_IP'] = '127.0.0.1' agent_conf = utils.AgentConfig(bzid='gip', deploy_type='unified', agent_version=3, layout_pattern="/gip/layout", zk_servers="127.0.0.1:8888") base_path = os.environ["TEST_TMPDIR"] cls.agent = AgentV3(config=agent_conf, conf_path=os.path.join(base_path, '/monolith_serving/conf'), tfs_log=os.path.join('monolith_serving/logs/log.log')) model_config_path = cls.agent._model_config_path # replace tfs wrapper and zk cls.tfs_wrapper = FakeTFSWrapper(model_config_path) cls.agent._tfs_wrapper = cls.tfs_wrapper cls.zk = FakeKazooClient() cls.backend = cls.agent._backend cls.backend._zk = cls.zk cls.agent.start() logging.info('setUpClass finished!') def base_path(sub_graph): return os.path.join(os.environ["TEST_TMPDIR"], "test_ffm_model/exported_models", sub_graph) for sub_graph in ['entry', 'ps_0', 'ps_1', 'ps_2']: config = { 'model_base_path': base_path(sub_graph), 'version_policy': 'latest' } path = f'/gip/saved_models/test_ffm_model/{sub_graph}' value = json.dumps(config).encode('utf-8') cls.zk.create(path, value=value, makepath=True) @classmethod def tearDownClass(cls) -> None: cls.agent.stop() logging.info('tearDownClass finished!') def test_service_info(self): self.assertEqual(self.agent._service_info, self.backend.get_service_info(self.agent._container)) def test_publish_models(self): self.assertEqual(self.tfs_wrapper.list_saved_models(), []) # publish self.zk.ensure_path("/gip/layout/test_ffm_model:entry") self.zk.ensure_path("/gip/layout/test_ffm_model:ps_0") # check tfs serving self.assertEqual(self.tfs_wrapper.list_saved_models(), ['test_ffm_model:entry', 'test_ffm_model:ps_0']) # force binding info to propagate self.agent.sync_available_saved_models() self.assertEqual( self.backend.get_service_map(), { 'test_ffm_model': { 'entry': [self.agent._service_info], 'ps_0': [self.agent._service_info] } }) # unload one model self.zk.delete("/gip/layout/test_ffm_model:ps_0") # check tfs serving self.assertEqual(self.tfs_wrapper.list_saved_models(), ['test_ffm_model:entry']) # force binding info to propagate self.agent.sync_available_saved_models() self.assertEqual(self.backend.get_service_map(), {'test_ffm_model': { 'entry': [self.agent._service_info] }}) if __name__ == "__main__": logging.use_absl_handler() logging.get_absl_handler().setFormatter(fmt=logging.PythonFormatter()) logging.set_verbosity(logging.INFO) unittest.main() ================================================ FILE: monolith/agent_service/backends.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import abc from collections import defaultdict from dataclasses import dataclass from functools import partial from threading import RLock, Event from absl import logging from dataclasses_json import dataclass_json from typing import Callable, Dict, List, Tuple, Any, Set, Union from kazoo.client import KazooState from kazoo.recipe.watchers import ChildrenWatch from kazoo.exceptions import NodeExistsError, ZookeeperError, \ NoNodeError, NotEmptyError, ConnectionClosedError from monolith.native_training.zk_utils import MonolithKazooClient @dataclass(frozen=True) class SavedModel: model_name: str = None sub_graph: str = None def __repr__(self): return str(self) def __str__(self): return f"{self.model_name}:{self.sub_graph}" @dataclass_json @dataclass(frozen=True) class SavedModelDeployConfig: model_base_path: str = None version_policy: str = None def serialize(self) -> bytes: return bytes(self.to_json(), encoding='utf-8') @classmethod def deserialize(cls, serialized: bytes) -> 'SavedModelDeployConfig': return cls.from_json(str(serialized, encoding='utf-8')) @dataclass(frozen=True) class Container: ctx_cluster: str = None ctx_id: str = None def __repr__(self): return str(self) def __str__(self): return f"{self.ctx_cluster}:{self.ctx_id}" @dataclass_json @dataclass(frozen=True) class ContainerServiceInfo: grpc: str = None # grpc ip:port http: str = None # http ip:port archon: str = None # archon ip:port agent: str = None # agent ip:port idc: str = None # dc name debug_info: str = None def serialize(self) -> bytes: return bytes(self.to_json(), encoding='utf-8') @classmethod def deserialize(cls, serialized: bytes) -> 'ContainerServiceInfo': return cls.from_json(str(serialized, encoding='utf-8')) class AgentBackend(abc.ABC): def __init__(self): pass @abc.abstractmethod def register_layout_callback( self, layout_path: str, callback: Callable[[List[Tuple[SavedModel, SavedModelDeployConfig]]], None] ) -> None: """ Invoke {callback} on layout updates(adding or removing saved_models) """ pass @abc.abstractmethod def sync_available_saved_models(self, saved_models: List[SavedModel]) -> None: """ Report available saved models serving in localhost """ pass @abc.abstractmethod def report_service_info(self, container: Container, service_info: ContainerServiceInfo) -> None: pass @abc.abstractmethod def get_service_map(self) -> Dict[str, Dict[str, List[ContainerServiceInfo]]]: """ Get service info map { "model_name": { "sub_graph: [ { "idc": "LQ", "archon": "10.xx.xx.1:9876", "grpc": "10.xx.xx.1:8765", "http": "10.xx.xx.1:6789", "agent": "10.xx.xx.1:6787" } ] } } """ pass @abc.abstractmethod def report_service_info(self, container: Container, service_info: ContainerServiceInfo) -> None: pass @abc.abstractmethod def get_service_info(self, container) -> ContainerServiceInfo: pass @abc.abstractmethod def start(self): pass @abc.abstractmethod def stop(self): pass class CtrlBackend(abc.ABC): def __init__(self): pass @abc.abstractmethod def list_saved_models(self, model_name: str) -> List[SavedModel]: pass @abc.abstractmethod def decl_saved_model(self, saved_model: SavedModel, deploy_config: SavedModelDeployConfig): pass @abc.abstractmethod def add_to_layout(self, layout: str, saved_model: SavedModel): pass @abc.abstractmethod def remove_from_layout(self, layout: str, saved_model: SavedModel): pass @abc.abstractmethod def bzid_info(self): pass @abc.abstractmethod def start(self): pass @abc.abstractmethod def stop(self): pass class SyncBackend(abc.ABC): def __init__(self): pass @abc.abstractmethod def subscribe_model(self, model_name: str): pass @abc.abstractmethod def get_sync_targets( self, sub_graph: str) -> Tuple[str, Union[List[str], Dict]]: pass @abc.abstractmethod def start(self): pass @abc.abstractmethod def stop(self): pass class ZKBackend(AgentBackend, CtrlBackend, SyncBackend): _lock = RLock() def __init__(self, bzid: str, zk_servers: str): super(ZKBackend, self).__init__() self._bzid = bzid self._zk = MonolithKazooClient(hosts=zk_servers) self._available_saved_model = set() self._service_info_map = {} self._children_watcher_map: Dict[str, ChildrenWatch] = {} self._sync_model_name = None self._is_lost = Event() def zk_listener(state): if state == KazooState.LOST: logging.error("zk state lost, set lost flag") self._is_lost.set() else: logging.warning(f"zk state changed to {state}, unset lost flag") self._is_lost.clear() return False self._zk.add_listener(zk_listener) def sync_available_saved_models(self, container: Container, saved_models: Set[SavedModel]) -> None: """ Report available saved models serving in tensorflow serving """ with self._lock: if self._is_lost.is_set(): self._available_saved_model.clear() logging.warning("zk is lost, try restarting") self._zk.restart() return add_saved_models = saved_models - self._available_saved_model remove_saved_models = self._available_saved_model - saved_models logging.info( f"available saved models updating, add: {add_saved_models}, remove: {remove_saved_models}" ) for saved_model in add_saved_models: bind_path = f"/{self._bzid}/binding/{saved_model.model_name}/{saved_model.sub_graph}:{container}" self.create_znode(bind_path, b"", ephemeral=True, makepath=True) for saved_model in remove_saved_models: bind_path = f"/{self._bzid}/binding/{saved_model.model_name}/{saved_model.sub_graph}:{container}" self.delete_znode(bind_path) logging.info(f"available saved models updated: {saved_models}") self._available_saved_model = saved_models def register_layout_callback( self, layout_path: str, callback: Callable[[List[Tuple[SavedModel, SavedModelDeployConfig]]], bool]): """ Invoke {callback} on layout updates(adding or removing saved_models) """ def callback_wrap(children: List[str]): with self._lock: logging.info(f"layout updated: {children}") model_names = set() saved_models = [] for child in children: model_name, sub_graph = child.split(":")[:2] saved_model = SavedModel(model_name, sub_graph) fetch_path = f"/{self._bzid}/saved_models/{model_name}/{sub_graph}" data = self.get_znode(fetch_path) if data is None: logging.error("missing deploy config for saved model") continue saved_models.append( (saved_model, SavedModelDeployConfig.deserialize(data))) model_names.add(model_name) self._service_info_map = { model_name: self._service_info_map.get(model_name, {}) for model_name in model_names } for model_name in model_names: binding_watch_path = f"/{self._bzid}/binding/{model_name}" self._children_watch(binding_watch_path, partial(self._bind_callback, model_name)) ret = callback(saved_models) return ret with self._lock: self._zk.ensure_path(layout_path) self._children_watch(layout_path, callback_wrap) def get_service_map(self) -> Dict[str, Dict[str, List[ContainerServiceInfo]]]: """ Get service info map { "model_name": { "sub_graph: [ { "idc": "LQ", "archon": "10.xx.xx.1:9876", "grpc": "10.xx.xx.1:8765", "http": "10.xx.xx.1:6789", "agent": "10.xx.xx.1:6787" } ] } } """ return self._service_info_map.copy() def _bind_callback(self, model_name, children): with self._lock: if model_name not in self._service_info_map: logging.info(f"model {model_name} no longer subscribed.") return False new_binding = defaultdict(list) for child in children: sub_graph, ctx_cluster, ctx_id = child.split(":")[:3] saved_model = SavedModel(model_name, sub_graph) container = Container(ctx_cluster, ctx_id) service_info = self.get_service_info(container) if service_info is None: logging.error(f"no serivice info of {child}") continue new_binding[sub_graph].append(service_info) self._service_info_map[model_name] = new_binding def report_service_info(self, container: Container, service_info: ContainerServiceInfo) -> None: service_info_path = f"/{self._bzid}/container_service/{container}" self.create_znode(service_info_path, service_info.serialize(), ephemeral=True, makepath=True) def get_service_info(self, container) -> ContainerServiceInfo: service_info_path = f"/{self._bzid}/container_service/{container}" data = self.get_znode(service_info_path) if data is None: return None else: return ContainerServiceInfo.deserialize(data) def _children_watch(self, path, callback): with self._lock: if path in self._children_watcher_map and not self._children_watcher_map[ path]._stopped: logging.info(f"active watcher exists on path {path}") else: self._zk.ensure_path( path ) # make sure the path exists otherwise the watcher may not be effective self._children_watcher_map[path] = self._zk.ChildrenWatch( path, callback) logging.info(f"registered new watcher on {path}") def list_saved_models(self, model_name: str) -> List[SavedModel]: model_path = f"/{self._bzid}/saved_models/{model_name}" try: sub_graphs = self._zk.get_children(model_path) return [SavedModel(model_name, sub_graph) for sub_graph in sub_graphs] except NoNodeError: return [] def decl_saved_model(self, saved_model: SavedModel, deploy_config: SavedModelDeployConfig): saved_model_path = f"/{self._bzid}/saved_models/{saved_model.model_name}/{saved_model.sub_graph}" logging.info(f"publishing {saved_model} -> {deploy_config}") self.create_znode(saved_model_path, deploy_config.serialize(), makepath=True) def add_to_layout(self, layout: str, saved_model: SavedModel): path = f"{layout}/{saved_model}" self._zk.ensure_path(path) def remove_from_layout(self, layout: str, saved_model: SavedModel): path = f"{layout}/{saved_model}" try: self._zk.delete(path) except NoNodeError: pass def bzid_info(self): # model deploy configs model_info = defaultdict(lambda: defaultdict(dict)) if self._zk.exists(f"/{self._bzid}/saved_models"): model_names = self._zk.get_children(f"/{self._bzid}/saved_models") for model_name in model_names: sub_graphs = self._zk.get_children( f"/{self._bzid}/saved_models/{model_name}") model_info[model_name]['sub_graphs_total'] = len(sub_graphs) for sub_graph in sub_graphs: model_info[model_name][sub_graph]['deploy_config'] = self.get_znode( f"/{self._bzid}/saved_models/{model_name}/{sub_graph}").decode( 'utf-8') container_info = defaultdict(lambda: defaultdict(dict)) # container service info if self._zk.exists(f"/{self._bzid}/container_service"): containers = self._zk.get_children(f"/{self._bzid}/container_service") for container in containers: cluster, container_id = container.split(":")[:2] container_info[cluster][container_id]['service_info'] = self.get_znode( f"/{self._bzid}/container_service/{container}").decode('utf-8') # layout info layout_info = defaultdict(lambda: defaultdict(dict)) if self._zk.exists(f"/{self._bzid}/layouts"): layouts = self._zk.get_children(f"/{self._bzid}/layouts") for layout in layouts: saved_models = self._zk.get_children(f"/{self._bzid}/layouts/{layout}") if saved_models: layout_info[layout] = sorted(saved_models) else: layout_info[layout] = [] # bindings if self._zk.exists(f"/{self._bzid}/binding"): model_names = self._zk.get_children(f"/{self._bzid}/binding") for model_name in model_names: bindings = self._zk.get_children(f"/{self._bzid}/binding/{model_name}") for binding in bindings: sub_graph, cluster, container_id = binding.split(":")[:3] if 'bindings' not in model_info[model_name][sub_graph]: model_info[model_name][sub_graph]['bindings'] = [] model_info[model_name][ 'sub_graphs_available'] = model_info[model_name].get( 'sub_graphs_available', 0) + 1 model_info[model_name][sub_graph]['bindings'].append( f"{cluster}:{container_id}") if 'saved_models' not in container_info[cluster][container_id]: container_info[cluster][container_id]['saved_models'] = [] container_info[cluster][container_id]['saved_models'].append( f"{model_name}:{sub_graph}") def sorted_dict(d): return dict(sorted(d.items())) return { 'model_info': { model_name: sorted_dict(model_info[model_name]) for model_name in model_info }, 'container_info': { cluster: sorted_dict(container_info[cluster]) for cluster in container_info }, 'layout_info': {layout: layout_info[layout] for layout in layout_info} } # sync backend def subscribe_model(self, model_name: str): if model_name == self._sync_model_name: return assert self._sync_model_name is None self._sync_model_name = model_name self._service_info_map[model_name] = self._service_info_map.get( model_name, {}) binding_watch_path = f"/{self._bzid}/binding/{model_name}" self._children_watch(binding_watch_path, partial(self._bind_callback, model_name)) def get_sync_targets(self, sub_graph: str) -> Tuple[str, List[str]]: with self._lock: if self._is_lost.is_set(): self._available_saved_model.clear() logging.warning("zk is lost, try restarting") self._zk.restart() sub_graph_map = self._service_info_map.get(self._sync_model_name, {}) service_infos = sub_graph_map.get(sub_graph, []) return f"{self._sync_model_name}:{sub_graph}", [ service_info.grpc for service_info in service_infos ] def create_znode(self, path, value, ephemeral=False, makepath=False) -> None: with self._lock: try: self._zk.create(path, value=value, ephemeral=ephemeral, makepath=makepath) except NodeExistsError as e: self._zk.retry(self._zk.set, path=path, value=value) except Exception as e: logging.error(f"exception in create_znode: {e}") def delete_znode(self, path) -> None: with self._lock: try: self._zk.delete(path) except Exception as e: logging.error(f"exception in delete_znode: {e}") def get_znode(self, path) -> bytes: try: return self._zk.get(path)[0] except NoNodeError: return None def start(self): self._zk.start() def stop(self): self._zk.stop() ================================================ FILE: monolith/agent_service/backends_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import json import os import unittest from monolith.agent_service import utils from monolith.agent_service import backends from monolith.agent_service.mocked_zkclient import FakeKazooClient class ZKBackendTest(unittest.TestCase): @classmethod def setUpClass(cls) -> None: cls.bzid = 'gip' cls.container = backends.Container("default", "asdf") cls.service_info = backends.ContainerServiceInfo(grpc="localhost:8888", http="localhost:8889", archon="localhost:8890", agent="localhost:8891", idc="IDC") cls.backend = backends.ZKBackend(cls.bzid, zk_servers='127.0.0.1:9999') cls.zk = FakeKazooClient() cls.backend._zk = cls.zk cls.layout_record = None def layout_callback(saved_models): cls.layout_record = saved_models cls.backend.start() cls.backend.report_service_info(cls.container, cls.service_info) cls.layout_path = "/gip/layouts/test_layout/mixed" cls.backend.register_layout_callback(cls.layout_path, layout_callback) print("setUpClass finished!") @classmethod def tearDownClass(cls) -> None: cls.backend.stop() print('tearDownClass finished!') def test_register_service(self): service_info = self.backend.get_service_info(self.container) self.assertEqual(service_info, self.service_info) def test_layout_callback(self): def base_path(sub_graph): return os.path.join(os.environ["TEST_TMPDIR"], "test_ffm_model/exported_models", sub_graph) for sub_graph in ['entry', 'ps_0', 'ps_1', 'ps_2']: saved_model = backends.SavedModel('test_ffm_model', sub_graph) self.backend.decl_saved_model( saved_model, backends.SavedModelDeployConfig(base_path(sub_graph), 'latest')) self.backend.add_to_layout(self.layout_path, saved_model) expected_saved_models = [ (backends.SavedModel("test_ffm_model", "entry"), backends.SavedModelDeployConfig(base_path('entry'), 'latest')), (backends.SavedModel("test_ffm_model", "ps_0"), backends.SavedModelDeployConfig(base_path('ps_0'), 'latest')), (backends.SavedModel("test_ffm_model", "ps_1"), backends.SavedModelDeployConfig(base_path('ps_1'), 'latest')), (backends.SavedModel("test_ffm_model", "ps_2"), backends.SavedModelDeployConfig(base_path('ps_2'), 'latest')), ] self.assertEqual(self.layout_record, expected_saved_models) self.backend.remove_from_layout( self.layout_path, backends.SavedModel('test_ffm_model', 'entry')) self.assertEqual(self.layout_record, [ (backends.SavedModel("test_ffm_model", "ps_0"), backends.SavedModelDeployConfig(base_path('ps_0'), 'latest')), (backends.SavedModel("test_ffm_model", "ps_1"), backends.SavedModelDeployConfig(base_path('ps_1'), 'latest')), (backends.SavedModel("test_ffm_model", "ps_2"), backends.SavedModelDeployConfig(base_path('ps_2'), 'latest')), ]) def test_sync_available_models(self): self.backend.sync_available_saved_models( self.container, { backends.SavedModel("test_ffm_model", "entry"), backends.SavedModel("test_ffm_model", "ps_0"), backends.SavedModel("test_ffm_model", "ps_1"), }) self.assertTrue( self.zk.exists(f"/gip/binding/test_ffm_model/entry:{self.container}")) self.assertTrue( self.zk.exists(f"/gip/binding/test_ffm_model/ps_0:{self.container}")) self.assertTrue( self.zk.exists(f"/gip/binding/test_ffm_model/ps_1:{self.container}")) def test_service_map(self): self.backend.sync_available_saved_models( self.container, { backends.SavedModel("test_ffm_model", "entry"), backends.SavedModel("test_ffm_model", "ps_0") }) expected = { 'test_ffm_model': { 'ps_0': [self.service_info], 'entry': [self.service_info] } } self.assertTrue(self.backend.get_service_map(), expected) def test_sync_backend(self): self.backend.subscribe_model("test_ffm_model") self.backend.sync_available_saved_models( self.container, { backends.SavedModel("test_ffm_model", "ps_0"), backends.SavedModel("test_ffm_model", "ps_1"), backends.SavedModel("test_ffm_model", "ps_2"), }) model_name, targets = self.backend.get_sync_targets("ps_1") self.assertEqual(model_name, "test_ffm_model:ps_1") self.assertEqual(targets, [self.service_info.grpc]) if __name__ == "__main__": unittest.main() ================================================ FILE: monolith/agent_service/client.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from absl import flags, logging, app from monolith.agent_service.zk_mirror import ZKMirror from monolith.native_training import env_utils from monolith.native_training.zk_utils import MonolithKazooClient from monolith.agent_service.data_def import ModelMeta, ReplicaMeta from typing import Dict from dataclasses import dataclass, field from tensorflow_serving.apis.get_model_status_pb2 import ModelVersionStatus ModelState = ModelVersionStatus.State FLAGS = flags.FLAGS flags.DEFINE_enum("cmd_type", "hb", [ "hb", "gr", "addr", "get", "clean", "load", "unload", "meta", "status", "profile" ], "cmd_type: hb, gr, addr, res, status") flags.DEFINE_string('zk_servers', None, 'zk servers') flags.DEFINE_string('bzid', None, 'business id') flags.DEFINE_string('model_name', None, 'model name') flags.DEFINE_string("target", None, "host:port") flags.DEFINE_enum("input_type", 'dump', [ "json", "pbtext", "dump", "binary", "instance", "example_batch", "example_batch_to_instance" ], "inputs type for prediction") flags.DEFINE_string("input_file", None, "The input file name") @dataclass class LoadSate: portal: bool = None publish: bool = None service: dict = field(default_factory=dict) # Dict[str, ModelState] class ServingClient(object): def __init__(self, zk_servers: str, bzid: str): self.kazoo = MonolithKazooClient(hosts=zk_servers) self.bzid = bzid self._zk = ZKMirror(self.kazoo, bzid) self._zk.start(is_client=True) def load(self, model_name: str, model_dir: str, ckpt: str = None, num_shard: int = -1): mm = ModelMeta(model_name=model_name, model_dir=model_dir, ckpt=ckpt, num_shard=num_shard) path = mm.get_path(self._zk.portal_base_path) if self._zk.exists(path): raise RuntimeError(f'{model_name} has exists') self._zk.create(path=path, value=mm.serialize(), include_data=True) def unload(self, model_name: str): mm = ModelMeta(model_name=model_name) path = mm.get_path(self._zk.portal_base_path) if self._zk.exists(path): self._zk.delete(path) else: logging.warning(f'{model_name} not exists, nothing to do!') def get_status(self, model_name: str) -> LoadSate: state = LoadSate() if self.kazoo.exists(f'/{self.bzid}/portal/{model_name}'): state.portal = True for node in self.kazoo.get_children(f'/{self.bzid}/publish'): shard_id, replica_id, name = node.split(':') if name == model_name: state.publish = True break service = {} for node in self.kazoo.get_children(f'/{self.bzid}/service/{model_name}'): for replica in self.kazoo.get_children( f'/{self.bzid}/service/{model_name}/{node}'): path = f'/{self.bzid}/service/{model_name}/{node}/{replica}' value, _ = self.kazoo.get(path) rm = ReplicaMeta.deserialize(value) service[f'{node}:{replica}'] = rm.stat state.service = service return state def main(_): env_utils.setup_host_ip() if FLAGS.zk_servers is None: raise ValueError(f'zk_servers is {FLAGS.zk_servers}') if FLAGS.bzid is None: raise ValueError(f'bzid is {FLAGS.bzid}') client = ServingClient(FLAGS.zk_servers, FLAGS.bzid) assert FLAGS.model_name is not None if FLAGS.cmd_type == 'load': assert FLAGS.model_dir is not None client.load(FLAGS.model_name, FLAGS.model_dir, FLAGS.ckpt, FLAGS.num_shard) elif FLAGS.cmd_type == 'unload': client.unload(FLAGS.model_name) else: print(client.get_status(FLAGS.model_name)) if __name__ == '__main__': app.run(main) ================================================ FILE: monolith/agent_service/constants.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. HOST_SHARD_ENV = "MONOLITH_HOST_SHARD_N" ================================================ FILE: monolith/agent_service/data_def.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass, field from dataclasses_json import dataclass_json from enum import Enum import os from typing import Dict, List, NewType, Optional from monolith.native_training.net_utils import AddressFamily from tensorflow_serving.util.status_pb2 import StatusProto from tensorflow_serving.apis.get_model_status_pb2 import ModelVersionStatus ModelState = ModelVersionStatus.State ModelName = NewType('ModelName', str) # simple model_name SubModelName = NewType('SubModelName', str) # entry, ps_{num}, dense SubModelSize = NewType('SubModelSize', int) # size in byte TFSModelName = NewType('TFSModelName', str) # f'{model_name}:{sub_model_name}' VersionPath = NewType('VersionPath', str) # .../exported_models/{sub_model_name}/{version} EmptyStatus = StatusProto() @dataclass_json @dataclass class ModelMeta(object): model_name: str = None model_dir: str = None ckpt: str = None num_shard: int = -1 action: str = 'NONE' spec_replicas: List[int] = field(default_factory=list) def get_path(self, base_path: str) -> str: return os.path.join(base_path, self.model_name) def serialize(self) -> bytes: return bytes(self.to_json(), encoding='utf-8') @classmethod def deserialize(cls, serialized: bytes) -> 'ModelMeta': return cls.from_json(str(serialized, encoding='utf-8')) @dataclass_json @dataclass class ResourceSpec(object): address: str = None # host:port shard_id: int = None replica_id: int = None memory: int = None cpu: float = -1.0 network: float = -1.0 work_load: float = -1.0 def get_path(self, base_path: str) -> str: return os.path.join(base_path, f"{self.shard_id}:{self.replica_id}") def serialize(self) -> bytes: return bytes(self.to_json(), encoding='utf-8') @classmethod def deserialize(cls, serialized: bytes) -> 'ResourceSpec': return cls.from_json(str(serialized, encoding='utf-8')) class PublishType(Enum): LOAD = 1 UNLOAD = 2 @dataclass_json @dataclass class PublishMeta(object): shard_id: int = None replica_id: int = -1 model_name: str = None num_ps: int = None total_publish_num: int = 1 sub_models: Dict[SubModelName, VersionPath] = None ptype: PublishType = PublishType.LOAD is_spec: bool = False def get_path(self, base_path: str) -> str: return os.path.join(base_path, f'{self.shard_id}:{self.replica_id}:{self.model_name}') def serialize(self) -> bytes: return bytes(self.to_json(), encoding='utf-8') @classmethod def deserialize(cls, serialized: bytes) -> 'PublishMeta': return cls.from_json(str(serialized, encoding='utf-8')) @dataclass_json @dataclass class ReplicaMeta: address: str = None # host:port address_ipv6: str = None # [host]:port stat: int = ModelState.UNKNOWN model_name: Optional[str] = None server_type: Optional[str] = None task: int = -1 replica: int = -1 archon_address: str = None # host:port archon_address_ipv6: str = None # [host]:port def serialize(self) -> bytes: return bytes(self.to_json(), encoding='utf-8') @classmethod def deserialize(cls, serialized: bytes) -> 'ReplicaMeta': return cls.from_json(str(serialized, encoding='utf-8')) def get_path(self, bzid: str, sep: str = '/') -> str: paths = [ '', bzid, 'service', self.model_name, f'{self.server_type}:{self.task}', str(self.replica) ] return sep.join(paths) def get_address(self, use_archon: bool = False, address_family: str = AddressFamily.IPV4) -> str: assert address_family in [AddressFamily.IPV4, AddressFamily.IPV6] ipv4_address = self.archon_address if use_archon else self.address if ipv4_address is not None and ipv4_address.startswith('0.0.0.0'): ipv4_address = None ipv6_address = self.archon_address_ipv6 if use_archon else self.address_ipv6 if ipv6_address is not None and ipv6_address.startswith('[::]'): ipv6_address = None if address_family == AddressFamily.IPV4: address = ipv4_address or ipv6_address else: address = ipv6_address or ipv4_address return address class EventType(Enum): PORTAL = 1 # Scheduler, ZK watch trigger SERVICE = 2 # StatusReportHandler, time trigger PUBLISH = 3 # ModelLoaderHandler, ZK watch trigger RESOURCE = 4 # ResourceReportHandler, time trigger UNKNOWN = 1 @dataclass_json @dataclass class Event(object): path: str = None data: bytes = b'' etype: EventType = EventType.UNKNOWN def serialize(self) -> bytes: return bytes(self.to_json(), encoding='utf-8') @classmethod def deserialize(cls, serialized: bytes) -> 'Event': return cls.from_json(str(serialized, encoding='utf-8')) ================================================ FILE: monolith/agent_service/data_def_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import unittest from monolith.agent_service.data_def import ModelMeta, ResourceSpec, ReplicaMeta class DataDefTest(unittest.TestCase): def serde(self, item): cls = item.__class__ serialized = item.serialize() recom = cls.deserialize(serialized) self.assertEqual(item, recom) def test_model_info(self): obj = ModelMeta(model_name='monolith', num_shard=3, model_dir="/tmp/opt", ckpt='model.ckpt-1234') self.serde(obj) def test_resource(self): obj = ResourceSpec(address="localhost:123", shard_id=10, replica_id=2, memory=12345, cpu=3.5) self.serde(obj) def test_replica_meta(self): obj = ReplicaMeta(address="localhost:123", model_name='monolith', server_type='ps', task=0, replica=0) self.serde(obj) if __name__ == "__main__": unittest.main() ================================================ FILE: monolith/agent_service/example_batch.pbtxt ================================================ named_feature_list { name: "f_goods_test30_bool" feature { } feature { } feature { fid_v1_list { value: 9376215397644962785 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { fid_v1_list { value: 9380318556689891903 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_user_ctx_client_version" feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_goods_source_id" feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_user_ctx_city" feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_goods_test32_double" feature { } feature { } feature { fid_v1_list { value: 9322172202116516833 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { fid_v1_list { value: 9322172202116516833 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_goods_comment_cnt" feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_goods_pub_time_hour" feature { } feature { } feature { fid_v1_list { value: 4235580633157097946 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { fid_v1_list { value: 4242591947990187885 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_user_test09_array_int32" feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { fid_v1_list { value: 9987835853247866286 value: 9982825002555554975 } } feature { } feature { } feature { } feature { fid_v1_list { value: 9987835853247866286 value: 9982825002555554975 } } feature { fid_v1_list { value: 9987835853247866286 value: 9982825002555554975 } } feature { fid_v1_list { value: 9987835853247866286 value: 9982825002555554975 } } feature { fid_v1_list { value: 9987835853247866286 value: 9982825002555554975 } } feature { fid_v1_list { value: 9987835853247866286 value: 9982825002555554975 } } feature { fid_v1_list { value: 9987835853247866286 value: 9982825002555554975 } } feature { fid_v1_list { value: 9987835853247866286 value: 9982825002555554975 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_user_test15_array_float" feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { fid_v1_list { value: 9736503367834602465 value: 9736503367834602465 } } feature { } feature { } feature { } feature { fid_v1_list { value: 9736503367834602465 value: 9736503367834602465 } } feature { fid_v1_list { value: 9736503367834602465 value: 9736503367834602465 } } feature { fid_v1_list { value: 9736503367834602465 value: 9736503367834602465 } } feature { fid_v1_list { value: 9736503367834602465 value: 9736503367834602465 } } feature { fid_v1_list { value: 9736503367834602465 value: 9736503367834602465 } } feature { fid_v1_list { value: 9736503367834602465 value: 9736503367834602465 } } feature { fid_v1_list { value: 9736503367834602465 value: 9736503367834602465 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_spm_1" feature { fid_v1_list { value: 5431062124068408289 } } feature { fid_v1_list { value: 5431062124068408289 } } feature { fid_v1_list { value: 5431062124068408289 } } feature { fid_v1_list { value: 5431062124068408289 } } feature { fid_v1_list { value: 5431062124068408289 } } feature { fid_v1_list { value: 5431062124068408289 } } feature { fid_v1_list { value: 5431062124068408289 } } feature { fid_v1_list { value: 5431062124068408289 } } feature { fid_v1_list { value: 5431062124068408289 } } feature { fid_v1_list { value: 5431062124068408289 } } feature { fid_v1_list { value: 5431062124068408289 } } feature { fid_v1_list { value: 5431062124068408289 } } feature { fid_v1_list { value: 5431062124068408289 } } feature { fid_v1_list { value: 5431062124068408289 } } feature { fid_v1_list { value: 5431062124068408289 } } feature { fid_v1_list { value: 5431062124068408289 } } feature { fid_v1_list { value: 5431062124068408289 } } feature { fid_v1_list { value: 5431062124068408289 } } feature { fid_v1_list { value: 5431062124068408289 } } feature { fid_v1_list { value: 5431062124068408289 } } feature { fid_v1_list { value: 5431062124068408289 } } feature { fid_v1_list { value: 5431062124068408289 } } feature { fid_v1_list { value: 5431062124068408289 } } feature { fid_v1_list { value: 5431062124068408289 } } feature { fid_v1_list { value: 5431062124068408289 } } feature { fid_v1_list { value: 5431062124068408289 } } feature { fid_v1_list { value: 5431062124068408289 } } feature { fid_v1_list { value: 5431062124068408289 } } feature { fid_v1_list { value: 5431062124068408289 } } feature { fid_v1_list { value: 5431062124068408289 } } feature { fid_v1_list { value: 5431062124068408289 } } feature { fid_v1_list { value: 5431062124068408289 } } feature { fid_v1_list { value: 5431062124068408289 } } feature { fid_v1_list { value: 5431062124068408289 } } feature { fid_v1_list { value: 5431062124068408289 } } feature { fid_v1_list { value: 5431062124068408289 } } feature { fid_v1_list { value: 5431062124068408289 } } feature { fid_v1_list { value: 5431062124068408289 } } feature { fid_v1_list { value: 5431062124068408289 } } feature { fid_v1_list { value: 5431062124068408289 } } feature { fid_v1_list { value: 5431062124068408289 } } feature { fid_v1_list { value: 5431062124068408289 } } feature { fid_v1_list { value: 5431062124068408289 } } feature { fid_v1_list { value: 5431062124068408289 } } feature { fid_v1_list { value: 5431062124068408289 } } feature { fid_v1_list { value: 5431062124068408289 } } feature { fid_v1_list { value: 5431062124068408289 } } feature { fid_v1_list { value: 5431062124068408289 } } feature { fid_v1_list { value: 5431062124068408289 } } feature { fid_v1_list { value: 5431062124068408289 } } feature { fid_v1_list { value: 5431062124068408289 } } } named_feature_list { name: "f_goods_share_cnt_1000" feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_spm_3" feature { fid_v1_list { value: 5467090921087372257 } } feature { fid_v1_list { value: 5467090921087372257 } } feature { fid_v1_list { value: 5467090921087372257 } } feature { fid_v1_list { value: 5467090921087372257 } } feature { fid_v1_list { value: 5467090921087372257 } } feature { fid_v1_list { value: 5467090921087372257 } } feature { fid_v1_list { value: 5467090921087372257 } } feature { fid_v1_list { value: 5467090921087372257 } } feature { fid_v1_list { value: 5467090921087372257 } } feature { fid_v1_list { value: 5467090921087372257 } } feature { fid_v1_list { value: 5467090921087372257 } } feature { fid_v1_list { value: 5467090921087372257 } } feature { fid_v1_list { value: 5467090921087372257 } } feature { fid_v1_list { value: 5467090921087372257 } } feature { fid_v1_list { value: 5467090921087372257 } } feature { fid_v1_list { value: 5467090921087372257 } } feature { fid_v1_list { value: 5467090921087372257 } } feature { fid_v1_list { value: 5467090921087372257 } } feature { fid_v1_list { value: 5467090921087372257 } } feature { fid_v1_list { value: 5467090921087372257 } } feature { fid_v1_list { value: 5467090921087372257 } } feature { fid_v1_list { value: 5467090921087372257 } } feature { fid_v1_list { value: 5467090921087372257 } } feature { fid_v1_list { value: 5467090921087372257 } } feature { fid_v1_list { value: 5467090921087372257 } } feature { fid_v1_list { value: 5467090921087372257 } } feature { fid_v1_list { value: 5467090921087372257 } } feature { fid_v1_list { value: 5467090921087372257 } } feature { fid_v1_list { value: 5467090921087372257 } } feature { fid_v1_list { value: 5467090921087372257 } } feature { fid_v1_list { value: 5467090921087372257 } } feature { fid_v1_list { value: 5467090921087372257 } } feature { fid_v1_list { value: 5467090921087372257 } } feature { fid_v1_list { value: 5467090921087372257 } } feature { fid_v1_list { value: 5467090921087372257 } } feature { fid_v1_list { value: 5467090921087372257 } } feature { fid_v1_list { value: 5467090921087372257 } } feature { fid_v1_list { value: 5467090921087372257 } } feature { fid_v1_list { value: 5467090921087372257 } } feature { fid_v1_list { value: 5467090921087372257 } } feature { fid_v1_list { value: 5467090921087372257 } } feature { fid_v1_list { value: 5467090921087372257 } } feature { fid_v1_list { value: 5467090921087372257 } } feature { fid_v1_list { value: 5467090921087372257 } } feature { fid_v1_list { value: 5467090921087372257 } } feature { fid_v1_list { value: 5467090921087372257 } } feature { fid_v1_list { value: 5467090921087372257 } } feature { fid_v1_list { value: 5467090921087372257 } } feature { fid_v1_list { value: 5467090921087372257 } } feature { fid_v1_list { value: 5467090921087372257 } } feature { fid_v1_list { value: 5467090921087372257 } } } named_feature_list { name: "f_spm_2" feature { fid_v1_list { value: 5441980230172067264 } } feature { fid_v1_list { value: 5441722959692120433 } } feature { fid_v1_list { value: 5441722959692120433 } } feature { fid_v1_list { value: 5441980230172067264 } } feature { fid_v1_list { value: 5451370321420258447 } } feature { fid_v1_list { value: 5441980230172067264 } } feature { fid_v1_list { value: 5441980230172067264 } } feature { fid_v1_list { value: 5441980230172067264 } } feature { fid_v1_list { value: 5441980230172067264 } } feature { fid_v1_list { value: 5441980230172067264 } } feature { fid_v1_list { value: 5441980230172067264 } } feature { fid_v1_list { value: 5451370321420258447 } } feature { fid_v1_list { value: 5451333593801533170 } } feature { fid_v1_list { value: 5451370321420258447 } } feature { fid_v1_list { value: 5441980230172067264 } } feature { fid_v1_list { value: 5441722959692120433 } } feature { fid_v1_list { value: 5441980230172067264 } } feature { fid_v1_list { value: 5441980230172067264 } } feature { fid_v1_list { value: 5450268195437205113 } } feature { fid_v1_list { value: 5451370321420258447 } } feature { fid_v1_list { value: 5447058606610205093 } } feature { fid_v1_list { value: 5441980230172067264 } } feature { fid_v1_list { value: 5451370321420258447 } } feature { fid_v1_list { value: 5450268195437205113 } } feature { fid_v1_list { value: 5441980230172067264 } } feature { fid_v1_list { value: 5441980230172067264 } } feature { fid_v1_list { value: 5451370321420258447 } } feature { fid_v1_list { value: 5441722959692120433 } } feature { fid_v1_list { value: 5451370321420258447 } } feature { fid_v1_list { value: 5451370321420258447 } } feature { fid_v1_list { value: 5441980230172067264 } } feature { fid_v1_list { value: 5441980230172067264 } } feature { fid_v1_list { value: 5450268195437205113 } } feature { fid_v1_list { value: 5450268195437205113 } } feature { fid_v1_list { value: 5450268195437205113 } } feature { fid_v1_list { value: 5450268195437205113 } } feature { fid_v1_list { value: 5450268195437205113 } } feature { fid_v1_list { value: 5447058606610205093 } } feature { fid_v1_list { value: 5450268195437205113 } } feature { fid_v1_list { value: 5441722959692120433 } } feature { fid_v1_list { value: 5441980230172067264 } } feature { fid_v1_list { value: 5441980230172067264 } } feature { fid_v1_list { value: 5451370321420258447 } } feature { fid_v1_list { value: 5447099530896665279 } } feature { fid_v1_list { value: 5451370321420258447 } } feature { fid_v1_list { value: 5447058606610205093 } } feature { fid_v1_list { value: 5450268195437205113 } } feature { fid_v1_list { value: 5450268195437205113 } } feature { fid_v1_list { value: 5447099530896665279 } } feature { fid_v1_list { value: 5441722959692120433 } } feature { fid_v1_list { value: 5441722959692120433 } } } named_feature_list { name: "f_spm_4" feature { fid_v1_list { value: 5485105319596854241 } } feature { fid_v1_list { value: 5485105319596854241 } } feature { fid_v1_list { value: 5485105319596854241 } } feature { fid_v1_list { value: 5485105319596854241 } } feature { fid_v1_list { value: 5485105319596854241 } } feature { fid_v1_list { value: 5485105319596854241 } } feature { fid_v1_list { value: 5485105319596854241 } } feature { fid_v1_list { value: 5485105319596854241 } } feature { fid_v1_list { value: 5485105319596854241 } } feature { fid_v1_list { value: 5485105319596854241 } } feature { fid_v1_list { value: 5485105319596854241 } } feature { fid_v1_list { value: 5485105319596854241 } } feature { fid_v1_list { value: 5485105319596854241 } } feature { fid_v1_list { value: 5485105319596854241 } } feature { fid_v1_list { value: 5485105319596854241 } } feature { fid_v1_list { value: 5485105319596854241 } } feature { fid_v1_list { value: 5485105319596854241 } } feature { fid_v1_list { value: 5485105319596854241 } } feature { fid_v1_list { value: 5485105319596854241 } } feature { fid_v1_list { value: 5485105319596854241 } } feature { fid_v1_list { value: 5485105319596854241 } } feature { fid_v1_list { value: 5485105319596854241 } } feature { fid_v1_list { value: 5485105319596854241 } } feature { fid_v1_list { value: 5485105319596854241 } } feature { fid_v1_list { value: 5485105319596854241 } } feature { fid_v1_list { value: 5485105319596854241 } } feature { fid_v1_list { value: 5485105319596854241 } } feature { fid_v1_list { value: 5485105319596854241 } } feature { fid_v1_list { value: 5485105319596854241 } } feature { fid_v1_list { value: 5485105319596854241 } } feature { fid_v1_list { value: 5485105319596854241 } } feature { fid_v1_list { value: 5485105319596854241 } } feature { fid_v1_list { value: 5485105319596854241 } } feature { fid_v1_list { value: 5485105319596854241 } } feature { fid_v1_list { value: 5485105319596854241 } } feature { fid_v1_list { value: 5485105319596854241 } } feature { fid_v1_list { value: 5485105319596854241 } } feature { fid_v1_list { value: 5485105319596854241 } } feature { fid_v1_list { value: 5485105319596854241 } } feature { fid_v1_list { value: 5485105319596854241 } } feature { fid_v1_list { value: 5485105319596854241 } } feature { fid_v1_list { value: 5485105319596854241 } } feature { fid_v1_list { value: 5485105319596854241 } } feature { fid_v1_list { value: 5485105319596854241 } } feature { fid_v1_list { value: 5485105319596854241 } } feature { fid_v1_list { value: 5485105319596854241 } } feature { fid_v1_list { value: 5485105319596854241 } } feature { fid_v1_list { value: 5485105319596854241 } } feature { fid_v1_list { value: 5485105319596854241 } } feature { fid_v1_list { value: 5485105319596854241 } } feature { fid_v1_list { value: 5485105319596854241 } } } named_feature_list { name: "f_goods_spu_id" feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_goods_test45_map_string" feature { } feature { } feature { fid_v1_list { value: 9254796155896927364 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { fid_v1_list { value: 9244648908249888622 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_price_reduction_10" feature { } feature { } feature { fid_v1_list { value: 4367227356144754473 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { fid_v1_list { value: 4360224287068445047 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_user_id" feature { fid_v1_list { value: 34661039000183975 } } feature { fid_v1_list { value: 26375391288018365 } } feature { fid_v1_list { value: 26375391288018365 } } feature { fid_v1_list { value: 26375391288018365 } } feature { fid_v1_list { value: 26375391288018365 } } feature { fid_v1_list { value: 26375391288018365 } } feature { fid_v1_list { value: 26375391288018365 } } feature { fid_v1_list { value: 26375391288018365 } } feature { fid_v1_list { value: 26375391288018365 } } feature { fid_v1_list { value: 26375391288018365 } } feature { fid_v1_list { value: 26375391288018365 } } feature { fid_v1_list { value: 35592691808838616 } } feature { fid_v1_list { value: 32480246379207459 } } feature { fid_v1_list { value: 26403633346786866 } } feature { fid_v1_list { value: 26403633346786866 } } feature { fid_v1_list { value: 26403633346786866 } } feature { fid_v1_list { value: 22739744141601140 } } feature { fid_v1_list { value: 22739744141601140 } } feature { fid_v1_list { value: 22739744141601140 } } feature { fid_v1_list { value: 31753659020026959 } } feature { fid_v1_list { value: 31753659020026959 } } feature { fid_v1_list { value: 33895127814546172 } } feature { fid_v1_list { value: 18976550735883040 } } feature { fid_v1_list { value: 18976550735883040 } } feature { fid_v1_list { value: 30810089642095820 } } feature { fid_v1_list { value: 30810089642095820 } } feature { fid_v1_list { value: 24555653156517823 } } feature { fid_v1_list { value: 28442428386816647 } } feature { fid_v1_list { value: 31371008067277372 } } feature { fid_v1_list { value: 31371008067277372 } } feature { fid_v1_list { value: 24213437492756285 } } feature { fid_v1_list { value: 34828384169520893 } } feature { fid_v1_list { value: 34828384169520893 } } feature { fid_v1_list { value: 34828384169520893 } } feature { fid_v1_list { value: 34828384169520893 } } feature { fid_v1_list { value: 34828384169520893 } } feature { fid_v1_list { value: 34828384169520893 } } feature { fid_v1_list { value: 34828384169520893 } } feature { fid_v1_list { value: 30621526964314858 } } feature { fid_v1_list { value: 26799910740734091 } } feature { fid_v1_list { value: 24017764260178887 } } feature { fid_v1_list { value: 24017764260178887 } } feature { fid_v1_list { value: 24017764260178887 } } feature { fid_v1_list { value: 32401856502699871 } } feature { fid_v1_list { value: 29165508905376910 } } feature { fid_v1_list { value: 25144527618414700 } } feature { fid_v1_list { value: 25144527618414700 } } feature { fid_v1_list { value: 25144527618414700 } } feature { fid_v1_list { value: 25144527618414700 } } feature { fid_v1_list { value: 30893925026832398 } } feature { fid_v1_list { value: 30893925026832398 } } } named_feature_list { name: "f_user_ctx_network" feature { fid_v1_list { value: 1107344313440999751 } } feature { fid_v1_list { value: 1102855795182531059 } } feature { fid_v1_list { value: 1102855795182531059 } } feature { fid_v1_list { value: 1102855795182531059 } } feature { fid_v1_list { value: 1102855795182531059 } } feature { fid_v1_list { value: 1102855795182531059 } } feature { fid_v1_list { value: 1102855795182531059 } } feature { fid_v1_list { value: 1102855795182531059 } } feature { fid_v1_list { value: 1102855795182531059 } } feature { fid_v1_list { value: 1102855795182531059 } } feature { fid_v1_list { value: 1102855795182531059 } } feature { fid_v1_list { value: 1107344313440999751 } } feature { fid_v1_list { value: 1107344313440999751 } } feature { fid_v1_list { value: 1107606481792732129 } } feature { fid_v1_list { value: 1107606481792732129 } } feature { fid_v1_list { value: 1107606481792732129 } } feature { fid_v1_list { value: 1107344313440999751 } } feature { fid_v1_list { value: 1107344313440999751 } } feature { fid_v1_list { value: 1107344313440999751 } } feature { fid_v1_list { value: 1107344313440999751 } } feature { fid_v1_list { value: 1107344313440999751 } } feature { fid_v1_list { value: 1102855795182531059 } } feature { fid_v1_list { value: 1102855795182531059 } } feature { fid_v1_list { value: 1102855795182531059 } } feature { fid_v1_list { value: 1107344313440999751 } } feature { fid_v1_list { value: 1107344313440999751 } } feature { fid_v1_list { value: 1107344313440999751 } } feature { fid_v1_list { value: 1107606481792732129 } } feature { fid_v1_list { value: 1107606481792732129 } } feature { fid_v1_list { value: 1107606481792732129 } } feature { fid_v1_list { value: 1107344313440999751 } } feature { fid_v1_list { value: 1107344313440999751 } } feature { fid_v1_list { value: 1107344313440999751 } } feature { fid_v1_list { value: 1107344313440999751 } } feature { fid_v1_list { value: 1107344313440999751 } } feature { fid_v1_list { value: 1107344313440999751 } } feature { fid_v1_list { value: 1107344313440999751 } } feature { fid_v1_list { value: 1107344313440999751 } } feature { fid_v1_list { value: 1102855795182531059 } } feature { fid_v1_list { value: 1107606481792732129 } } feature { fid_v1_list { value: 1107344313440999751 } } feature { fid_v1_list { value: 1107344313440999751 } } feature { fid_v1_list { value: 1107344313440999751 } } feature { fid_v1_list { value: 1107606481792732129 } } feature { fid_v1_list { value: 1107606481792732129 } } feature { fid_v1_list { value: 1107606481792732129 } } feature { fid_v1_list { value: 1107606481792732129 } } feature { fid_v1_list { value: 1107606481792732129 } } feature { fid_v1_list { value: 1107606481792732129 } } feature { fid_v1_list { value: 1107606481792732129 } } feature { fid_v1_list { value: 1107606481792732129 } } } named_feature_list { name: "f_goods_test38_array_bool" feature { } feature { } feature { fid_v1_list { value: 9466287390192372705 value: 9470390549237301823 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { fid_v1_list { value: 9470390549237301823 value: 9466287390192372705 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_user_test05_string" feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { fid_v1_list { value: 9724206661329807113 } } feature { } feature { } feature { } feature { fid_v1_list { value: 9724206661329807113 } } feature { fid_v1_list { value: 9724206661329807113 } } feature { fid_v1_list { value: 9724206661329807113 } } feature { fid_v1_list { value: 9724206661329807113 } } feature { fid_v1_list { value: 9724206661329807113 } } feature { fid_v1_list { value: 9724206661329807113 } } feature { fid_v1_list { value: 9724206661329807113 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_user_city" feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { fid_v1_list { value: 184466745546546060 } } feature { } feature { } feature { } feature { fid_v1_list { value: 188872157809150945 } } feature { fid_v1_list { value: 188872157809150945 } } feature { fid_v1_list { value: 188872157809150945 } } feature { fid_v1_list { value: 188872157809150945 } } feature { fid_v1_list { value: 188872157809150945 } } feature { fid_v1_list { value: 188872157809150945 } } feature { fid_v1_list { value: 188872157809150945 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_goods_test31_float" feature { } feature { } feature { fid_v1_list { value: 9178057014040660961 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { fid_v1_list { value: 9178057014040660961 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_user_device_id" feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_goods_test25_int32" feature { } feature { } feature { fid_v1_list { value: 9302601369885095920 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { fid_v1_list { value: 9306223536401580458 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_user_ctx_country" feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_user_test07_float" feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { fid_v1_list { value: 9646431375287192545 } } feature { } feature { } feature { } feature { fid_v1_list { value: 9646431375287192545 } } feature { fid_v1_list { value: 9646431375287192545 } } feature { fid_v1_list { value: 9646431375287192545 } } feature { fid_v1_list { value: 9646431375287192545 } } feature { fid_v1_list { value: 9646431375287192545 } } feature { fid_v1_list { value: 9646431375287192545 } } feature { fid_v1_list { value: 9646431375287192545 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_goods_shipping_money" feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_user_test20_map_uint64" feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { fid_v1_list { value: 9875330678134588286 } } feature { } feature { } feature { } feature { fid_v1_list { value: 9875330678134588286 } } feature { fid_v1_list { value: 9875330678134588286 } } feature { fid_v1_list { value: 9875330678134588286 } } feature { fid_v1_list { value: 9875330678134588286 } } feature { fid_v1_list { value: 9875330678134588286 } } feature { fid_v1_list { value: 9875330678134588286 } } feature { fid_v1_list { value: 9875330678134588286 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_goods_test42_map_int64" feature { } feature { } feature { fid_v1_list { value: 9583186420523358701 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { fid_v1_list { value: 9569085903473394558 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_user_test02_int64" feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { fid_v1_list { value: 9687302302898843769 } } feature { } feature { } feature { } feature { fid_v1_list { value: 9687302302898843769 } } feature { fid_v1_list { value: 9687302302898843769 } } feature { fid_v1_list { value: 9687302302898843769 } } feature { fid_v1_list { value: 9687302302898843769 } } feature { fid_v1_list { value: 9687302302898843769 } } feature { fid_v1_list { value: 9687302302898843769 } } feature { fid_v1_list { value: 9687302302898843769 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_user_ctx_district" feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_goods_current_price_1000" feature { } feature { } feature { fid_v1_list { value: 3864695453538224966 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { fid_v1_list { value: 3872186882756116749 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_goods_test35_array_uint32" feature { } feature { } feature { fid_v1_list { value: 9390373322373335920 value: 9388843097214556352 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { fid_v1_list { value: 9394938221468407955 value: 9390373322373335920 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_user_age" feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_user_test24_map_double" feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { fid_v1_list { value: 9963468822863469749 } } feature { } feature { } feature { } feature { fid_v1_list { value: 9963468822863469749 } } feature { fid_v1_list { value: 9963468822863469749 } } feature { fid_v1_list { value: 9963468822863469749 } } feature { fid_v1_list { value: 9963468822863469749 } } feature { fid_v1_list { value: 9963468822863469749 } } feature { fid_v1_list { value: 9963468822863469749 } } feature { fid_v1_list { value: 9963468822863469749 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_user_id-f_goods_current_price" feature { } feature { } feature { fid_v1_list { value: 9045399711331748120 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { fid_v1_list { value: 9050924295655723542 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_user_ctx_device_model" feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_goods_origin_price_1000" feature { } feature { } feature { fid_v1_list { value: 3915976234101104791 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { fid_v1_list { value: 3916132237881122179 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_goods_test28_uint64" feature { } feature { } feature { fid_v1_list { value: 9475573618135007223 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { fid_v1_list { value: 9475573618135007224 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_user_id-f_goods_sale_number" feature { } feature { } feature { fid_v1_list { value: 9065644441006867721 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { fid_v1_list { value: 9074052365851756434 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_goods_share_cnt" feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_user_test08_double" feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { fid_v1_list { value: 9952676149948386273 } } feature { } feature { } feature { } feature { fid_v1_list { value: 9952676149948386273 } } feature { fid_v1_list { value: 9952676149948386273 } } feature { fid_v1_list { value: 9952676149948386273 } } feature { fid_v1_list { value: 9952676149948386273 } } feature { fid_v1_list { value: 9952676149948386273 } } feature { fid_v1_list { value: 9952676149948386273 } } feature { fid_v1_list { value: 9952676149948386273 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_user_test06_bool" feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { fid_v1_list { value: 9830678519426941503 } } feature { } feature { } feature { } feature { fid_v1_list { value: 9830678519426941503 } } feature { fid_v1_list { value: 9830678519426941503 } } feature { fid_v1_list { value: 9830678519426941503 } } feature { fid_v1_list { value: 9830678519426941503 } } feature { fid_v1_list { value: 9830678519426941503 } } feature { fid_v1_list { value: 9830678519426941503 } } feature { fid_v1_list { value: 9830678519426941503 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_goods_test33_array_int32" feature { } feature { } feature { fid_v1_list { value: 9534996739393864528 value: 9542874036410030307 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { fid_v1_list { value: 9537475890510816686 value: 9532465039818505375 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_goods_brand" feature { } feature { } feature { fid_v1_list { value: 3743551367396753389 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { fid_v1_list { value: 3733245351290651264 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_goods_current_price" feature { } feature { } feature { fid_v1_list { value: 3826301672674962047 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { fid_v1_list { value: 3830953370607164169 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_goods_tags_terms" feature { } feature { } feature { fid_v1_list { value: 3811379433387304802 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { fid_v1_list { value: 3806659414003006399 value: 3817841950973150016 value: 3808895514000229327 value: 3808399441778900137 value: 3808399441778900137 value: 3808895514000229327 value: 3811495559608005733 value: 3817841611590067540 value: 3805280058801983575 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_goods_cate_1" feature { } feature { } feature { fid_v1_list { value: 3664100952242652919 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { fid_v1_list { value: 3671364035570364440 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_goods_cate_2" feature { } feature { } feature { fid_v1_list { value: 3687893206423298716 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { fid_v1_list { value: 3690602924044901962 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_goods_cate_3" feature { } feature { } feature { fid_v1_list { value: 3696969598358506601 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { fid_v1_list { value: 3698659331662272830 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_goods_cate_4" feature { } feature { } feature { fid_v1_list { value: 3716991466540562362 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { fid_v1_list { value: 3711405396388666663 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_goods_test44_map_uint64" feature { } feature { } feature { fid_v1_list { value: 9367013638409574893 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { fid_v1_list { value: 9352913121359610750 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_user_id-f_page" feature { fid_v1_list { value: 9092789705426625231 } } feature { fid_v1_list { value: 9086397419364646740 } } feature { fid_v1_list { value: 9086397419364646740 } } feature { fid_v1_list { value: 9080929169466218336 } } feature { fid_v1_list { value: 9092490710899668234 } } feature { fid_v1_list { value: 9080929169466218336 } } feature { fid_v1_list { value: 9080929169466218336 } } feature { fid_v1_list { value: 9080929169466218336 } } feature { fid_v1_list { value: 9080929169466218336 } } feature { fid_v1_list { value: 9080929169466218336 } } feature { fid_v1_list { value: 9080929169466218336 } } feature { fid_v1_list { value: 9084723152110966336 } } feature { fid_v1_list { value: 9090673728892834668 } } feature { fid_v1_list { value: 9087002549578805200 } } feature { fid_v1_list { value: 9086378705790910311 } } feature { fid_v1_list { value: 9091939060511529031 } } feature { fid_v1_list { value: 9094200603063820929 } } feature { fid_v1_list { value: 9094200603063820929 } } feature { fid_v1_list { value: 9096580265606473271 } } feature { fid_v1_list { value: 9090211897252558350 } } feature { fid_v1_list { value: 9097050149952374856 } } feature { fid_v1_list { value: 9086504574730266468 } } feature { fid_v1_list { value: 9092015533406643567 } } feature { fid_v1_list { value: 9086616598317095688 } } feature { fid_v1_list { value: 9085846876813685546 } } feature { fid_v1_list { value: 9085846876813685546 } } feature { fid_v1_list { value: 9087911285192683050 } } feature { fid_v1_list { value: 9082753167719844823 } } feature { fid_v1_list { value: 9082815111106798175 } } feature { fid_v1_list { value: 9082815111106798175 } } feature { fid_v1_list { value: 9094183977061329440 } } feature { fid_v1_list { value: 9082847486989399824 } } feature { fid_v1_list { value: 9085289623114222694 } } feature { fid_v1_list { value: 9085289623114222694 } } feature { fid_v1_list { value: 9085289623114222694 } } feature { fid_v1_list { value: 9085289623114222694 } } feature { fid_v1_list { value: 9085289623114222694 } } feature { fid_v1_list { value: 9095104172800023209 } } feature { fid_v1_list { value: 9089914688358822412 } } feature { fid_v1_list { value: 9088296710536755392 } } feature { fid_v1_list { value: 9092609888464068017 } } feature { fid_v1_list { value: 9092609888464068017 } } feature { fid_v1_list { value: 9080303384419389583 } } feature { fid_v1_list { value: 9089956094034039490 } } feature { fid_v1_list { value: 9094391052982619763 } } feature { fid_v1_list { value: 9085646402573353544 } } feature { fid_v1_list { value: 9096908574899867209 } } feature { fid_v1_list { value: 9096908574899867209 } } feature { fid_v1_list { value: 9082413653420066872 } } feature { fid_v1_list { value: 9082691189500607224 } } feature { fid_v1_list { value: 9082691189500607224 } } } named_feature_list { name: "f_user_id_type" feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_goods_origin_price_10" feature { } feature { } feature { fid_v1_list { value: 3906092201329708288 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { fid_v1_list { value: 3905557326176545651 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_user_test03_uint32" feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { fid_v1_list { value: 9755226191658047635 } } feature { } feature { } feature { } feature { fid_v1_list { value: 9755226191658047635 } } feature { fid_v1_list { value: 9755226191658047635 } } feature { fid_v1_list { value: 9755226191658047635 } } feature { fid_v1_list { value: 9755226191658047635 } } feature { fid_v1_list { value: 9755226191658047635 } } feature { fid_v1_list { value: 9755226191658047635 } } feature { fid_v1_list { value: 9755226191658047635 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_goods_shop_id" feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_goods_pub_time_day" feature { } feature { } feature { fid_v1_list { value: 4220323736978779502 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { fid_v1_list { value: 4233021007139439725 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_user_test14_array_bool" feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { fid_v1_list { value: 9794649722407977535 value: 9790546563363048417 } } feature { } feature { } feature { } feature { fid_v1_list { value: 9794649722407977535 value: 9790546563363048417 } } feature { fid_v1_list { value: 9794649722407977535 value: 9790546563363048417 } } feature { fid_v1_list { value: 9794649722407977535 value: 9790546563363048417 } } feature { fid_v1_list { value: 9794649722407977535 value: 9790546563363048417 } } feature { fid_v1_list { value: 9794649722407977535 value: 9790546563363048417 } } feature { fid_v1_list { value: 9794649722407977535 value: 9790546563363048417 } } feature { fid_v1_list { value: 9794649722407977535 value: 9790546563363048417 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_user_test12_array_uint64" feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { fid_v1_list { value: 9939503882031591545 value: 9940527614559994373 } } feature { } feature { } feature { } feature { fid_v1_list { value: 9939503882031591545 value: 9940527614559994373 } } feature { fid_v1_list { value: 9939503882031591545 value: 9940527614559994373 } } feature { fid_v1_list { value: 9939503882031591545 value: 9940527614559994373 } } feature { fid_v1_list { value: 9939503882031591545 value: 9940527614559994373 } } feature { fid_v1_list { value: 9939503882031591545 value: 9940527614559994373 } } feature { fid_v1_list { value: 9939503882031591545 value: 9940527614559994373 } } feature { fid_v1_list { value: 9939503882031591545 value: 9940527614559994373 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_user_test10_array_int64" feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { fid_v1_list { value: 9903475085012627577 value: 9904498817541030405 } } feature { } feature { } feature { } feature { fid_v1_list { value: 9903475085012627577 value: 9904498817541030405 } } feature { fid_v1_list { value: 9903475085012627577 value: 9904498817541030405 } } feature { fid_v1_list { value: 9903475085012627577 value: 9904498817541030405 } } feature { fid_v1_list { value: 9903475085012627577 value: 9904498817541030405 } } feature { fid_v1_list { value: 9903475085012627577 value: 9904498817541030405 } } feature { fid_v1_list { value: 9903475085012627577 value: 9904498817541030405 } } feature { fid_v1_list { value: 9903475085012627577 value: 9904498817541030405 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_goods_test48_map_double" feature { } feature { } feature { fid_v1_list { value: 9504076465062773640 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { fid_v1_list { value: 9495094461616938165 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_goods_test36_array_uint64" feature { } feature { } feature { fid_v1_list { value: 9418110057785016837 value: 9404144550176698117 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { fid_v1_list { value: 9417086325256614009 value: 9418110057785016837 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_user_test11_array_uint32" feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { fid_v1_list { value: 9593096605072709779 value: 9588531705977637744 } } feature { } feature { } feature { } feature { fid_v1_list { value: 9593096605072709779 value: 9588531705977637744 } } feature { fid_v1_list { value: 9593096605072709779 value: 9588531705977637744 } } feature { fid_v1_list { value: 9593096605072709779 value: 9588531705977637744 } } feature { fid_v1_list { value: 9593096605072709779 value: 9588531705977637744 } } feature { fid_v1_list { value: 9593096605072709779 value: 9588531705977637744 } } feature { fid_v1_list { value: 9593096605072709779 value: 9588531705977637744 } } feature { fid_v1_list { value: 9593096605072709779 value: 9588531705977637744 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_user_os_version" feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_goods_comment_cnt_10" feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_goods_praise_cnt_10" feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_goods_free_shipping" feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_user_os" feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_user_client_version" feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_goods_sale_number" feature { } feature { } feature { fid_v1_list { value: 4030042199373742655 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { fid_v1_list { value: 4019703662382893803 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_goods_test27_uint32" feature { } feature { } feature { fid_v1_list { value: 9228243735787998064 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { fid_v1_list { value: 9232808634883070099 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_scm" feature { fid_v1_list { value: 5513886323412268515 } } feature { fid_v1_list { value: 5513886323412268515 } } feature { fid_v1_list { value: 5513886323412268515 } } feature { fid_v1_list { value: 5513886323412268515 } } feature { fid_v1_list { value: 5513886323412268515 } } feature { fid_v1_list { value: 5513886323412268515 } } feature { fid_v1_list { value: 5513886323412268515 } } feature { fid_v1_list { value: 5513886323412268515 } } feature { fid_v1_list { value: 5513886323412268515 } } feature { fid_v1_list { value: 5513886323412268515 } } feature { fid_v1_list { value: 5513886323412268515 } } feature { fid_v1_list { value: 5513886323412268515 } } feature { fid_v1_list { value: 5513886323412268515 } } feature { fid_v1_list { value: 5513886323412268515 } } feature { fid_v1_list { value: 5513886323412268515 } } feature { fid_v1_list { value: 5513886323412268515 } } feature { fid_v1_list { value: 5526585389177434390 } } feature { fid_v1_list { value: 5526585389177434390 } } feature { fid_v1_list { value: 5526585389177434390 } } feature { fid_v1_list { value: 5513886323412268515 } } feature { fid_v1_list { value: 5513886323412268515 } } feature { fid_v1_list { value: 5513886323412268515 } } feature { fid_v1_list { value: 5526585389177434390 } } feature { fid_v1_list { value: 5526585389177434390 } } feature { fid_v1_list { value: 5513886323412268515 } } feature { fid_v1_list { value: 5513886323412268515 } } feature { fid_v1_list { value: 5513886323412268515 } } feature { fid_v1_list { value: 5513886323412268515 } } feature { fid_v1_list { value: 5513886323412268515 } } feature { fid_v1_list { value: 5513886323412268515 } } feature { fid_v1_list { value: 5513886323412268515 } } feature { fid_v1_list { value: 5526585389177434390 } } feature { fid_v1_list { value: 5526585389177434390 } } feature { fid_v1_list { value: 5526585389177434390 } } feature { fid_v1_list { value: 5526585389177434390 } } feature { fid_v1_list { value: 5526585389177434390 } } feature { fid_v1_list { value: 5526585389177434390 } } feature { fid_v1_list { value: 5526585389177434390 } } feature { fid_v1_list { value: 5526585389177434390 } } feature { fid_v1_list { value: 5526585389177434390 } } feature { fid_v1_list { value: 5513886323412268515 } } feature { fid_v1_list { value: 5513886323412268515 } } feature { fid_v1_list { value: 5513886323412268515 } } feature { fid_v1_list { value: 5513886323412268515 } } feature { fid_v1_list { value: 5526585389177434390 } } feature { fid_v1_list { value: 5513886323412268515 } } feature { fid_v1_list { value: 5513886323412268515 } } feature { fid_v1_list { value: 5513886323412268515 } } feature { fid_v1_list { value: 5513886323412268515 } } feature { fid_v1_list { value: 5513886323412268515 } } feature { fid_v1_list { value: 5513886323412268515 } } } named_feature_list { name: "f_goods_share_cnt_10" feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_user_ctx_os_version" feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "__LINE_ID__" feature { bytes_list { value: "\030\305\331\221\203\006!\004\\.\347\231=\233H2\001\001Z\001a\240\001\305\331\221\203\006\212\003\0171c86a18cbb1060f\232\n\0231$##$1005$##$1$##$1" } } feature { bytes_list { value: "\030\200\342\220\203\006!\346\010T\347\364\020\365!2\001\001Z\001a\240\001\200\342\220\203\006\212\003\01721bacb96ae78a72\232\n\0231$##$1001$##$1$##$1" } } feature { bytes_list { value: "\030\206\342\220\203\006!\224\331\334\036O\026\365!2\001\001Z\001a\240\001\206\342\220\203\006\212\003\01721bacb96ae78a72\232\n\0231$##$1001$##$1$##$1" } } feature { bytes_list { value: "\030\315\233\221\203\006!\213w\364\272.:\260G2\001\001Z\001a\240\001\315\233\221\203\006\212\003\01721bacb96ae78a72\232\n\0231$##$1005$##$1$##$1" } } feature { bytes_list { value: "\030\203\333\221\203\006!\200\255\205n\330\364\017y2\001\001Z\001a\240\001\203\333\221\203\006\212\003\01721bacb96ae78a72\232\n\0231$##$1002$##$1$##$1" } } feature { bytes_list { value: "\030\310\333\221\203\006!O4\025&\240\242\263x2\001\001Z\001a\240\001\310\333\221\203\006\212\003\01721bacb96ae78a72\232\n\0231$##$1005$##$1$##$1" } } feature { bytes_list { value: "\030\323\334\221\203\006!\036j\017k\237\271K\0302\001\001Z\001a\240\001\323\334\221\203\006\212\003\01721bacb96ae78a72\232\n\0231$##$1005$##$1$##$1" } } feature { bytes_list { value: "\030\230\337\221\203\006!\334DZN^\267\207`2\001\001Z\001a\240\001\230\337\221\203\006\212\003\01721bacb96ae78a72\232\n\0231$##$1005$##$1$##$1" } } feature { bytes_list { value: "\030\317\340\221\203\006!\352P\331\203+\362W~2\001\001Z\001a\240\001\317\340\221\203\006\212\003\01721bacb96ae78a72\232\n\0231$##$1005$##$1$##$1" } } feature { bytes_list { value: "\030\246\342\221\203\006!\234\311\264G\351`\367X2\001\001Z\001a\240\001\246\342\221\203\006\212\003\01721bacb96ae78a72\232\n\0231$##$1005$##$1$##$1" } } feature { bytes_list { value: "\030\253\345\221\203\006!\217\312\036\022\213\3378\0022\001\001Z\001a\240\001\253\345\221\203\006\212\003\01721bacb96ae78a72\232\n\0231$##$1005$##$1$##$1" } } feature { bytes_list { value: "\030\213\211\222\203\006!\205\246\302\320\343\366UG2\001\001Z\001a\240\001\213\211\222\203\006\212\003\01725bce83ed9020ac\232\n\0231$##$1002$##$1$##$1" } } feature { bytes_list { value: "\030\311\233\220\203\006!\337\276\373\235*\237K\0302\001\001Z\001a\240\001\311\233\220\203\006\212\003\01728a0ca96b8b56c3\232\n\0231$##$1007$##$1$##$1" } } feature { bytes_list { value: "\030\315\202\217\203\006!e\353\371\t\247\206}\0372\001\001Z\001a\240\001\315\202\217\203\006\212\003\0172ed86ac692fdee1\232\n\0231$##$1002$##$1$##$1" } } feature { bytes_list { value: "\030\335\203\217\203\006!~\374+\270\304?\017W2\001\001Z\001a\240\001\335\203\217\203\006\212\003\0172ed86ac692fdee1\232\n\0231$##$1005$##$1$##$1" } } feature { bytes_list { value: "\030\222\224\217\203\006!\r\237<\261\001\201?\'2\001\001Z\001a\240\001\222\224\217\203\006\212\003\0172ed86ac692fdee1\232\n\0231$##$1001$##$1$##$1" } } feature { bytes_list { value: "\030\267\261\222\203\006!e\t\276$\025`a\0272\001\001Z\001b\240\001\267\261\222\203\006\212\003\01730985df86b21b67\232\n\0231$##$1005$##$1$##$1" } } feature { bytes_list { value: "\030\325\270\222\203\006!J\274\017\323A\3509\0022\001\001Z\001b\240\001\325\270\222\203\006\212\003\01730985df86b21b67\232\n\0231$##$1005$##$1$##$1" } } feature { bytes_list { value: "\030\360\270\222\203\006!\036GfC\366\220\'\r2\001\001Z\001b\240\001\360\270\222\203\006\212\003\01730985df86b21b67\232\n\0231$##$1000$##$1$##$1" } } feature { bytes_list { value: "\030\207\372\217\203\006!\374=c+\323\331\025\0162\001\001Z\001a\240\001\207\372\217\203\006\212\003\017369da2888203e1e\232\n\0231$##$1002$##$1$##$1" } } feature { bytes_list { value: "\030\226\372\217\203\006!\313\322\24794\341\304\0322\001\001Z\001a\240\001\226\372\217\203\006\212\003\017369da2888203e1e\232\n\0231$##$1004$##$1$##$1" } } feature { bytes_list { value: "\030\213\351\221\203\006!b\234r\324\210\256\025\0022\001\001Z\001a\240\001\213\351\221\203\006\212\003\0173aab3d8116557fa\232\n\0231$##$1005$##$1$##$1" } } feature { bytes_list { value: "\030\305\335\217\203\006!\252\357\274\025\205\3353e2\001\001Z\001b\240\001\305\335\217\203\006\212\003\01747f669fd3349cc6\232\n\0231$##$1002$##$1$##$1" } } feature { bytes_list { value: "\030\322\342\217\203\006!~\346\220l_\361\266.2\001\001Z\001b\240\001\322\342\217\203\006\212\003\01747f669fd3349cc6\232\n\0231$##$1000$##$1$##$1" } } feature { bytes_list { value: "\030\366\263\222\203\006!\312\263*l\273r\302y2\001\001Z\001a\240\001\366\263\222\203\006\212\003\0176223504cb9aa333\232\n\0231$##$1005$##$1$##$1" } } feature { bytes_list { value: "\030\216\264\222\203\006!\223\321\260\345\242`a\0272\001\001Z\001a\240\001\216\264\222\203\006\212\003\0176223504cb9aa333\232\n\0231$##$1005$##$1$##$1" } } feature { bytes_list { value: "\030\376\350\216\203\006!\320\224?\232\357\347\22632\001\001Z\001a\240\001\376\350\216\203\006\212\003\01762a75ce41714473\232\n\0231$##$1002$##$1$##$1" } } feature { bytes_list { value: "\030\362\353\220\203\006!\353\363\355\243\001\357\003#2\001\001Z\001a\240\001\362\353\220\203\006\212\003\017673702cccb12d46\232\n\0231$##$1001$##$1$##$1" } } feature { bytes_list { value: "\030\200\314\220\203\006!\026\345\223[$\023\244\\2\001\002Z\001a\240\001\200\314\220\203\006\212\003\0176a32e345a4503a8\232\n\0231$##$1002$##$1$##$1" } } feature { bytes_list { value: "\030\246\313\220\203\006!\246\310\205\354U\341\333P2\001\001Z\001a\240\001\246\313\220\203\006\212\003\0176a32e345a4503a8\232\n\0231$##$1002$##$1$##$1" } } feature { bytes_list { value: "\030\313\276\215\203\006!\030\030\233%\255\252\261A2\001\001Z\001a\240\001\313\276\215\203\006\212\003\01771bd60bc5418391\232\n\0231$##$1005$##$1$##$1" } } feature { bytes_list { value: "\030\324\241\220\203\006!K\357\275\346%\232uC2\001\001Z\001b\240\001\324\241\220\203\006\212\003\017872ad0c97f0b04d\232\n\0231$##$1005$##$1$##$1" } } feature { bytes_list { value: "\030\336\242\220\203\006!\303,CQ\302\214\222)2\001\001Z\001b\240\001\336\242\220\203\006\212\003\017872ad0c97f0b04d\232\n\0231$##$1000$##$1$##$1" } } feature { bytes_list { value: "\030\227\243\220\203\006!\361\272\277\246K\254K\0302\001\001Z\001b\240\001\227\243\220\203\006\212\003\017872ad0c97f0b04d\232\n\0231$##$1000$##$1$##$1" } } feature { bytes_list { value: "\030\270\243\220\203\006!\3767\254\023\265o\n,2\001\001Z\001b\240\001\270\243\220\203\006\212\003\017872ad0c97f0b04d\232\n\0231$##$1000$##$1$##$1" } } feature { bytes_list { value: "\030\362\255\220\203\006!\321\372\321\245T\r\325R2\001\001Z\001b\240\001\362\255\220\203\006\212\003\017872ad0c97f0b04d\232\n\0231$##$1000$##$1$##$1" } } feature { bytes_list { value: "\030\243\256\220\203\006!\252\207l&\024\321\270^2\001\001Z\001b\240\001\243\256\220\203\006\212\003\017872ad0c97f0b04d\232\n\0231$##$1000$##$1$##$1" } } feature { bytes_list { value: "\030\267\265\220\203\006!\025\241U,\212H\214 2\001\001Z\001b\240\001\267\265\220\203\006\212\003\017872ad0c97f0b04d\232\n\0231$##$1004$##$1$##$1" } } feature { bytes_list { value: "\030\337\250\217\203\006!\355SRF\327\315\323U2\001\001Z\001b\240\001\337\250\217\203\006\212\003\017b5b5e702d95bdfb\232\n\0231$##$1000$##$1$##$1" } } feature { bytes_list { value: "\030\353\211\217\203\006!\337\301q\264\313W\34222\001\001Z\001b\240\001\353\211\217\203\006\212\003\017b5edbf560d2caa9\232\n\0231$##$1001$##$1$##$1" } } feature { bytes_list { value: "\030\272\270\220\203\006!\271k\225\372\001\003\232\0132\001\001Z\001a\240\001\272\270\220\203\006\212\003\017bd8be8b84a911ad\232\n\0231$##$1005$##$1$##$1" } } feature { bytes_list { value: "\030\366\270\220\203\006!;\216\350ba\273\034z2\001\001Z\001a\240\001\366\270\220\203\006\212\003\017bd8be8b84a911ad\232\n\0231$##$1005$##$1$##$1" } } feature { bytes_list { value: "\030\206\224\222\203\006!\322\277\374\271$}\213c2\001\001Z\001a\240\001\206\224\222\203\006\212\003\017bd8be8b84a911ad\232\n\0231$##$1002$##$1$##$1" } } feature { bytes_list { value: "\030\317\227\215\203\006!,\273Sc\263*\210Y2\001\001Z\001a\240\001\317\227\215\203\006\212\003\017ca6b013bfc776c3\232\n\0231$##$1006$##$1$##$1" } } feature { bytes_list { value: "\030\210\332\220\203\006!\034 \252e\314\251K\0302\001\001Z\001b\240\001\210\332\220\203\006\212\003\017f4ae638e6d1400e\232\n\0231$##$1002$##$1$##$1" } } feature { bytes_list { value: "\030\371\211\221\203\006!\343$L+\374\265\262i2\001\001Z\001a\240\001\371\211\221\203\006\212\003\017f821dbfa85c83ed\232\n\0231$##$1004$##$1$##$1" } } feature { bytes_list { value: "\030\365\225\221\203\006!\202\353\224\226\253V0&2\001\001Z\001a\240\001\365\225\221\203\006\212\003\017f821dbfa85c83ed\232\n\0231$##$1000$##$1$##$1" } } feature { bytes_list { value: "\030\214\226\221\203\006!\215\017\r\216\322eO\"2\001\001Z\001a\240\001\214\226\221\203\006\212\003\017f821dbfa85c83ed\232\n\0231$##$1000$##$1$##$1" } } feature { bytes_list { value: "\030\252\235\221\203\006!T\350\307\222\227\271f\1772\001\001Z\001a\240\001\252\235\221\203\006\212\003\017f821dbfa85c83ed\232\n\0231$##$1006$##$1$##$1" } } feature { bytes_list { value: "\030\241\230\217\203\006!\320\237-\325UD/,2\001\001Z\001a\240\001\241\230\217\203\006\212\003\017fc6289b84feeb5f\232\n\0231$##$1001$##$1$##$1" } } feature { bytes_list { value: "\030\210\260\217\203\006!\323\3127i@\033\315^2\001\001Z\001a\240\001\210\260\217\203\006\212\003\017fc6289b84feeb5f\232\n\0231$##$1001$##$1$##$1" } } } named_feature_list { name: "f_goods_test26_int64" feature { } feature { } feature { fid_v1_list { value: 9201937275671233029 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { fid_v1_list { value: 9200913543142830201 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "__LABEL__" feature { float_list { value: 0.0 } } feature { float_list { value: 0.0 } } feature { float_list { value: 0.0 } } feature { float_list { value: 0.0 } } feature { float_list { value: 0.0 } } feature { float_list { value: 0.0 } } feature { float_list { value: 0.0 } } feature { float_list { value: 0.0 } } feature { float_list { value: 0.0 } } feature { float_list { value: 0.0 } } feature { float_list { value: 0.0 } } feature { float_list { value: 0.0 } } feature { float_list { value: 0.0 } } feature { float_list { value: 0.0 } } feature { float_list { value: 0.0 } } feature { float_list { value: 0.0 } } feature { float_list { value: 0.0 } } feature { float_list { value: 0.0 } } feature { float_list { value: 0.0 } } feature { float_list { value: 0.0 } } feature { float_list { value: 0.0 } } feature { float_list { value: 0.0 } } feature { float_list { value: 0.0 } } feature { float_list { value: 0.0 } } feature { float_list { value: 0.0 } } feature { float_list { value: 0.0 } } feature { float_list { value: 0.0 } } feature { float_list { value: 0.0 } } feature { float_list { value: 0.0 } } feature { float_list { value: 0.0 } } feature { float_list { value: 0.0 } } feature { float_list { value: 0.0 } } feature { float_list { value: 0.0 } } feature { float_list { value: 0.0 } } feature { float_list { value: 0.0 } } feature { float_list { value: 0.0 } } feature { float_list { value: 0.0 } } feature { float_list { value: 0.0 } } feature { float_list { value: 0.0 } } feature { float_list { value: 0.0 } } feature { float_list { value: 0.0 } } feature { float_list { value: 0.0 } } feature { float_list { value: 0.0 } } feature { float_list { value: 0.0 } } feature { float_list { value: 0.0 } } feature { float_list { value: 0.0 } } feature { float_list { value: 0.0 } } feature { float_list { value: 0.0 } } feature { float_list { value: 0.0 } } feature { float_list { value: 0.0 } } feature { float_list { value: 0.0 } } } named_feature_list { name: "f_user_test23_map_float" feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { fid_v1_list { value: 9801309575956515417 } } feature { } feature { } feature { } feature { fid_v1_list { value: 9801309575956515417 } } feature { fid_v1_list { value: 9801309575956515417 } } feature { fid_v1_list { value: 9801309575956515417 } } feature { fid_v1_list { value: 9801309575956515417 } } feature { fid_v1_list { value: 9801309575956515417 } } feature { fid_v1_list { value: 9801309575956515417 } } feature { fid_v1_list { value: 9801309575956515417 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_price_reduction" feature { } feature { } feature { fid_v1_list { value: 3934388063731853951 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { fid_v1_list { value: 3930250596545429415 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_user_membership_level" feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_user_platform" feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_user_test21_map_string" feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { fid_v1_list { value: 9622951276949010286 } } feature { } feature { } feature { } feature { fid_v1_list { value: 9622951276949010286 } } feature { fid_v1_list { value: 9622951276949010286 } } feature { fid_v1_list { value: 9622951276949010286 } } feature { fid_v1_list { value: 9622951276949010286 } } feature { fid_v1_list { value: 9622951276949010286 } } feature { fid_v1_list { value: 9622951276949010286 } } feature { fid_v1_list { value: 9622951276949010286 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_user_network" feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_goods_praise_cnt" feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_goods_origin_price" feature { } feature { } feature { fid_v1_list { value: 3875588474307037931 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { fid_v1_list { value: 3884996566135610121 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_goods_rating" feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_detail_pic_num" feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_goods_test43_map_uint32" feature { } feature { } feature { fid_v1_list { value: 9560876433953954425 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { fid_v1_list { value: 9564567803863763281 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_user_country" feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_user_test16_array_double" feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { fid_v1_list { value: 9916647352929422305 value: 9916647352929422305 } } feature { } feature { } feature { } feature { fid_v1_list { value: 9916647352929422305 value: 9916647352929422305 } } feature { fid_v1_list { value: 9916647352929422305 value: 9916647352929422305 } } feature { fid_v1_list { value: 9916647352929422305 value: 9916647352929422305 } } feature { fid_v1_list { value: 9916647352929422305 value: 9916647352929422305 } } feature { fid_v1_list { value: 9916647352929422305 value: 9916647352929422305 } } feature { fid_v1_list { value: 9916647352929422305 value: 9916647352929422305 } } feature { fid_v1_list { value: 9916647352929422305 value: 9916647352929422305 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_goods_id" feature { fid_v1_list { value: 3610133626909631996 } } feature { fid_v1_list { value: 3604145606281754473 } } feature { fid_v1_list { value: 3616415016801295334 } } feature { fid_v1_list { value: 3619863857700495324 } } feature { fid_v1_list { value: 3617246315459696229 } } feature { fid_v1_list { value: 3614776421570243930 } } feature { fid_v1_list { value: 3617016026378354835 } } feature { fid_v1_list { value: 3619262297046717444 } } feature { fid_v1_list { value: 3619249165061197326 } } feature { fid_v1_list { value: 3618989404930912336 } } feature { fid_v1_list { value: 3613697671503572881 } } feature { fid_v1_list { value: 3603154114569699616 } } feature { fid_v1_list { value: 3611397315614213708 } } feature { fid_v1_list { value: 3614763883536608253 } } feature { fid_v1_list { value: 3606701039490254176 } } feature { fid_v1_list { value: 3620064348517401255 } } feature { fid_v1_list { value: 3617870400400875595 } } feature { fid_v1_list { value: 3604189663600640666 } } feature { fid_v1_list { value: 3615867999634237576 } } feature { fid_v1_list { value: 3620527882979129921 } } feature { fid_v1_list { value: 3612210403309400165 } } feature { fid_v1_list { value: 3619712610574498604 } } feature { fid_v1_list { value: 3616268254331474538 } } feature { fid_v1_list { value: 3610856047563518050 } } feature { fid_v1_list { value: 3606887699004289777 } } feature { fid_v1_list { value: 3610970158357720999 } } feature { fid_v1_list { value: 3607339321121329188 } } feature { fid_v1_list { value: 3603297871905146955 } } feature { fid_v1_list { value: 3605886361555974207 } } feature { fid_v1_list { value: 3611846377146202666 } } feature { fid_v1_list { value: 3603259629256285404 } } feature { fid_v1_list { value: 3607259625265071215 } } feature { fid_v1_list { value: 3616775414741526157 } } feature { fid_v1_list { value: 3603039257404856223 } } feature { fid_v1_list { value: 3620296482659719726 } } feature { fid_v1_list { value: 3611735379914375886 } } feature { fid_v1_list { value: 3619698866543905695 } } feature { fid_v1_list { value: 3611138543871412195 } } feature { fid_v1_list { value: 3608086200946945886 } } feature { fid_v1_list { value: 3608534566476548317 } } feature { fid_v1_list { value: 3607157927903496194 } } feature { fid_v1_list { value: 3603776236828038739 } } feature { fid_v1_list { value: 3607177734847050046 } } feature { fid_v1_list { value: 3610806719725808236 } } feature { fid_v1_list { value: 3606611260239343721 } } feature { fid_v1_list { value: 3604130491108528090 } } feature { fid_v1_list { value: 3605348026796130101 } } feature { fid_v1_list { value: 3616118806628872259 } } feature { fid_v1_list { value: 3615956039523072078 } } feature { fid_v1_list { value: 3616815934704003769 } } feature { fid_v1_list { value: 3611149856407841749 } } } named_feature_list { name: "f_goods_title" feature { } feature { } feature { fid_v1_list { value: 3750933102327502546 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { fid_v1_list { value: 3760993937582903701 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_user_ctx_os" feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_goods_sale_number_1000" feature { } feature { } feature { fid_v1_list { value: 4066070996392706623 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { fid_v1_list { value: 4054836228468787784 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_user_area" feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_goods_test41_map_int32" feature { } feature { } feature { fid_v1_list { value: 9519402512114369756 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { fid_v1_list { value: 9523202332830430791 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_goods_test39_array_float" feature { } feature { } feature { fid_v1_list { value: 9286143405097552865 value: 9286143405097552865 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { fid_v1_list { value: 9286143405097552865 value: 9286143405097552865 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_user_test01_int32" feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { fid_v1_list { value: 10008785078271377834 } } feature { } feature { } feature { } feature { fid_v1_list { value: 10008785078271377834 } } feature { fid_v1_list { value: 10008785078271377834 } } feature { fid_v1_list { value: 10008785078271377834 } } feature { fid_v1_list { value: 10008785078271377834 } } feature { fid_v1_list { value: 10008785078271377834 } } feature { fid_v1_list { value: 10008785078271377834 } } feature { fid_v1_list { value: 10008785078271377834 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_goods_pub_time_month" feature { } feature { } feature { fid_v1_list { value: 4214234648173981377 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { fid_v1_list { value: 4211846706978290518 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_goods_comment_cnt_1000" feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_goods_test46_map_bool" feature { } feature { } feature { fid_v1_list { value: 9432690607210547955 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { fid_v1_list { value: 9438915382847444385 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_spm" feature { fid_v1_list { value: 5416942885167263790 } } feature { fid_v1_list { value: 5406374454400203474 } } feature { fid_v1_list { value: 5406374454400203474 } } feature { fid_v1_list { value: 5416942885167263790 } } feature { fid_v1_list { value: 5409103793660641491 } } feature { fid_v1_list { value: 5416942885167263790 } } feature { fid_v1_list { value: 5416942885167263790 } } feature { fid_v1_list { value: 5416942885167263790 } } feature { fid_v1_list { value: 5416942885167263790 } } feature { fid_v1_list { value: 5416942885167263790 } } feature { fid_v1_list { value: 5416942885167263790 } } feature { fid_v1_list { value: 5409103793660641491 } } feature { fid_v1_list { value: 5417348509665164251 } } feature { fid_v1_list { value: 5409103793660641491 } } feature { fid_v1_list { value: 5416942885167263790 } } feature { fid_v1_list { value: 5406374454400203474 } } feature { fid_v1_list { value: 5416942885167263790 } } feature { fid_v1_list { value: 5416942885167263790 } } feature { fid_v1_list { value: 5414531140030155305 } } feature { fid_v1_list { value: 5409103793660641491 } } feature { fid_v1_list { value: 5419024638130229596 } } feature { fid_v1_list { value: 5416942885167263790 } } feature { fid_v1_list { value: 5409103793660641491 } } feature { fid_v1_list { value: 5414531140030155305 } } feature { fid_v1_list { value: 5416942885167263790 } } feature { fid_v1_list { value: 5416942885167263790 } } feature { fid_v1_list { value: 5409103793660641491 } } feature { fid_v1_list { value: 5406374454400203474 } } feature { fid_v1_list { value: 5409103793660641491 } } feature { fid_v1_list { value: 5409103793660641491 } } feature { fid_v1_list { value: 5416942885167263790 } } feature { fid_v1_list { value: 5416942885167263790 } } feature { fid_v1_list { value: 5414531140030155305 } } feature { fid_v1_list { value: 5414531140030155305 } } feature { fid_v1_list { value: 5414531140030155305 } } feature { fid_v1_list { value: 5414531140030155305 } } feature { fid_v1_list { value: 5414531140030155305 } } feature { fid_v1_list { value: 5419024638130229596 } } feature { fid_v1_list { value: 5414531140030155305 } } feature { fid_v1_list { value: 5406374454400203474 } } feature { fid_v1_list { value: 5416942885167263790 } } feature { fid_v1_list { value: 5416942885167263790 } } feature { fid_v1_list { value: 5409103793660641491 } } feature { fid_v1_list { value: 5404484209003362014 } } feature { fid_v1_list { value: 5409103793660641491 } } feature { fid_v1_list { value: 5419024638130229596 } } feature { fid_v1_list { value: 5414531140030155305 } } feature { fid_v1_list { value: 5414531140030155305 } } feature { fid_v1_list { value: 5404484209003362014 } } feature { fid_v1_list { value: 5406374454400203474 } } feature { fid_v1_list { value: 5406374454400203474 } } } named_feature_list { name: "f_user_district" feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_goods_test40_array_double" feature { } feature { } feature { fid_v1_list { value: 9160042615531178977 value: 9160042615531178977 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { fid_v1_list { value: 9160042615531178977 value: 9160042615531178977 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_goods_status" feature { } feature { } feature { fid_v1_list { value: 3647636671629691873 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { fid_v1_list { value: 3647636671629691873 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_user_register_time_year" feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { fid_v1_list { value: 74831614955585623 } } feature { } feature { } feature { } feature { fid_v1_list { value: 79882321954458273 } } feature { fid_v1_list { value: 79882321954458273 } } feature { fid_v1_list { value: 79882321954458273 } } feature { fid_v1_list { value: 79882321954458273 } } feature { fid_v1_list { value: 79882321954458273 } } feature { fid_v1_list { value: 79882321954458273 } } feature { fid_v1_list { value: 79882321954458273 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_goods_title_terms" feature { } feature { } feature { fid_v1_list { value: 3772834016398264993 value: 3765913701858879735 value: 3773000691200033589 value: 3781801790151204962 value: 3766856870087500518 value: 3775195877281589288 value: 3774546363060474958 value: 3767514185271515909 value: 3768887054169500879 value: 3769995601855538213 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { fid_v1_list { value: 3782553589976358690 value: 3774011077586830980 value: 3776247000003769261 value: 3781324026518362914 value: 3775919041219263551 value: 3765703767519208555 value: 3769675736649258697 value: 3776247000003769261 value: 3781324026518362914 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_user_test19_map_uint32" feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { fid_v1_list { value: 9780740585977547089 } } feature { } feature { } feature { } feature { fid_v1_list { value: 9780740585977547089 } } feature { fid_v1_list { value: 9780740585977547089 } } feature { fid_v1_list { value: 9780740585977547089 } } feature { fid_v1_list { value: 9780740585977547089 } } feature { fid_v1_list { value: 9780740585977547089 } } feature { fid_v1_list { value: 9780740585977547089 } } feature { fid_v1_list { value: 9780740585977547089 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_user_tags_terms" feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_user_device_model" feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_page" feature { fid_v1_list { value: 5496023425700513216 } } feature { fid_v1_list { value: 5495766155220566385 } } feature { fid_v1_list { value: 5495766155220566385 } } feature { fid_v1_list { value: 5496023425700513216 } } feature { fid_v1_list { value: 5505413516948704399 } } feature { fid_v1_list { value: 5496023425700513216 } } feature { fid_v1_list { value: 5496023425700513216 } } feature { fid_v1_list { value: 5496023425700513216 } } feature { fid_v1_list { value: 5496023425700513216 } } feature { fid_v1_list { value: 5496023425700513216 } } feature { fid_v1_list { value: 5496023425700513216 } } feature { fid_v1_list { value: 5505413516948704399 } } feature { fid_v1_list { value: 5505376789329979122 } } feature { fid_v1_list { value: 5505413516948704399 } } feature { fid_v1_list { value: 5496023425700513216 } } feature { fid_v1_list { value: 5495766155220566385 } } feature { fid_v1_list { value: 5496023425700513216 } } feature { fid_v1_list { value: 5496023425700513216 } } feature { fid_v1_list { value: 5504311390965651065 } } feature { fid_v1_list { value: 5505413516948704399 } } feature { fid_v1_list { value: 5501101802138651045 } } feature { fid_v1_list { value: 5496023425700513216 } } feature { fid_v1_list { value: 5505413516948704399 } } feature { fid_v1_list { value: 5504311390965651065 } } feature { fid_v1_list { value: 5496023425700513216 } } feature { fid_v1_list { value: 5496023425700513216 } } feature { fid_v1_list { value: 5505413516948704399 } } feature { fid_v1_list { value: 5495766155220566385 } } feature { fid_v1_list { value: 5505413516948704399 } } feature { fid_v1_list { value: 5505413516948704399 } } feature { fid_v1_list { value: 5496023425700513216 } } feature { fid_v1_list { value: 5496023425700513216 } } feature { fid_v1_list { value: 5504311390965651065 } } feature { fid_v1_list { value: 5504311390965651065 } } feature { fid_v1_list { value: 5504311390965651065 } } feature { fid_v1_list { value: 5504311390965651065 } } feature { fid_v1_list { value: 5504311390965651065 } } feature { fid_v1_list { value: 5501101802138651045 } } feature { fid_v1_list { value: 5504311390965651065 } } feature { fid_v1_list { value: 5495766155220566385 } } feature { fid_v1_list { value: 5496023425700513216 } } feature { fid_v1_list { value: 5496023425700513216 } } feature { fid_v1_list { value: 5505413516948704399 } } feature { fid_v1_list { value: 5501142726425111231 } } feature { fid_v1_list { value: 5505413516948704399 } } feature { fid_v1_list { value: 5501101802138651045 } } feature { fid_v1_list { value: 5504311390965651065 } } feature { fid_v1_list { value: 5504311390965651065 } } feature { fid_v1_list { value: 5501142726425111231 } } feature { fid_v1_list { value: 5495766155220566385 } } feature { fid_v1_list { value: 5495766155220566385 } } } named_feature_list { name: "f_is_dup" feature { fid_v1_list { value: 5598068063666195856 } } feature { fid_v1_list { value: 5598068063666195856 } } feature { fid_v1_list { value: 5598068063666195856 } } feature { fid_v1_list { value: 5598068063666195856 } } feature { fid_v1_list { value: 5598068063666195856 } } feature { fid_v1_list { value: 5598068063666195856 } } feature { fid_v1_list { value: 5598068063666195856 } } feature { fid_v1_list { value: 5598068063666195856 } } feature { fid_v1_list { value: 5598068063666195856 } } feature { fid_v1_list { value: 5598068063666195856 } } feature { fid_v1_list { value: 5598068063666195856 } } feature { fid_v1_list { value: 5598068063666195856 } } feature { fid_v1_list { value: 5598068063666195856 } } feature { fid_v1_list { value: 5598068063666195856 } } feature { fid_v1_list { value: 5598068063666195856 } } feature { fid_v1_list { value: 5598068063666195856 } } feature { fid_v1_list { value: 5598068063666195856 } } feature { fid_v1_list { value: 5598068063666195856 } } feature { fid_v1_list { value: 5598068063666195856 } } feature { fid_v1_list { value: 5598068063666195856 } } feature { fid_v1_list { value: 5598068063666195856 } } feature { fid_v1_list { value: 5598068063666195856 } } feature { fid_v1_list { value: 5598068063666195856 } } feature { fid_v1_list { value: 5598068063666195856 } } feature { fid_v1_list { value: 5598068063666195856 } } feature { fid_v1_list { value: 5598068063666195856 } } feature { fid_v1_list { value: 5598068063666195856 } } feature { fid_v1_list { value: 5598068063666195856 } } feature { fid_v1_list { value: 5598068063666195856 } } feature { fid_v1_list { value: 5598068063666195856 } } feature { fid_v1_list { value: 5598068063666195856 } } feature { fid_v1_list { value: 5598068063666195856 } } feature { fid_v1_list { value: 5598068063666195856 } } feature { fid_v1_list { value: 5598068063666195856 } } feature { fid_v1_list { value: 5598068063666195856 } } feature { fid_v1_list { value: 5598068063666195856 } } feature { fid_v1_list { value: 5598068063666195856 } } feature { fid_v1_list { value: 5598068063666195856 } } feature { fid_v1_list { value: 5598068063666195856 } } feature { fid_v1_list { value: 5598068063666195856 } } feature { fid_v1_list { value: 5598068063666195856 } } feature { fid_v1_list { value: 5598068063666195856 } } feature { fid_v1_list { value: 5598068063666195856 } } feature { fid_v1_list { value: 5598068063666195856 } } feature { fid_v1_list { value: 5598068063666195856 } } feature { fid_v1_list { value: 5598068063666195856 } } feature { fid_v1_list { value: 5598068063666195856 } } feature { fid_v1_list { value: 5598068063666195856 } } feature { fid_v1_list { value: 5598068063666195856 } } feature { fid_v1_list { value: 5598068063666195856 } } feature { fid_v1_list { value: 5598068063666195856 } } } named_feature_list { name: "f_user_test04_uint64" feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { fid_v1_list { value: 9601674407701381112 } } feature { } feature { } feature { } feature { fid_v1_list { value: 9601674407701381112 } } feature { fid_v1_list { value: 9601674407701381112 } } feature { fid_v1_list { value: 9601674407701381112 } } feature { fid_v1_list { value: 9601674407701381112 } } feature { fid_v1_list { value: 9601674407701381112 } } feature { fid_v1_list { value: 9601674407701381112 } } feature { fid_v1_list { value: 9601674407701381112 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_goods_test29_string" feature { } feature { } feature { fid_v1_list { value: 9454213769273926101 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { fid_v1_list { value: 9453990683687577353 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_goods_sale_number_10" feature { } feature { } feature { fid_v1_list { value: 4048056597883224639 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { fid_v1_list { value: 4050207389405564160 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_user_gender" feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_goods_current_price_10" feature { } feature { } feature { fid_v1_list { value: 3843523687113293978 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { fid_v1_list { value: 3847744272348507369 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_user_test13_array_string" feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { fid_v1_list { value: 9670163465801361161 value: 9670386551387709909 } } feature { } feature { } feature { } feature { fid_v1_list { value: 9670163465801361161 value: 9670386551387709909 } } feature { fid_v1_list { value: 9670163465801361161 value: 9670386551387709909 } } feature { fid_v1_list { value: 9670163465801361161 value: 9670386551387709909 } } feature { fid_v1_list { value: 9670163465801361161 value: 9670386551387709909 } } feature { fid_v1_list { value: 9670163465801361161 value: 9670386551387709909 } } feature { fid_v1_list { value: 9670163465801361161 value: 9670386551387709909 } } feature { fid_v1_list { value: 9670163465801361161 value: 9670386551387709909 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_user_id-f_goods_cate_1" feature { } feature { } feature { fid_v1_list { value: 9117240382873998683 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { fid_v1_list { value: 9119401595761469634 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_user_id-f_goods_cate_2" feature { } feature { } feature { fid_v1_list { value: 9145018947826707613 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { fid_v1_list { value: 9145794885964584340 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_goods_tags" feature { } feature { } feature { fid_v1_list { value: 3793365034877822818 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { fid_v1_list { value: 3793205953451737140 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_user_test22_map_bool" feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { fid_v1_list { value: 9853246548565530017 } } feature { } feature { } feature { } feature { fid_v1_list { value: 9853246548565530017 } } feature { fid_v1_list { value: 9853246548565530017 } } feature { fid_v1_list { value: 9853246548565530017 } } feature { fid_v1_list { value: 9853246548565530017 } } feature { fid_v1_list { value: 9853246548565530017 } } feature { fid_v1_list { value: 9853246548565530017 } } feature { fid_v1_list { value: 9853246548565530017 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_user_register_time_month" feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { fid_v1_list { value: 98148548915816378 } } feature { } feature { } feature { } feature { fid_v1_list { value: 100705584251763957 } } feature { fid_v1_list { value: 100705584251763957 } } feature { fid_v1_list { value: 100705584251763957 } } feature { fid_v1_list { value: 100705584251763957 } } feature { fid_v1_list { value: 100705584251763957 } } feature { fid_v1_list { value: 100705584251763957 } } feature { fid_v1_list { value: 100705584251763957 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_user_ctx_province" feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_user_id-f_goods_brand" feature { } feature { } feature { fid_v1_list { value: 9101553090662715167 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { fid_v1_list { value: 9111910659614948266 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_user_ctx_area" feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_goods_test34_array_int64" feature { } feature { } feature { fid_v1_list { value: 9219951674180715013 value: 9205986166572396293 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { fid_v1_list { value: 9218927941652312185 value: 9219951674180715013 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_current_price_ratio" feature { } feature { } feature { fid_v1_list { value: 3961537472356691931 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { fid_v1_list { value: 3954672822003818443 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_goods_praise_cnt_1000" feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_user_ctx_platform" feature { fid_v1_list { value: 946231517668508309 } } feature { fid_v1_list { value: 946231517668508309 } } feature { fid_v1_list { value: 946231517668508309 } } feature { fid_v1_list { value: 946231517668508309 } } feature { fid_v1_list { value: 946231517668508309 } } feature { fid_v1_list { value: 946231517668508309 } } feature { fid_v1_list { value: 946231517668508309 } } feature { fid_v1_list { value: 946231517668508309 } } feature { fid_v1_list { value: 946231517668508309 } } feature { fid_v1_list { value: 946231517668508309 } } feature { fid_v1_list { value: 946231517668508309 } } feature { fid_v1_list { value: 946231517668508309 } } feature { fid_v1_list { value: 938505698453898612 } } feature { fid_v1_list { value: 946231517668508309 } } feature { fid_v1_list { value: 946231517668508309 } } feature { fid_v1_list { value: 946231517668508309 } } feature { fid_v1_list { value: 938505698453898612 } } feature { fid_v1_list { value: 938505698453898612 } } feature { fid_v1_list { value: 938505698453898612 } } feature { fid_v1_list { value: 946231517668508309 } } feature { fid_v1_list { value: 946231517668508309 } } feature { fid_v1_list { value: 946231517668508309 } } feature { fid_v1_list { value: 946231517668508309 } } feature { fid_v1_list { value: 946231517668508309 } } feature { fid_v1_list { value: 946231517668508309 } } feature { fid_v1_list { value: 946231517668508309 } } feature { fid_v1_list { value: 946231517668508309 } } feature { fid_v1_list { value: 938505698453898612 } } feature { fid_v1_list { value: 938505698453898612 } } feature { fid_v1_list { value: 938505698453898612 } } feature { fid_v1_list { value: 946231517668508309 } } feature { fid_v1_list { value: 938505698453898612 } } feature { fid_v1_list { value: 938505698453898612 } } feature { fid_v1_list { value: 938505698453898612 } } feature { fid_v1_list { value: 938505698453898612 } } feature { fid_v1_list { value: 938505698453898612 } } feature { fid_v1_list { value: 938505698453898612 } } feature { fid_v1_list { value: 938505698453898612 } } feature { fid_v1_list { value: 946231517668508309 } } feature { fid_v1_list { value: 946231517668508309 } } feature { fid_v1_list { value: 946231517668508309 } } feature { fid_v1_list { value: 946231517668508309 } } feature { fid_v1_list { value: 946231517668508309 } } feature { fid_v1_list { value: 938505698453898612 } } feature { fid_v1_list { value: 938505698453898612 } } feature { fid_v1_list { value: 938505698453898612 } } feature { fid_v1_list { value: 938505698453898612 } } feature { fid_v1_list { value: 938505698453898612 } } feature { fid_v1_list { value: 938505698453898612 } } feature { fid_v1_list { value: 946817063302724633 } } feature { fid_v1_list { value: 946817063302724633 } } } named_feature_list { name: "f_goods_test47_map_float" feature { } feature { } feature { fid_v1_list { value: 9261723814926832396 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { fid_v1_list { value: 9260877620672055897 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_user_tags" feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_user_province" feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_goods_test37_array_string" feature { } feature { } feature { fid_v1_list { value: 9343519995220875194 value: 9333547367854110383 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { fid_v1_list { value: 9345904292630685449 value: 9346127378217034197 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_user_test18_map_int64" feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { fid_v1_list { value: 9695186693039768446 } } feature { } feature { } feature { } feature { fid_v1_list { value: 9695186693039768446 } } feature { fid_v1_list { value: 9695186693039768446 } } feature { fid_v1_list { value: 9695186693039768446 } } feature { fid_v1_list { value: 9695186693039768446 } } feature { fid_v1_list { value: 9695186693039768446 } } feature { fid_v1_list { value: 9695186693039768446 } } feature { fid_v1_list { value: 9695186693039768446 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_user_test17_map_int32" feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { fid_v1_list { value: 9865475904510588487 } } feature { } feature { } feature { } feature { fid_v1_list { value: 9865475904510588487 } } feature { fid_v1_list { value: 9865475904510588487 } } feature { fid_v1_list { value: 9865475904510588487 } } feature { fid_v1_list { value: 9865475904510588487 } } feature { fid_v1_list { value: 9865475904510588487 } } feature { fid_v1_list { value: 9865475904510588487 } } feature { fid_v1_list { value: 9865475904510588487 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } named_feature_list { name: "f_goods_pub_time_year" feature { } feature { } feature { fid_v1_list { value: 4187165182116350625 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { fid_v1_list { value: 4187165182116350625 } } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } feature { } } batch_size: 51 ================================================ FILE: monolith/agent_service/mocked_tfserving.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from absl import logging from concurrent import futures from dataclasses import dataclass import grpc from google.protobuf.any_pb2 import Any from queue import Queue import random import threading import time from typing import List, Tuple, Union, Optional from google.protobuf import text_format from tensorflow_serving.apis.get_model_status_pb2 import GetModelStatusRequest, \ GetModelStatusResponse from tensorflow_serving.apis.get_model_metadata_pb2 import GetModelMetadataRequest, \ GetModelMetadataResponse from tensorflow_serving.apis.model_management_pb2 import ReloadConfigRequest, \ ReloadConfigResponse from tensorflow_serving.apis.model_service_pb2_grpc import ModelServiceServicer from tensorflow_serving.apis.model_service_pb2_grpc import add_ModelServiceServicer_to_server from tensorflow_serving.apis.prediction_service_pb2_grpc import PredictionServiceServicer from tensorflow_serving.apis.prediction_service_pb2_grpc import add_PredictionServiceServicer_to_server from tensorflow_serving.config import model_server_config_pb2 from monolith.agent_service.utils import gen_model_version_status, gen_status_proto, gen_model_config, \ ModelState, ErrorCode @dataclass class ModelConf: model_name: str = None base_path: str = None version_policy: str = 'latest' version_data: Union[int, List[int]] = None model_platform: str = 'tensorflow' signature_name: Tuple = ('update', 'predict') @dataclass class ModelVersion: version: int = 0 version_label: str = None state: int = ModelState.UNKNOWN class ModelMeta: def __init__(self, conf: ModelConf, versions: List[ModelVersion] = None): self.conf = conf self.versions = versions or [ModelVersion()] self._unloading = False def is_unloading(self): return self._unloading def set_unloading(self): self._unloading = True @dataclass class Event: model_name: str = None version: int = 0 state: int = ModelState.UNKNOWN class ModelMgr: def __init__(self, model_config_list=None): self._models = {} self._lock = threading.Lock() self._queue = Queue() self._has_stopped = False self._thread: Optional[threading.Thread] = None if model_config_list is not None: self.load(model_config_list) def load(self, model_config_list): for config in model_config_list: if config.model_version_policy.HasField('latest'): version_policy = 'latest' version_data = config.model_version_policy.latest.num_versions versions = [ModelVersion(i + 1) for i in range(version_data)] elif config.model_version_policy.HasField('all'): version_policy = 'latest' version_data = None versions = [ModelVersion(1)] else: version_policy = 'specific' version_data = config.model_version_policy.specific.versions version_data.sort() versions = [ModelVersion(i) for i in version_data] model_conf = ModelConf(config.name, config.base_path, version_policy=version_policy, version_data=version_data) self._models[config.name] = ModelMeta(model_conf, versions) logging.info('start load a new model {}'.format(config.name)) for v in versions: self._queue.put( Event(model_name=config.name, version=v.version, state=ModelState.START)) def remove(self, model_name_list): for model_name in model_name_list: model: ModelMeta = self._models[model_name] model.set_unloading() logging.info('start remove the model {}'.format(model_name)) for version in model.versions: self._queue.put(Event(model_name, version.version, ModelState.UNLOADING)) def get_status(self, model_spec): model_version_status = [] with self._lock: if model_spec.name in self._models: model_meta: ModelMeta = self._models[model_spec.name] if model_spec.WhichOneof('version_choice') is None: for version in model_meta.versions: mvs = gen_model_version_status(version.version, version.state) model_version_status.append(mvs) else: if model_spec.HasField('version'): value = model_spec.version.value for version in model_meta.versions: if version.version == value: mvs = gen_model_version_status(version.version, version.state) model_version_status.append(mvs) break else: value = model_spec.version_label for version in model_meta.versions: if version.version_label == value: mvs = gen_model_version_status(version.version, version.state) model_version_status.append(mvs) break if len(model_version_status) == 0: mvs = gen_model_version_status( -1, error_code=ErrorCode.NOT_FOUND, error_message=f'{model_spec.name} is not found') model_version_status.append(mvs) return model_version_status def get_metadata(self, model_spec, metadata_field): metadata = {} if metadata_field is not None and len(metadata_field) > 0: with self._lock: model_meta: ModelMeta = self._models[model_spec.name] conf = model_meta.conf for field in metadata_field: if hasattr(conf, field): metadata[field] = getattr(conf, field) if model_spec.HasField('version'): version = model_spec.version.value for v in model_meta.versions: if v.version == version: for field in metadata_field: if hasattr(v, field): metadata[field] = getattr(v, field) break return metadata def get_alive_model_names(self): with self._lock: return {k for k, v in self._models.items() if not v.is_unloading()} def start(self): self._thread = threading.Thread(target=self._poll,) self._thread.start() def stop(self): self._has_stopped = True if self._thread is not None: self._thread.join() self._thread = None def _poll(self): start_time = time.time() while not self._has_stopped: if not self._queue.empty(): event = self._queue.get() self._event_handler(event) end_time = time.time() if end_time - start_time > 30: start_time = end_time model_names = list(self._models.keys()) if len(model_names) == 0: continue model_name = random.choice(list(self._models.keys())) model_conf: ModelConf = self._models[model_name].conf if model_conf.version_policy != 'specific': versions = self._models[model_name].versions version = versions[-1].version + 1 versions.append(ModelVersion(version)) logging.info( 'start load a new version of model {}'.format(model_name)) self._queue.put(Event(model_name, version, ModelState.START)) # time.sleep(random.uniform(0, 0.1)) def _event_handler(self, event: Event): with self._lock: model: ModelMeta = self._models.get(event.model_name, None) if model is None: logging.error(f'{event.model_name} has removed!') return log_flag = False if event.state == ModelState.START: for version in model.versions: if version.version == event.version: if version.state == ModelState.UNKNOWN: version.state = event.state self._queue.put( Event(event.model_name, event.version, ModelState.LOADING)) log_flag = True break elif event.state == ModelState.LOADING: for version in model.versions: if version.version == event.version: if version.state == ModelState.START: version.state = event.state self._queue.put( Event(event.model_name, event.version, ModelState.AVAILABLE)) log_flag = True break elif event.state == ModelState.AVAILABLE: for version in model.versions: if version.version == event.version: if version.state == ModelState.LOADING: version.state = event.state log_flag = True if model.conf.version_policy == 'latest': if len(model.versions) > model.conf.version_data: self._queue.put( Event(event.model_name, model.versions[0].version, ModelState.UNLOADING)) break elif event.state == ModelState.UNLOADING: for version in model.versions: if version.version == event.version: # in case unloading in unloading if version.state not in {ModelState.UNLOADING, ModelState.END}: version.state = event.state self._queue.put( Event(event.model_name, event.version, ModelState.END)) log_flag = True break elif event.state == ModelState.END: index = -1 for i, version in enumerate(model.versions): if version.version == event.version: if version.state == ModelState.UNLOADING: version.state = event.state logging.info( f'{event.model_name}-{event.version}: state is {ModelState.Name(event.state)}' ) index = i break if index >= 0: logging.info( f'{event.model_name}-{model.versions[index].version} is removed!') del model.versions[index] if len(model.versions) == 0: logging.info(f'{event.model_name} is removed!') del self._models[event.model_name] else: logging.error('unknown event') if log_flag: logging.info( f'{event.model_name}-{event.version}: state is {event.state}') class ModelServiceImpl(ModelServiceServicer): def __init__(self, model_mgr: ModelMgr): self._model_mgr = model_mgr def GetModelStatus(self, request: GetModelStatusRequest, context): response = GetModelStatusResponse() model_version_status = self._model_mgr.get_status(request.model_spec) response.model_version_status.extend(model_version_status) return response def HandleReloadConfigRequest(self, request: ReloadConfigRequest, context): model_config_list = request.config.model_config_list.config old_names = self._model_mgr.get_alive_model_names() new_names = {config.name for config in model_config_list} to_remove = old_names - new_names self._model_mgr.remove(to_remove) to_load = new_names - old_names self._model_mgr.load( [config for config in model_config_list if config.name in to_load]) response = ReloadConfigResponse() response.status.CopyFrom(gen_status_proto()) return response class PredictionServiceImpl(PredictionServiceServicer): def __init__(self, model_mgr: ModelMgr): self._model_mgr = model_mgr def Predict(self, request, context): pass def GetModelMetadata(self, request: GetModelMetadataRequest, context): model_spec = request.model_spec metadata_field = set(request.metadata_field) response = GetModelMetadataResponse() response.model_spec.CopyFrom(model_spec) metadata = self._model_mgr.get_metadata(model_spec, metadata_field) for k, v in metadata.items(): value = bytes(repr(v), encoding='utf-8') response.metadata[k].CopyFrom(Any(value=value)) return response class FakeTFServing: def __init__(self, model_name: str = None, base_path: str = None, num_versions: int = 1, port: int = 8500, max_workers: int = 10, model_config_file=None): if model_config_file is None: self._model_mgr = ModelMgr( [gen_model_config(model_name, base_path, version_data=num_versions)]) elif isinstance(model_config_file, str): msc = model_server_config_pb2.ModelServerConfig() with open(model_config_file, 'r') as fp: text = ''.join(fp.readlines()) text_format.Parse(text, msc) self._model_mgr = ModelMgr(msc.model_config_list.config) else: assert isinstance(model_config_file, model_server_config_pb2.ModelServerConfig) self._model_mgr = ModelMgr(model_config_file.model_config_list.config) self._server = grpc.server( futures.ThreadPoolExecutor(max_workers=max_workers)) add_ModelServiceServicer_to_server(ModelServiceImpl(self._model_mgr), self._server) add_PredictionServiceServicer_to_server( PredictionServiceImpl(self._model_mgr), self._server) self._server.add_insecure_port(f'[::]:{port}') def start(self): self._model_mgr.start() self._server.start() self._server.wait_for_termination() def stop(self, grace=None): self._server.stop(grace=grace) self._model_mgr.stop() if __name__ == '__main__': tfs = FakeTFServing('model_test', '/tmp/model/monolith', num_versions=1) tfs.start() ================================================ FILE: monolith/agent_service/mocked_tfserving_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import grpc import socket import time import threading import unittest from tensorflow_serving.apis.get_model_metadata_pb2 import GetModelMetadataRequest, \ GetModelMetadataResponse from tensorflow_serving.apis.get_model_status_pb2 import GetModelStatusRequest, \ GetModelStatusResponse from tensorflow_serving.apis.model_management_pb2 import ReloadConfigRequest, \ ReloadConfigResponse from tensorflow_serving.apis.model_service_pb2_grpc import ModelServiceStub from tensorflow_serving.apis.prediction_service_pb2_grpc import PredictionServiceStub from monolith.agent_service import utils from monolith.agent_service.mocked_tfserving import FakeTFServing MODEL_NAME = 'test_model_test' BASE_PATH = '/tmp/test_model/monolith' PORT = utils.find_free_port() Address = f'{socket.gethostbyname(socket.gethostname())}:{PORT}' class MockedTFSTest(unittest.TestCase): tfs: FakeTFServing = None @classmethod def setUpClass(cls) -> None: cls.tfs = FakeTFServing(MODEL_NAME, BASE_PATH, num_versions=2, port=PORT) # cls.tfs.start() thread = threading.Thread(target=lambda: cls.tfs.start()) thread.start() time.sleep(5) @classmethod def tearDownClass(cls) -> None: cls.tfs.stop() def test_get_model_metadata(self): request = GetModelMetadataRequest() request.model_spec.CopyFrom( utils.gen_model_spec(MODEL_NAME, 2, signature_name='predict')) request.metadata_field.extend( ['base_path', 'num_versions', 'signature_name']) stub = PredictionServiceStub(grpc.insecure_channel(Address)) self.assertTrue( isinstance(stub.GetModelMetadata(request), GetModelMetadataResponse)) def test_get_model_status(self): stub = ModelServiceStub(grpc.insecure_channel(Address)) request = GetModelStatusRequest() request.model_spec.CopyFrom( utils.gen_model_spec(MODEL_NAME, 1, signature_name='predict')) self.assertTrue( isinstance(stub.GetModelStatus(request), GetModelStatusResponse)) def test_handle_reload_config_request(self): stub = ModelServiceStub(grpc.insecure_channel(Address)) request = ReloadConfigRequest() model_config_list = request.config.model_config_list.config model_config_list.extend([ utils.gen_model_config(name='test_model', base_path='/tmp/test_model/ctr/saved_model', version_data=2), utils.gen_model_config(name='test_model', base_path='/tmp/test_model/cvr/saved_model', version_data=1), ]) self.assertTrue( isinstance(stub.HandleReloadConfigRequest(request), ReloadConfigResponse)) if __name__ == "__main__": unittest.main() ================================================ FILE: monolith/agent_service/mocked_zkclient.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from absl import logging from functools import partial from kazoo.protocol.states import ZnodeStat, WatchedEvent, EventType, KeeperState from kazoo.exceptions import NoNodeError, NodeExistsError, NotEmptyError, CancelledError import os import time from threading import Lock from typing import List, Dict, Union, Callable, Optional class ChildrenWatch: def __init__(self, client, path: str, func: Union[Callable[[List[str]], None], Callable[[List[str], WatchedEvent], None]], send_event=False): self.path = path self.send_event = send_event self._stopped = False self._func = func catalog = client._catalog catalog.add_children_watch(self) def __call__(self, children: List[str], event: WatchedEvent): if self.send_event: self._func(children, event) else: self._func(children) class DataWatch: def __init__(self, client, path: str, func: Union[Callable[[bytes, ZnodeStat, WatchedEvent], None], Callable[[bytes, ZnodeStat], None]]): self.path = path self._func = func catalog = client._catalog catalog.add_data_watch(self) def __call__(self, data: bytes, state: ZnodeStat, event: WatchedEvent): try: self._func(data, state, event) except TypeError: self._func(data, state) class Election(object): def __init__(self, client, path, identifier=None): self.lock = Lock() def run(self, func, *args, **kwargs): if not callable(func): raise ValueError("leader function is not callable") try: with self.lock: func(*args, **kwargs) except CancelledError: pass def cancel(self): self.lock.cancel() class Node: def __init__(self, path: str, value: bytes = b'', ephemeral: bool = False, data_watch: DataWatch = None, children_watch: ChildrenWatch = None): self.path: str = path self.value: bytes = value self.ephemeral: bool = ephemeral self.children: Dict[str, Node] = {} self._ctime = int(time.time()) self._mtime = int(time.time()) self._version = 0 self._data_watch = data_watch self._children_watch = children_watch event = WatchedEvent(type=EventType.CREATED, state=KeeperState.CONNECTED, path=self.path) if self._data_watch is not None: self._data_watch(self.value, self.state, event) if self._children_watch is not None: self._children_watch([], event) @property def state(self): return ZnodeStat(czxid=0, mzxid=0, ctime=self._ctime, mtime=self._mtime, version=self._version, cversion=0, aversion=0, ephemeralOwner=0, dataLength=len(self.value), numChildren=len(self.children), pzxid=0) @property def basename(self): return os.path.basename(self.path) def set(self, value: bytes): self._mtime = int(time.time()) self._version += 1 self.value = value if self._data_watch is not None: event = WatchedEvent(type=EventType.CHANGED, state=KeeperState.CONNECTED, path=self.path) self._data_watch(self.value, self.state, event) def get(self): return self.value def set_data_watch(self, watch: DataWatch): self._data_watch = watch self._data_watch(self.value, self.state, None) def set_children_watch(self, watch: ChildrenWatch): self._children_watch = watch self._children_watch(list(self.children.keys()), None) def create_child(self, path: str, value: bytes = b'', ephemeral: bool = False, data_watch=None, children_watch=None): basename = os.path.basename(path) self._mtime = int(time.time()) if self.path == os.path.sep: # root child_path = f'{os.path.sep}{basename}' else: child_path = f'{self.path}{os.path.sep}{basename}' node = Node(child_path, value, ephemeral, data_watch, children_watch) self.children[basename] = node if self._children_watch is not None: event = WatchedEvent(type=EventType.CHILD, state=KeeperState.CONNECTED, path=self.path) self._children_watch(list(self.children.keys()), event) return node def get_or_create_child(self, path): name = os.path.basename(path) if name in self.children: return self.children[name] else: return self.create_child(path) def get_child(self, path): return self.children.get(os.path.basename(path), None) def has_child(self, path=None): if path is None: return len(self.children) > 0 else: return os.path.basename(path) in self.children def remove_child(self, path, recursive: bool = False): if self.has_child(path): self._mtime = int(time.time()) node = self.children[os.path.basename(path)] if not recursive and node.has_child(): raise NotEmptyError(f'{path} is not empty!') del self.children[os.path.basename(path)] if self._children_watch is not None: event = WatchedEvent(type=EventType.CHILD, state=KeeperState.CONNECTED, path=self.path) self._children_watch(list(self.children.keys()), event=event) else: raise NoNodeError(f'{path} is not exists!') def __del__(self): event = WatchedEvent(type=EventType.DELETED, state=KeeperState.CONNECTED, path=self.path) if self._data_watch is not None: self._data_watch(self.value, self.state, event) self._data_watch = None for child in list(self.children.keys()): del self.children[child] if self._children_watch is not None: self._children_watch(list(self.children.keys()), event=event) self._children_watch = None del self.path, self.value, self.ephemeral, self._ctime, self._mtime, self._version del self._data_watch, self._children_watch, self.children class Catalog: def __init__(self): self.root = Node(os.path.sep) self._data_watches = {} self._children_watches = {} self._sequence_paths = {} def add_data_watch(self, watch: DataWatch): self._data_watches[watch.path] = watch try: node = self.get(watch.path) node.set_data_watch(watch) except Exception: pass def add_children_watch(self, watch: ChildrenWatch): self._children_watches[watch.path] = watch try: node = self.get(watch.path) node.set_children_watch(watch) except Exception as e: pass def ensure_path(self, path: str) -> Node: items = [item for item in path.split(os.path.sep) if len(item) > 0] cpath, cnode = '', self.root for item in items: cpath = f'{cpath}{os.path.sep}{item}' cnode = cnode.get_or_create_child(cpath) if cnode.path in self._data_watches and cnode._data_watch is None: cnode._data_watch = self._data_watches[cnode.path] if cnode.path in self._children_watches and cnode._children_watch is None: cnode._children_watch = self._children_watches[cnode.path] return cnode def create(self, path: str, value: bytes = b'', ephemeral: bool = False, makepath: bool = False, sequence: bool = False): if sequence: if path in self._sequence_paths: self._sequence_paths[path] += 1 path = f'{path}{self._sequence_paths[path]:010d}' else: self._sequence_paths[path] = 0 path = f'{path}{0:010d}' dirname = os.path.dirname(path) if makepath: pnode = self.ensure_path(dirname) else: pnode = self.get(dirname) if pnode.has_child(path): raise NodeExistsError(f'{path} Exists!') else: data_watch = self._data_watches.get(path, None) children_watch = self._children_watches.get(path, None) return pnode.create_child(path, value, ephemeral, data_watch, children_watch) def delete(self, path: str, recursive: bool = False): dirname = os.path.dirname(path) pnode = self.get(dirname) pnode.remove_child(path, recursive) def set(self, path: str, value: bytes): self.get(path).set(value) def get(self, path: str) -> Node: items = [item for item in path.split(os.path.sep) if len(item) > 0] cpath, cnode = '', self.root for item in items: cpath = f'{cpath}{os.path.sep}{item}' cnode = cnode.get_child(cpath) if cnode is None: raise NoNodeError(f'{path} is not exists!') return cnode class FakeKazooClient: def __init__(self, zk_server: str = None): self._zk_server = zk_server self._catalog: Optional[Catalog] = None self.DataWatch = partial(DataWatch, self) self.ChildrenWatch = partial(ChildrenWatch, self) self.Election = partial(Election, self) def ensure_path(self, path: str): self._catalog.ensure_path(path) def start(self): self._catalog = Catalog() def create(self, path: str, value: bytes = b'', acl=None, ephemeral: bool = False, makepath: bool = False, include_data: bool = False, sequence: bool = False): node = self._catalog.create(path, value, ephemeral, makepath, sequence) if include_data: return node.path, node.state else: return node.path def delete(self, path: str, recursive: bool = True): self._catalog.delete(path, recursive) def set(self, path: str, value: bytes): self._catalog.set(path, value) def get(self, path: str): node = self._catalog.get(path) return node.value, node.state def exists(self, path: str): try: node = self._catalog.get(path) return True except NoNodeError as e: return False def get_children(self, path: str, include_data=False): node = self._catalog.get(path) children = list(node.children.keys()) if include_data: return children, node.state else: return children def retry(self, func, *args, **kwargs): return func(*args, **kwargs) def stop(self): self._catalog = None def close(self): if self._catalog is not None: self.stop() def add_listener(self, listener): pass ================================================ FILE: monolith/agent_service/mocked_zkclient_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from absl import logging from kazoo.exceptions import NoNodeError, NodeExistsError import unittest from monolith.agent_service.mocked_zkclient import FakeKazooClient class MockedZKClientTest(unittest.TestCase): client = None @classmethod def setUpClass(cls) -> None: cls.client = FakeKazooClient() cls.client.start() @classmethod def tearDownClass(cls) -> None: cls.client.stop() def test_create(self): path = '/monolith/zk/data' try: real_path = self.client.create(path, makepath=True) self.assertEqual(real_path, path) except NoNodeError as e: logging.info(f'{e}') except NodeExistsError as e: logging.info(f'{e}') def test_set_get(self): path = '/monolith/zk/data' data = b'hi, I am Fitz!' try: real_path, state = self.client.create(path, makepath=True, include_data=True) self.assertEqual(real_path, path) except NoNodeError as e: logging.info(f'{e}') except NodeExistsError as e: logging.info(f'{e}') not_exists_path = f"{path}/error" try: self.client.set(not_exists_path, data) except NoNodeError as e: logging.error(f'{e}') self.client.set(path, b'hi, I am Fitz!') try: gdata, _ = self.client.get(not_exists_path) self.assertEqual(gdata, data) except NoNodeError as e: logging.error(f'{e}') gdata, state = self.client.get(path) self.assertEqual(gdata, data) def test_delete(self): path = '/monolith/zk/data' try: real_path = self.client.create(path, makepath=True) self.assertEqual(real_path, path) except NoNodeError as e: logging.info(f'{e}') except NodeExistsError as e: logging.info(f'{e}') self.client.delete(path) self.client.delete('/monolith') def test_data_watch(self): path = '/monolith/zk/data' try: real_path = self.client.create(path, makepath=True) self.assertEqual(real_path, path) except NoNodeError as e: logging.info(f'{e}') except NodeExistsError as e: logging.info(f'{e}') def data_watch(data, state, event): print('data_watch', data, state, event) self.client.DataWatch(path=path, func=data_watch) def test_children_watch(self): path = '/monolith/zk/data' def children_watch(children, event): print('children_watch', children, event) self.client.ChildrenWatch(path='/monolith/zk', func=children_watch, send_event=True) try: real_path = self.client.create(path, makepath=True) self.assertEqual(real_path, path) except NoNodeError as e: logging.info(f'{e}') except NodeExistsError as e: logging.info(f'{e}') def data_watch(data, state, event): print('data_watch', data, state, event) self.client.DataWatch(path=path, func=data_watch) self.client.create('/monolith/zk/test', b'123') if __name__ == "__main__": unittest.main() ================================================ FILE: monolith/agent_service/model_manager.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import sys import os import threading import time import shutil import logging from monolith.native_training.metric import cli #from absl import logging # copy latest model from source path(p2p path) to receive path(model path) class ModelManager(object): WRITE_DONE = '.write.done' READ_LOCK = '.read.lock' def __init__(self, model_name, source_path, receive_path, use_metrics): self._worker = None self._source_path = source_path self._receive_path = receive_path self._model_name = model_name self._models = {} # model_name to version list self._latest_models = {} self._wait_timeout = 1200 self._loop_thread = None self._loop_interval = 30 self._exist = False self._remain_version_num = 5 self._lock_files = set() self._use_metrics = use_metrics self._metrics = None if self._model_name and self._use_metrics: self.init_metrics() def init_metrics(self): self._metrics = cli.get_cli('data.monolith_serving.online') def stop(self): self._exist = True if self._loop_thread: self._loop_thread.join() def start(self): ret = False try: ret = self._start() except Exception as e: logging.error('model manager start failed: %s', str(e)) ret = False return ret def _start(self): if self._model_name is None: logging.info('ModelManager is not needed') return True # delete receive path first if not self.delete(self._receive_path): return False # wait for the source path if not self.wait_for_download(): return False # do loop once to copy model while True: try: if self.loop_once(): break except BaseException as err: logging.error('model manager loop once failed: %s', str(err)) logging.info('loop once failed, wait for ready model') time.sleep(10) self.remove_read_lock() self._loop_thread = threading.Thread(target=self.run, name="thread-model_manager") self._loop_thread.start() return True def touch(self, file): try: f = open(file, 'w+') f.close() return True except BaseException: pass return False def run(self): while not self._exist: try: ret = self.loop_once() self.remove_read_lock() if not ret: logging.error('model manager loop once failed') except BaseException as err: logging.error('model manager loop once failed: %s', str(err)) if self._use_metrics: self.check_model_update_time() time.sleep(self._loop_interval) self.remove_old_file() def check_model_update_time(self): if not self._metrics: return if self._model_name not in self._latest_models: logging.error('model %s not in _latest_models: %s', self._model_name, str(self._latest_models)) self._metrics.emit_counter('loop_once_failed', 1, tagkv={'model': self._model_name}) return version, update_time = self._latest_models[self._model_name] cur_time = int(time.time()) self._metrics.emit_store('version.delay', cur_time - int(version), tagkv={'model': self._model_name}) self._metrics.emit_store('update.delay', cur_time - update_time, tagkv={'model': self._model_name}) def remove_old_file(self): for model_name in self._models: model_files_list = self._models[model_name] if len(model_files_list) > self._remain_version_num: old_files = model_files_list.pop(0) for old_file in old_files[1]: self.delete(old_file) def create_read_lock(self, name): lock_name = name + self.READ_LOCK if self.touch(lock_name): return lock_name else: logging.error("create lock %s failed", lock_name) return lock_name def remove_read_lock(self): for lock_file in self._lock_files: self.delete(lock_file) self._lock_files.clear() # remove other lock ret = list(os.walk(self._source_path)) if len(ret) == 0: return root, dirs, files = ret[0] for file in files: if file.endswith(self.READ_LOCK): completed_name = os.join(root, file) logging.info('delete lock file: %s', completed_name) self.delete(completed_name) def loop_once(self): source_data = {} result = True try: source_data = self.get_source_data() except BaseException as err: logging.error('get download data failed: %s', str(err)) return False for model_name in source_data: new_version = source_data[model_name][0] if model_name in self._models and len(self._models[model_name]) > 0: old_version = self._models[model_name][-1][0] if old_version >= new_version: continue ret, file_list = self.copy_model(model_name, new_version, source_data[model_name][1]) if ret: if model_name not in self._models: self._models[model_name] = [] self._models[model_name].append((new_version, file_list)) cur_time = int(time.time()) self._latest_models[model_name] = (new_version, cur_time) logging.info(f'{model_name} update to {new_version}') else: logging.error(f'copy {model_name} failed') result = False return result def copy_model(self, model_name, version, model_data): sub_model_num = len(model_data) ready_data = [] result = [] ready_num = 0 for sub_model_name, sub_model_data in model_data: # sub_model_name: ps_0/version # sub_model_data: /xxx/model_name@version/ps_0/version try: src_file = sub_model_data dst_file = os.path.join(self._receive_path, model_name, sub_model_name) temp_dst_file = dst_file + '-temp' result.append(dst_file) if os.path.exists(dst_file): logging.error(f'{dst_file} exist') ready_num += 1 continue if os.path.exists(temp_dst_file): logging.error(f'{temp_dst_file} exist') ready_num += 1 ready_data.append((temp_dst_file, dst_file)) continue shutil.copytree(src_file, temp_dst_file) ready_data.append((temp_dst_file, dst_file)) ready_num += 1 except BaseException as err: logging.error('copy model %s -> %s faild: %s', src_file, temp_dst_file, str(err)) self.delete(temp_dst_file) break if ready_num != sub_model_num: logging.error( f'copy model faild, ready_num={ready_num}, expect_num={sub_model_num}' ) for data in ready_data: self.delete(data[0]) return False, [] for data in ready_data: os.rename(data[0], data[1]) return True, result def wait_for_download(self): duartion = 0 download_path_ready = os.path.exists(self._source_path) while not download_path_ready and duartion < self._wait_timeout: logging.info(f'wait {self._source_path} created') time.sleep(10) duartion += 10 download_path_ready = os.path.exists(self._source_path) if not download_path_ready: logging.error(f'{self._source_path} is not ready') return False while duartion < self._wait_timeout: ret = list(os.walk(self._source_path)) if len(ret) > 0: root, dirs, files = ret[0] for file in files: if file.endswith(self.WRITE_DONE) and file.startswith( self._model_name): logging.info(f'{file} is ready') return True logging.info('no ready model found') time.sleep(10) duartion += 10 logging.error('no ready model found') return False def get_source_data(self): source_data = {} ret = list(os.walk(self._source_path)) if len(ret) == 0: logging.error(f'{self._source_path} is empty') return source_data root, dirs, files = ret[0] done_file_set = set() for file in files: if file.endswith(self.WRITE_DONE) and file.startswith(self._model_name): done_file_set.add(file) for model_data in dirs: lock_file = self.create_read_lock(os.path.join(root, model_data)) self._lock_files.add(lock_file) if self.get_done_file(model_data) in done_file_set: data = model_data.split('@') if len(data) != 2: logging.error(f'{model_data} is not valid') continue model_name, version = data # real_path: /xxx/model_name@version/model_name real_path = os.path.join(root, model_data, model_name) # version_data: [(ps_0/version,/xxx/model_name@version/ps_0/version), (..,..)] version_data = self.get_version_data(real_path, version) if len(version_data) == 0: continue if model_name not in source_data: source_data[model_name] = (version, version_data, real_path) else: old_data = source_data[model_name] if old_data[0] < version: source_data[model_name] = (version, version_data, real_path) return source_data def get_version_data(self, path, version): ret = list(os.walk(path)) if len(ret) == 0: logging.error(f'get version data [{path}] failed') return [] sub_root, sub_dirs, sub_files = ret[0] if len(sub_dirs) == 0: return [] res = [] for sub_dir in sub_dirs: # sub_dir: ps_0 # version_dir: /xxx/model_name@version/ps_0/version version_dir = os.path.join(sub_root, sub_dir, version) if not os.path.exists(version_dir): logging.error(f'{version_dir} not exist') return [] else: res.append((os.path.join(sub_dir, version), version_dir)) return res def get_done_file(self, file): return file + self.WRITE_DONE def delete(self, file): try: if not os.path.exists(file): return True if os.path.isfile(file): os.remove(file) else: shutil.rmtree(file) return True except BaseException as err: logging.error('delete [%s] failed: %s', file, str(err)) return False ================================================ FILE: monolith/agent_service/model_manager_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from absl import logging, app, flags import json import os import shutil import unittest import time from monolith.agent_service.model_manager import ModelManager FLAGS = flags.FLAGS class ModelManagerTest(unittest.TestCase): def create_file(self, model_name, timestamp, p2p_data_path): # model_data/test_model/ps_item_embedding_0/1234567 # p2p/test_model@1234567/test_model/ps_item_embedding_0/1234567 os.makedirs( os.path.join(p2p_data_path, model_name + '@' + timestamp, model_name, 'ps_item_embedding_0', timestamp)) os.makedirs( os.path.join(p2p_data_path, model_name + '@' + timestamp, model_name, 'ps_item_embedding_1', timestamp)) f = open( os.path.join(p2p_data_path, model_name + '@' + timestamp + '.write.done'), 'w+') f.close() def test_start(self): base_path = os.path.join(os.environ["TEST_TMPDIR"], "test_model_manager") p2p_data_path = os.path.join(base_path, 'p2p') model_data_path = os.path.join(base_path, 'model_data') model_name = "test_model" timestamp = "1234567" self.create_file(model_name, timestamp, p2p_data_path) model_manager = ModelManager(model_name, p2p_data_path, model_data_path, False) model_manager._wait_timeout = 5 model_manager._loop_interval = 5 ret = model_manager.start() self.assertTrue(ret) ready_path1 = os.path.join(model_data_path, model_name, 'ps_item_embedding_0', timestamp) ready_path2 = os.path.join(model_data_path, model_name, 'ps_item_embedding_1', timestamp) self.assertTrue(os.path.exists(ready_path1)) self.assertTrue(os.path.exists(ready_path2)) model_manager.stop() shutil.rmtree(p2p_data_path) shutil.rmtree(model_data_path) def test_ignore_old(self): base_path = os.path.join(os.environ["TEST_TMPDIR"], "test_model_manager") p2p_data_path = os.path.join(base_path, 'p2p') model_data_path = os.path.join(base_path, 'model_data') model_name = "test_model" timestamp = "1234567" timestamp_old = "1234566" self.create_file(model_name, timestamp, p2p_data_path) model_manager = ModelManager(model_name, p2p_data_path, model_data_path, False) model_manager._wait_timeout = 5 model_manager._loop_interval = 5 ret = model_manager.start() self.assertTrue(ret) self.create_file(model_name, timestamp_old, p2p_data_path) time.sleep(11) ready_path1 = os.path.join(model_data_path, model_name, 'ps_item_embedding_0', timestamp_old) ready_path2 = os.path.join(model_data_path, model_name, 'ps_item_embedding_1', timestamp_old) self.assertFalse(os.path.exists(ready_path1)) self.assertFalse(os.path.exists(ready_path2)) model_manager.stop() shutil.rmtree(p2p_data_path) shutil.rmtree(model_data_path) def main(_): unittest.main() if __name__ == "__main__": app.run(main) ================================================ FILE: monolith/agent_service/profile.sh ================================================ #! /bin/bash # set -x # grpc port: PORT3 gpu_server_target="10.209.87.151:9469" # multi "10.210.92.156:9361,10.198.98.198:9433" cpu_server_target="10.211.69.228:9388" tool_dir=`dirname $0` abs_tool_dir=`realpath $tool_dir` entry_agent_path="$abs_tool_dir/agent.conf" profile_data_dir="$abs_tool_dir/profile_data" bin_name="tfs_client" bin_path="/home/lilintong.22222/.cache/bazel/_bazel_lilintong.22222/5282ccf6d1eb9e524c65d4bb4a5b4207/execroot/__main__/bazel-out/k8-opt/bin/monolith/agent_service/${bin_name}" function run_pro() { target="$1" conf_path="$2" batch_size="$3" parallel_num="$4" profile_duration="$5" profile_data_dir="$6" $bin_path \ --target=$target \ --conf=$conf_path \ --cmd_type="profile" \ --input_type="example_batch" \ --batch_size=$batch_size \ --parallel_num=$parallel_num \ --profile_duration=$profile_duration \ --profile_data_dir=$profile_data_dir \ --has_sort_id } function run_pro_async() { target="$1" conf_path="$2" batch_size="$3" parallel_num="$4" profile_duration="$5" profile_data_dir="$6" $bin_path \ --target=$target \ --conf=$conf_path \ --cmd_type="profile" \ --input_type="example_batch" \ --batch_size=$batch_size \ --parallel_num=$parallel_num \ --profile_duration=$profile_duration \ --profile_data_dir=$profile_data_dir \ --has_sort_id & } function run_alg() { target="$1" conf_path="$2" batch_size="$3" input_path="$4" output_path="$5" $bin_path \ --target=$target \ --conf=$conf_path \ --cmd_type="get" \ --input_type="example_batch" \ --batch_size=$batch_size \ --input_file=$input_path \ --has_sort_id > $output_path } function compare_alg() { a_server_target="$1" b_server_target="$1" rm -f input_alg.pb run_alg $a_server_target $entry_agent_path 1 input_alg.pb output_alg_gpu.txt run_alg $b_server_target $entry_agent_path 1 input_alg.pb output_alg_cpu.txt diff -urN output_alg_gpu.txt output_alg_cpu.txt > compare_alg.diff cat compare_alg.diff } function warmup() { server_target="$1" for ((i=0; i<3; i++)); do run_alg $server_target $entry_agent_path 1 input_alg.pb output_alg_gpu.txt cat output_alg_gpu.txt done } bazel build :${bin_name} warmup $gpu_server_target # compare_alg $gpu_server_target $cpu_server_target # sync profile run_pro $gpu_server_target $entry_agent_path 128 1 300 $profile_data_dir # run_pro $cpu_server_target $entry_agent_path 128 1 600 $profile_data_dir # async profile for ((i=1; i<=6; i++)); do run_pro_async $gpu_server_target $entry_agent_path 128 11 300 $profile_data_dir done ================================================ FILE: monolith/agent_service/replica_manager.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from absl import logging from dataclasses import dataclass from dataclasses_json import dataclass_json from kazoo.client import KazooState from kazoo.protocol.states import WatchedEvent, EventType, ZnodeStat from kazoo.exceptions import NodeExistsError, NoNodeError import os, re import socket import sys import threading import time import traceback from typing import List, Dict, Union, Optional, Tuple from tensorflow.core.protobuf.error_codes_pb2 import Code as ErrorCode from monolith.agent_service.tfs_monitor import TFSMonitor from monolith.agent_service.utils import AgentConfig, ModelState, TFSServerType, DeployType, ZKPath from monolith.agent_service.agent_service_pb2 import ServerType from monolith.agent_service.data_def import ReplicaMeta from monolith.agent_service.backends import SyncBackend from monolith.native_training.model_export import export_state_utils from monolith.native_training.net_utils import AddressFamily from monolith.native_training.zk_utils import MonolithKazooClient, is_ipv6_only from monolith.native_training.metric import cli from monolith.native_training.runtime.parameter_sync.parameter_sync_pb2 import ClientConfig DEFAULT_USE_ARCHON = False class ReplicaWatcher(object): def __init__(self, zk_client: MonolithKazooClient, config: AgentConfig, use_archon: bool = False, zk_watch_address_family: str = AddressFamily.IPV4): self._zk = zk_client # the info of this replica self._conf: AgentConfig = config self._use_archon = use_archon if zk_watch_address_family == AddressFamily.IPV4 and is_ipv6_only(): logging.warning("zk_watch_address_family change to IPV6") self._zk_watch_address_family = AddressFamily.IPV6 else: self._zk_watch_address_family = zk_watch_address_family self.path_prefix = os.path.join('/', config.bzid, 'service', config.base_name) # /bzid/service/model_name/server_type:task -> replica -> (addr, stat) self._lock = threading.Lock() self.replicas: Dict[str, Dict[str, ReplicaMeta]] = {} self._has_stop = False self._should_poll = True self._thread = None @property def zk(self): return self._zk def watch_data(self): if self._conf.dc_aware: self.zk.ChildrenWatch(path=self.path_prefix, func=self._get_idc_cluster_children_watch( self.path_prefix)) else: self.zk.ChildrenWatch(path=self.path_prefix, func=self._get_task_children_watch( self.path_prefix)) self._thread = threading.Thread(target=self._poll, daemon=True) self._has_stop = False self._thread.start() def stop(self): try: self._has_stop = True if self._thread is not None: try: self._thread.join() self._thread = None except: self._thread = None finally: with self._lock: self.replicas.clear() def _get_idc_cluster_children_watch(self, path_prefix: str): _idc_cluster = set() def idc_cluster_children_watch(children: List[str]): if children is not None: for idc_cluster in children: if idc_cluster not in _idc_cluster: # idc_cluster -> idc:cluster _idc_cluster.add(idc_cluster) ic_path = os.path.join(path_prefix, idc_cluster) self.zk.ChildrenWatch(path=ic_path, func=self._get_task_children_watch(ic_path)) return idc_cluster_children_watch def _get_task_children_watch(self, path_prefix: str): _tasks = set() def task_children_watch(children: List[str]): if children is not None: for task in children: if task not in _tasks: # task -> entry/ps/dense:idx _tasks.add(task) task_path = os.path.join(path_prefix, task) self.zk.ChildrenWatch( path=task_path, func=self._get_replica_children_watch(task_path)) return task_children_watch def _get_replica_children_watch(self, task_path: str): _replicas = set() def replica_children_watch(children: List[str]): if children is not None: for replica in children: if replica not in _replicas: _replicas.add(replica) replica_path = os.path.join(task_path, replica) self.zk.DataWatch(path=replica_path, func=self._get_data_watch(replica_path)) return replica_children_watch def _get_data_watch(self, path): def data_watch(data: bytes, state: ZnodeStat, event: WatchedEvent): task_path = os.path.dirname(path) rnode = str(int(os.path.basename(path))) if data is None or len(data) == 0: with self._lock: if task_path in self.replicas: if rnode in self.replicas[task_path]: meta = self.replicas[task_path][rnode] meta.stat = ModelState.UNKNOWN else: return else: return else: meta = ReplicaMeta.deserialize(data) with self._lock: if event is None or event.type == EventType.CREATED: # in the first call, event is None if task_path in self.replicas: self.replicas[task_path][rnode] = meta else: self.replicas[task_path] = {rnode: meta} elif event.type == EventType.DELETED: if task_path in self.replicas.keys( ) and rnode in self.replicas[task_path].keys(): del self.replicas[task_path][rnode] if task_path in self.replicas.keys() and len( self.replicas[task_path]) == 0: del self.replicas[task_path] elif event.type == EventType.CHANGED: self.replicas[task_path][rnode] = meta elif event.type == EventType.NONE: meta.stat = ModelState.UNKNOWN self.replicas[task_path][rnode] = meta else: assert event.type == EventType.CHILD return data_watch def _poll(self): while not self._has_stop: time.sleep(60) if not self._should_poll: continue try: tasks = [] if self._conf.dc_aware: idc_clusters = self.zk.get_children(self.path_prefix) if idc_clusters: for ic in idc_clusters: ic_path = os.path.join(self.path_prefix, ic) ts = self.zk.get_children(ic_path) if ts: tasks.extend([f'{ic}/{t}' for t in ts]) else: ts = self.zk.get_children(self.path_prefix) if ts: tasks.extend(ts) replicas_tmp = {} for task in tasks: task_path = os.path.join(self.path_prefix, task) replicas = self.zk.get_children(task_path) replicas_tmp[task_path] = {} if replicas: for replica in replicas: replica_path = os.path.join(task_path, replica) value, _ = self.zk.get(replica_path) if value is not None: meta = ReplicaMeta.from_json(str(value, encoding='utf-8')) replicas_tmp[task_path][str(int(replica))] = meta with self._lock: old_paths, new_paths = {}, {} for task, replicas in self.replicas.items(): for replica in replicas: key = os.path.join(task, replica) old_paths[key] = self.replicas[task][replica] for task, replicas in replicas_tmp.items(): for replica in replicas: key = os.path.join(task, replica) new_paths[key] = replicas_tmp[task][replica] to_removed_replicas = set(old_paths) - set(new_paths) if self._conf.deploy_type == DeployType.MIXED or self._conf.deploy_type == DeployType.PS: server_type = TFSServerType.PS port_grpc, port_archon = self._conf.tfs_port_grpc, self._conf.tfs_ps_archon_port elif self._conf.deploy_type == DeployType.DENSE: server_type = TFSServerType.DENSE port_grpc, port_archon = self._conf.tfs_dense_port, self._conf.tfs_dense_archon_port else: server_type = None need_register_replicas: Dict[str, ReplicaMeta] = {} if server_type and to_removed_replicas: for i in self._conf.get_server_schedule_iter(server_type): for replica in to_removed_replicas: if f'/{server_type}:{i}/{self._conf.replica_id}' in replica: meta: ReplicaMeta = old_paths[replica] meta.address = f"{meta.address.split(':')[0]}:{port_grpc}" meta.address_ipv6 = f"{meta.address_ipv6.split(':')[0]}:{port_grpc}" meta.archon_address = f"{meta.archon_address.split(':')[0]}:{port_archon}" meta.archon_address_ipv6 = f"{meta.archon_address_ipv6.split(':')[0]}:{port_archon}" need_register_replicas[replica] = meta # update self.replicas self.replicas = replicas_tmp # register while need_register_replicas: zk_path, meta = need_register_replicas.popitem() replica_meta_bytes = bytes(meta.to_json(), encoding='utf-8') try: self.zk.retry(self.zk.create, path=zk_path, value=replica_meta_bytes, ephemeral=True, makepath=True, sequence=False) except NodeExistsError: logging.info(f'{zk_path} has already exists') except Exception as e: exc_type, exc_value, exc_traceback_obj = sys.exc_info() logging.log_every_n_seconds(logging.ERROR, f"exc_type: {exc_type}", 10 * 60) def get_all_replicas(self, server_type: ServerType, idc: str = None, cluster: str = None) -> Dict[str, List[str]]: st = ServerType.Name(server_type).lower() result = {} with self._lock: for path, replicas in self.replicas.items(): zk_path = ZKPath(path) dc_flag = zk_path.ship_in(idc, cluster) if self._conf.dc_aware else True if zk_path.server_type == st and dc_flag: key = os.path.join( zk_path.location, zk_path.task) if self._conf.dc_aware else zk_path.task addrs = [ pm.get_address(use_archon=self._use_archon, address_family=self._zk_watch_address_family) for pm in replicas.values() if pm and pm.stat == ModelState.AVAILABLE ] if key in result: result[key].extend(addrs) else: result[key] = addrs if (server_type == ServerType.PS and len(result) == 0) or (server_type == ServerType.DENSE and self._conf.dense_alone and len(result) == 0): logging.error(f'empty replicas {self.path_prefix}-{st}') logging.info('all replicas is ' + str(self.replicas)) return result def get_replicas(self, server_type: ServerType, task: int, idc: str = None, cluster: str = None) -> List[str]: st = ServerType.Name(server_type).lower() with self._lock: addrs = [] for path, replicas in self.replicas.items(): zk_path = ZKPath(path) dc_flag = zk_path.ship_in(idc, cluster) if self._conf.dc_aware else True if zk_path.server_type == st and int(zk_path.index) == task and dc_flag: if replicas: addrs.extend([ meta.get_address(use_archon=self._use_archon, address_family=self._zk_watch_address_family) for meta in replicas.values() if meta.stat == ModelState.AVAILABLE ]) return addrs def get_replica(self, server_type: ServerType, task: int, replica: int, idc: str = None, cluster: str = None) -> Optional[Union[List[str], str]]: st = ServerType.Name(server_type).lower() result = [] with self._lock: for path, replicas in self.replicas.items(): zk_path = ZKPath(path) dc_flag = zk_path.ship_in(idc, cluster) if self._conf.dc_aware else True if zk_path.server_type == st and int(zk_path.index) == task and dc_flag: for replica_id, meta in replicas.items(): if int(replica_id) == replica: if meta is not None and meta.stat == ModelState.AVAILABLE: result.append( meta.get_address( use_archon=self._use_archon, address_family=self._zk_watch_address_family)) if result: if len(result) == 1: return result[0] else: return result else: return None def get_replicas_with_extra_info(self, server_type: ServerType, task: int, idc: str = None, cluster: str = None) -> List[str]: st = ServerType.Name(server_type).lower() with self._lock: addr_dict = {} for path, replicas in self.replicas.items(): zk_path = ZKPath(path) dc_flag = zk_path.ship_in(idc, cluster) if self._conf.dc_aware else True if zk_path.server_type == st and int(zk_path.index) == task and dc_flag: if replicas: addr_dict.update({ meta.get_address(use_archon=self._use_archon, address_family=self._zk_watch_address_family): ClientConfig.TargetExtraInfo(idc=zk_path.idc, cluster=zk_path.cluster, replica_id=int(replica_id)) for replica_id, meta in replicas.items() if meta.stat == ModelState.AVAILABLE }) return addr_dict def to_sync_wrapper(self) -> SyncBackend: return SyncBackendWrapper(self) class ReplicaUpdater(object): def __init__(self, zk_client: MonolithKazooClient, config: AgentConfig): self._zk = zk_client self._conf: AgentConfig = config self.path_prefix = config.path_prefix self.model_monitor = TFSMonitor(config) self.meta = {} self._has_stop = False self._should_reregister = False self._should_update = True self._thread = None self._reregister_thread = None self._watch_update_thread = None self._entry_last_update_version = None self._metrics_cli = None self._metrics_cli_global = None self._tagkv = {'status': 'OK'} self._model_latest_version = {} self._model_last_update_ts = {} if self._conf.use_metrics: try: self.init_metrics() except: logging.error('init metrics error') exc_type, exc_value, exc_traceback_obj = sys.exc_info() logging.error(f"exc_type: {exc_type}") logging.error(f"exc_value: {exc_value}") traceback.print_tb(exc_traceback_obj, limit=10) else: logging.info("conf.use_metrics is false") def init_metrics(self): if "MONOLITH_METRIC_PREFIX" in os.environ: # In MLP Env prefix = os.environ.get("MONOLITH_METRIC_PREFIX") elif "TCE_PSM" in os.environ: # In Byterec Env prefix = os.environ.get("TCE_PSM") else: prefix = "data.monolith_serving." + self._conf.base_name self._metrics_cli = cli.get_cli(prefix=prefix) self._metrics_cli_global = cli.get_cli(prefix="data.monolith_serving.global") logging.info(f"after init_metrics, prefix is {prefix}") @property def zk(self): return self._zk @property def model_names(self): names = [] if self._conf.deploy_type == DeployType.MIXED or self._conf.deploy_type == DeployType.PS: for task_id in range(self._conf.num_ps): if task_id % self._conf.num_shard == self._conf.shard_id: names.append(f'{TFSServerType.PS}_{task_id}') if self._conf.deploy_type == DeployType.MIXED or self._conf.deploy_type == DeployType.ENTRY: names.append(TFSServerType.ENTRY) if self._conf.dense_alone and (self._conf.deploy_type == DeployType.MIXED or self._conf.deploy_type == DeployType.DENSE): names.append(f'{TFSServerType.DENSE}_0') return names @property def entry_path(self): return os.path.join(self.path_prefix, f'{TFSServerType.ENTRY}:0', str(self._conf.replica_id)) def ps_path(self, task_id: int): return os.path.join(self.path_prefix, f'{TFSServerType.PS}:{task_id}', str(self._conf.replica_id)) def dense_path(self): return os.path.join(self.path_prefix, f'{TFSServerType.DENSE}:0', str(self._conf.replica_id)) def _do_register(self, replica_path: str, grpc_port: int, archon_port: int): try: host = os.environ.get("MY_HOST_IP", socket.gethostbyname(socket.gethostname())) except: host = '0.0.0.0' try: defalut_host_ipv6 = socket.getaddrinfo(socket.gethostname(), None, socket.AF_INET6)[0][4][0] except: defalut_host_ipv6 = '::' host_ipv6 = os.environ.get("MY_HOST_IPV6") if not host_ipv6: host_ipv6 = defalut_host_ipv6 host_ipv6 = '[{}]'.format(host_ipv6) replica_meta = ReplicaMeta(address=f'{host}:{grpc_port}', address_ipv6=f'{host_ipv6}:{grpc_port}', stat=ModelState.UNKNOWN, archon_address=f'{host}:{archon_port}', archon_address_ipv6=f'{host_ipv6}:{archon_port}') self.meta[replica_path] = replica_meta replica_meta_bytes = bytes(replica_meta.to_json(), encoding='utf-8') node_stat = self.zk.exists(replica_path) if not node_stat: try: sequence = True if TFSServerType.ENTRY in replica_path and self._conf.replica_id == -1 else False real_path = self.zk.retry(self.zk.create, path=replica_path, value=replica_meta_bytes, ephemeral=True, makepath=True, sequence=sequence) if self._conf.replica_id == -1: self._conf.replica_id = int(os.path.basename(real_path)) del self.meta[replica_path] self.meta[real_path] = replica_meta except NodeExistsError: logging.info(f'{replica_path} has already exists') self.zk.retry(self.zk.set, path=replica_path, value=replica_meta_bytes) else: value, _ = self.zk.get(replica_path) if value != replica_meta_bytes: self.zk.retry(self.zk.set, path=replica_path, value=replica_meta_bytes) def register(self): if self._conf.deploy_type == DeployType.MIXED or self._conf.deploy_type == DeployType.ENTRY: if self._conf.replica_id == -1: replica_path = f'{self.path_prefix}/{TFSServerType.ENTRY}:0/0' else: replica_path = f'{self.path_prefix}/{TFSServerType.ENTRY}:0/{self._conf.replica_id:011d}' self._do_register(replica_path, self._conf.tfs_entry_port, self._conf.tfs_entry_archon_port) if self._conf.deploy_type == DeployType.MIXED or self._conf.deploy_type == DeployType.PS: for task_id in range(self._conf.num_ps): if task_id % self._conf.num_shard == self._conf.shard_id: self._do_register(self.ps_path(task_id), self._conf.tfs_ps_port, self._conf.tfs_ps_archon_port) if self._conf.dense_alone and (self._conf.deploy_type == DeployType.MIXED or self._conf.deploy_type == DeployType.DENSE): self._do_register(self.dense_path(), self._conf.tfs_dense_port, self._conf.tfs_dense_archon_port) def _do_update(self, name: str): if name.startswith(TFSServerType.ENTRY): replica_path = f'{self.path_prefix}/{TFSServerType.ENTRY}:0/{self._conf.replica_id:011d}' elif name.startswith(TFSServerType.PS): replica_path = self.ps_path(int(name.split("_")[1])) else: replica_path = self.dense_path() try: model_status = self.model_monitor.get_model_status(name) except Exception as e: replica_meta = self.meta[replica_path] if replica_meta.stat != ModelState.UNKNOWN: replica_meta.stat = ModelState.UNKNOWN replica_meta_bytes = bytes(replica_meta.to_json(), encoding='utf-8') try: self.zk.retry(self.zk.set, path=replica_path, value=replica_meta_bytes) except NoNodeError: self.zk.retry(self.zk.create, path=replica_path, value=replica_meta_bytes, ephemeral=True, makepath=True) return if model_status is not None and len(model_status) > 0: model_version_status = None if len(model_status) > 1: model_status.sort(key=lambda mvs: mvs.version, reverse=True) for m_status in model_status: if m_status.state == ModelState.AVAILABLE: model_version_status = m_status break if model_version_status is None: # check model version status model_version_status = model_status[0] status = model_version_status.status if status.error_code != ErrorCode.OK: raise Exception(status.error_message) # update state if changed stat = model_version_status.state replica_meta = self.meta[replica_path] if replica_meta.stat != stat: replica_meta.stat = stat replica_meta_bytes = bytes(replica_meta.to_json(), encoding='utf-8') try: self.zk.retry(self.zk.set, path=replica_path, value=replica_meta_bytes) except NoNodeError: self.zk.retry(self.zk.create, path=replica_path, value=replica_meta_bytes, ephemeral=True, makepath=True) def _updater(self): while not self._has_stop: curr_name = None time.sleep(1) if not self._should_update: continue try: for name in self.model_names: curr_name = name self._do_update(name) except Exception as e: exc_type, exc_value, exc_traceback_obj = sys.exc_info() logging.error(f"exc_type: {exc_type}") logging.error(f"exc_value: {exc_value}") traceback.print_tb(exc_traceback_obj, limit=10) logging.error(f"{e}, when model {curr_name} update") except (SystemExit, KeyboardInterrupt, GeneratorExit) as e: self._has_stop = True logging.error(f"{e}, when model {curr_name} update") def _get_latest_version_in_fs(self, name): exported_models_dir = os.path.join(self._conf.base_path, name) state = export_state_utils.get_export_saver_listener_state( exported_models_dir) if state.entries: return state.entries[-1].export_dir else: return None def _check_version(self): if not self._metrics_cli: return # metrics_v2 for name in self.model_names: model_status = self.model_monitor.get_model_status(name) req_ts = int(time.time()) if model_status is None or len(model_status) == 0: continue model_status.sort(key=lambda x: x.version, reverse=True) latest_model_status = model_status[0] latest_version = latest_model_status.version if latest_model_status.status.error_code != ErrorCode.OK: raise Exception(latest_model_status.status.error_message) tags = { "model_name": name, # It's safe when clueter or idc is None "idc": f"{self._conf.idc}:{self._conf.cluster}", "replica_id": str(self._conf.replica_id), "shard_id": str(self._conf.shard_id), "base_name": self._conf.base_name } self._metrics_cli.emit_store("serving_model.latest_version", latest_version, tags) if name in self._model_last_update_ts: interval = req_ts - self._model_last_update_ts[name] self._metrics_cli_global.emit_store("serving_model.since_last_update", interval, tags) if name not in self._model_latest_version or self._model_latest_version[ name] < latest_version: self._metrics_cli.emit_store("serving_model.update_ts", req_ts, tags) self._model_latest_version[name] = latest_version self._model_last_update_ts[name] = req_ts self._metrics_cli.flush() return def _watch_update(self): if not self._metrics_cli: return while not self._has_stop: time.sleep(60) try: self._check_version() except: exc_type, exc_value, exc_traceback_obj = sys.exc_info() logging.error(f"exc_type: {exc_type}") logging.log_every_n_seconds(logging.WARNING, traceback.format_exc(), 600) def _reregister(self): while not self._has_stop: time.sleep(10) if self._should_reregister: self.register() self._should_update = True # self._should_reregister = False def start(self): self.model_monitor.start() self._has_stop = False if self._thread is None: self._thread = threading.Thread(target=self._updater, daemon=True) self._thread.start() if self._reregister_thread is None: self._reregister_thread = threading.Thread(target=self._reregister, daemon=True) self._reregister_thread.start() if self._watch_update_thread is None and self._metrics_cli is not None: self._watch_update_thread = threading.Thread(target=self._watch_update, daemon=True) self._watch_update_thread.start() def stop(self): try: self._has_stop = True if self._thread is not None: try: self._thread.join() finally: self._thread = None if self._reregister_thread is not None: try: self._reregister_thread.join() finally: self._reregister_thread = None if self._watch_update_thread is not None: try: self._watch_update_thread.join() finally: self._watch_update_thread = None finally: self.model_monitor.stop() self.meta.clear() class ZKListener(object): def __init__(self, watcher: ReplicaWatcher, updater: ReplicaUpdater): self._watcher: ReplicaWatcher = watcher self._updater: ReplicaUpdater = updater self._has_lost = False def __call__(self, state: KazooState) -> bool: if state == KazooState.LOST: # The connection has been confirmed dead logging.warning( "Any ephemeral nodes will need to be recreated upon re-establishing a connection." ) self._has_lost = True self._watcher._should_poll = False self._updater._should_update = False elif state == KazooState.SUSPENDED: # Handle being disconnected from Zookeeper return False else: # Handle being connected/reconnected to Zookeeper if self._has_lost: logging.info( "connected/reconnected after lost, restart updater and watcher") self._updater._should_reregister = True time.sleep(5) # wait for updater reregister self._watcher._should_poll = True self._has_lost = False return False class ReplicaManager: def __init__(self, zk_client: MonolithKazooClient, config: AgentConfig): self._watcher = ReplicaWatcher(zk_client, config, DEFAULT_USE_ARCHON) self._updater = ReplicaUpdater(zk_client, config) self._conf = config listener = ZKListener(self._watcher, self._updater) zk_client.add_listener(listener) @property def watcher(self): return self._watcher @property def updater(self): return self._updater def start(self): self._updater.register() self._watcher.watch_data() self._updater.start() def stop(self): self._updater.stop() self._watcher.stop() def get_all_replicas(self, server_type: ServerType, idc: str = None, cluster: str = None) -> Dict[str, List[str]]: return self._watcher.get_all_replicas(server_type, idc, cluster) def get_replicas(self, server_type: ServerType, task: int, idc: str = None, cluster: str = None) -> List[str]: return self._watcher.get_replicas(server_type, task, idc, cluster) def get_replica(self, server_type: ServerType, task: int, replica: int, idc: str = None, cluster: str = None) -> Optional[Union[List[str], str]]: return self._watcher.get_replica(server_type, task, replica, idc, cluster) def is_ps_set_started(self): for i in range(self._conf.num_ps): replicas = self._watcher.get_replicas(ServerType.PS, i, self._conf.idc, self._conf.cluster) if replicas is None or len(replicas) == 0: return False logging.info( f"get_all_replicas: {self._watcher.get_all_replicas(ServerType.PS)}") return True def is_dense_set_started(self): replicas = self._watcher.get_replicas(ServerType.DENSE, 0) if replicas is None or len(replicas) == 0: return False logging.info(f"get_replicas: {replicas}") return True class SyncBackendWrapper(SyncBackend): def __init__(self, watcher: ReplicaWatcher): super(SyncBackendWrapper, self).__init__() self._watcher = watcher self._model_name = None def subscribe_model(self, model_name: str): self._model_name = model_name def get_sync_targets(self, sub_graph: str) -> Tuple[str, Dict]: ps, i = sub_graph.split("_")[:2] assert ps == "ps" return sub_graph, self._watcher.get_replicas_with_extra_info( ServerType.PS, int(i)) def start(self): self._watcher.watch_data() def stop(self): self._watcher.stop() ================================================ FILE: monolith/agent_service/replica_manager_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from absl import logging from kazoo.exceptions import NodeExistsError import socket import time import os import threading import unittest from monolith.agent_service import constants from monolith.agent_service import utils from monolith.agent_service.agent_service_pb2 import ServerType from monolith.agent_service.mocked_tfserving import FakeTFServing from monolith.agent_service.mocked_zkclient import FakeKazooClient from monolith.agent_service.replica_manager import ReplicaUpdater, ReplicaWatcher, \ ReplicaManager, ReplicaMeta, ModelState MODEL_NAME = 'test_model' BASE_PATH = f'/model/{MODEL_NAME}/saved_models' NUM_REPLICAS = 3 class ReplicaMgrTest(unittest.TestCase): tfs: FakeTFServing = None agent_conf: utils.AgentConfig = None @classmethod def setUpClass(cls) -> None: os.environ[constants.HOST_SHARD_ENV] = '5' os.environ['SHARD_ID'] = '1' os.environ['REPLICA_ID'] = '2' os.environ['TCE_INTERNAL_IDC'] = 'lf' os.environ['TCE_LOGICAL_CLUSTER'] = 'default' cls.agent_conf = utils.AgentConfig(bzid='bzid', base_name=MODEL_NAME, deploy_type='mixed', base_path=BASE_PATH, num_ps=20, num_shard=5, dc_aware=True) entry_cmd = cls.agent_conf.get_cmd('tensorflow_serving', server_type=utils.TFSServerType.ENTRY) start = entry_cmd.index('model_config_file') + len('model_config_file') + 1 end = entry_cmd.find(' ', start) cls.tfs_entry = FakeTFServing(model_config_file=entry_cmd[start:end], num_versions=1, port=cls.agent_conf.tfs_entry_port) ps_cmd = cls.agent_conf.get_cmd('tensorflow_serving', server_type=utils.TFSServerType.PS) start = ps_cmd.index('model_config_file') + len('model_config_file') + 1 end = ps_cmd.find(' ', start) cls.tfs_ps = FakeTFServing(model_config_file=ps_cmd[start:end], num_versions=1, port=cls.agent_conf.tfs_ps_port) cls.threads = [ threading.Thread(target=lambda: cls.tfs_entry.start()), threading.Thread(target=lambda: cls.tfs_ps.start()) ] for thread in cls.threads: thread.start() time.sleep(1) @classmethod def tearDownClass(cls) -> None: cls.tfs_entry.stop() cls.tfs_ps.stop() for thread in cls.threads: thread.join() def register(self, zk): path_prefix = self.agent_conf.path_prefix path_to_meta, idx = {}, 2 for replica_id in range(NUM_REPLICAS): for shard_id in range(self.agent_conf.num_shard): if shard_id == self.agent_conf.shard_id and replica_id == self.agent_conf.replica_id: continue for task_id in range(self.agent_conf.num_ps): if task_id % self.agent_conf.num_shard == shard_id: meta = ReplicaMeta( address=f'192.168.1.{idx}:{utils.find_free_port()}', stat=ModelState.AVAILABLE) replica_path = f'{path_prefix}/ps:{task_id}/{replica_id}' path_to_meta[replica_path] = meta idx += 1 replica_path = f'{path_prefix}/entry:0/{replica_id}' meta = ReplicaMeta(address=f'192.168.1.{idx}:{utils.find_free_port()}', stat=ModelState.AVAILABLE) path_to_meta[replica_path] = meta idx += 1 for replica_path, meta in path_to_meta.items(): replica_meta_bytes = bytes(meta.to_json(), encoding='utf-8') try: zk.retry(zk.create, path=replica_path, value=replica_meta_bytes, ephemeral=True, makepath=True) except NodeExistsError: logging.info(f'{replica_path} has already exists') zk.retry(zk.set, path=replica_path, value=replica_meta_bytes) if __name__ == "__main__": unittest.main() ================================================ FILE: monolith/agent_service/resource_utils.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from absl import logging import os import time import subprocess import re import psutil from google.protobuf import text_format import tensorflow as tf from typing import Dict, Union, List import monolith.agent_service.utils from monolith.agent_service.data_def import SubModelName, VersionPath, SubModelSize from monolith.native_training.model_export import export_pb2 from monolith.native_training.model_export import export_state_utils ROW = re.compile(r"^.+_(\d+)-\d+-of-\d+$") def _get_pod_cgroup_path(): cmd = ["cat", "/proc/1/cgroup"] try: out_bytes = subprocess.check_output(cmd) out_list = out_bytes.decode('utf-8').strip().split('\n') for line in out_list: if ':memory:' in line: return line.strip().split(':')[-1].strip('/') except Exception as e: return None _POD_CGROUP_PATH = _get_pod_cgroup_path() def exists(dirname: str) -> bool: return tf.io.gfile.isdir(dirname) or tf.io.gfile.exists(dirname) def open_hdfs(fname: Union[str, List[str]]): cmd = [_HADOOP_BIN, 'fs', '-text'] if isinstance(fname, (list, tuple)): cmd.extend(fname) else: cmd.append(fname) out_list = None cnt, max_try = 0, 3 while cnt < max_try: try: out_bytes = subprocess.check_output(cmd) out_list = out_bytes.decode('utf-8').strip().split('\n') break except Exception as e: logging.info(e) cnt += 1 assert out_list is not None for line in out_list: line = line.strip() if len(line) > 0: yield line def cal_model_info_v2( exported_models_path: str, ckpt: str = None, version: str = None) -> 'Dict[SubModelName, (SubModelSize, VersionPath)]': # 1) get all names of saved_models if os.path.isabs(exported_models_path): exported_models_path = exported_models_path.rstrip('/') else: exported_models_path = os.path.abspath(exported_models_path.rstrip('/')) if not tf.io.gfile.exists(exported_models_path): raise Exception(f"{exported_models_path} is not exists ") model_info = { sub_model_name: 0 for sub_model_name in tf.io.gfile.listdir(exported_models_path) if not sub_model_name.startswith('.') } # 2) ensure checkpoint ckpt_base_path = os.path.dirname(exported_models_path) if ckpt is None: checkpoint_state = tf.train.get_checkpoint_state(ckpt_base_path) if checkpoint_state is not None: ckpt = os.path.basename(checkpoint_state.model_checkpoint_path) global_step = -1 if ckpt is None else int(ckpt.split('-')[-1]) # 3) ensure version if version is None: com_versions = None for sub_model_name in tf.io.gfile.listdir(exported_models_path): if sub_model_name.startswith('.'): continue tfs_base_path = os.path.join(exported_models_path, sub_model_name) state = export_state_utils.get_export_saver_listener_state(tfs_base_path) if global_step >= 0 and state is not None: versions = set() for se in state.entries: _version = int(os.path.basename(se.export_dir)) versions.add(_version) if se.global_step == global_step: if version is None: version = _version else: assert version == _version break else: versions = set( int(num) for num in tf.io.gfile.listdir(tfs_base_path) if num.isnumeric()) if com_versions is None: com_versions = versions else: com_versions &= versions assert com_versions is not None and len(com_versions) > 0 version = version or sorted(com_versions)[-1] else: version = int(version) # 4) get dense part size of all saved_models for sub_model_name in model_info: version_path = os.path.join(exported_models_path, sub_model_name, str(version)) assert tf.io.gfile.exists(version_path) for (dir_name, _, file_names) in tf.io.gfile.walk(version_path): for fn in file_names: stat = tf.io.gfile.stat(os.path.join(dir_name, fn)) model_info[sub_model_name] += stat.length # 5) add assets length (sparse part size) assets_path = os.path.join(ckpt_base_path, f'{ckpt}.assets') if tf.io.gfile.exists(assets_path): for fn in tf.io.gfile.listdir(assets_path): matched = ROW.match(fn) if matched: key = f'ps_{matched.group(1)}' stat = tf.io.gfile.stat(os.path.join(assets_path, fn)) model_info[key] += stat.length return { sub_model_name: (size, os.path.join(exported_models_path, sub_model_name, str(version))) for sub_model_name, size in model_info.items() } def total_memory() -> int: memory_base = os.path.join("/sys/fs/cgroup/memory", _POD_CGROUP_PATH) limit_in_bytes = 0 with open(os.path.join(memory_base, 'memory.limit_in_bytes'), 'r') as stream: for line in stream: limit_in_bytes = int(line.strip()) if limit_in_bytes == 0: return int(os.environ.get('MY_MEM_LIMIT')) else: return limit_in_bytes def total_memory_v2() -> int: mem = psutil.virtual_memory() return mem.total def cal_available_memory() -> int: memory_base = os.path.join("/sys/fs/cgroup/memory", _POD_CGROUP_PATH) usage_in_bytes = 0 with open(os.path.join(memory_base, 'memory.usage_in_bytes'), 'r') as stream: for line in stream: usage_in_bytes = int(line.strip()) limit_in_bytes = 0 with open(os.path.join(memory_base, 'memory.limit_in_bytes'), 'r') as stream: for line in stream: limit_in_bytes = int(line.strip()) return limit_in_bytes - usage_in_bytes def cal_available_memory_v2() -> int: mem = psutil.virtual_memory() return mem.available class CPU(object): def __init__(self, cpuacct_file): self.cpuacct_file = cpuacct_file self.last_wall_clock = self.wall_clock() self.last_cpu_clock = self.cpu_clock() def wall_clock(self): try: # time_ns() only supported by python 3.7 total_time = time.time_ns() except Exception as e: total_time = subprocess.check_output(['date', '+%s%N']) return int(total_time) def cpu_clock(self): with open(self.cpuacct_file, 'r') as f: use_time = int(f.read()) return use_time def cpu_usage(self): current_wall_clock = self.wall_clock() current_cpu_clock = self.cpu_clock() delta_cpu_time = current_cpu_clock - self.last_cpu_clock delta_wall_time = current_wall_clock - self.last_wall_clock usage = delta_cpu_time / delta_wall_time self.last_wall_clock = current_wall_clock self.last_cpu_clock = current_cpu_clock return usage def num_cpu(): cpu_base = os.path.join("/sys/fs/cgroup/cpu", _POD_CGROUP_PATH) cfs_quota_us = 0 with open(os.path.join(cpu_base, 'cpu.cfs_quota_us'), 'r') as stream: for line in stream: cfs_quota_us = int(line.strip()) cfs_period_us = 0 with open(os.path.join(cpu_base, 'cpu.cfs_period_us'), 'r') as stream: for line in stream: cfs_period_us = int(line.strip()) if cfs_period_us == 0: return int(os.environ.get('MY_CPU_LIMIT')) else: return int(cfs_quota_us / cfs_period_us) def cal_cpu_usage(): cpu_base = os.path.join("/sys/fs/cgroup/cpu", _POD_CGROUP_PATH) cpuacct_file = os.path.join(cpu_base, 'cpuacct.usage') cpu = CPU(cpuacct_file) cpu_usages, cnt, max_try = [], 0, 5 while cnt < max_try: time.sleep(1) cpu_usages.append(round(cpu.cpu_usage() * 100, 2)) cnt += 1 return sum(cpu_usages) / max_try def cal_cpu_usage_v2() -> float: return psutil.cpu_percent() ================================================ FILE: monolith/agent_service/resource_utils_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from absl import logging import unittest from monolith.agent_service.resource_utils import cal_available_memory_v2, total_memory_v2, \ cal_cpu_usage_v2, cal_model_info_v2, cal_available_memory class UtilTest(unittest.TestCase): def test_cal_avaiable_memory_v2(self): total = total_memory_v2() available = cal_available_memory_v2() logging.info(f'the total memory is {total}, and {available} is available') self.assertTrue(0 < available < total) def test_cal_cpu_usage_v2(self): usage = cal_cpu_usage_v2() logging.info(f'the cpu usage is {usage}') self.assertTrue(0 <= usage <= 100) if __name__ == '__main__': unittest.main() ================================================ FILE: monolith/agent_service/run.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from absl import app, flags, logging from monolith.agent_service.agent import main as agent_main from monolith.agent_service.agent_client import main as agent_client_main from monolith.agent_service.tfs_client import main as tfs_client_main FLAGS = flags.FLAGS flags.DEFINE_enum("bin_name", "agent", ["agent", "agent_client", "tfs_client"], "bin_name: agent, agent_client") def main(_): if FLAGS.bin_name == 'agent': agent_main(_) elif FLAGS.bin_name == 'agent_client': agent_client_main(_) elif FLAGS.bin_name == 'tfs_client': tfs_client_main(_) else: raise ValueError("Unknown bin: {}".format(FLAGS.bin_name)) if __name__ == '__main__': app.run(main) ================================================ FILE: monolith/agent_service/svr_client.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from absl import logging, flags import grpc import os import socket from monolith.agent_service import utils from monolith.agent_service.agent_service_pb2 import HeartBeatRequest, ServerType, \ GetReplicasRequest from monolith.agent_service.agent_service_pb2_grpc import AgentServiceStub class SvrClient(object): def __init__(self, config) -> None: if isinstance(config, str): self.agent_conf = utils.AgentConfig.from_file(config) else: self.agent_conf = config self._stub = None @property def stub(self): if self._stub is None: local_host = socket.gethostbyname(socket.gethostname()) target = f'{os.environ.get("MY_HOST_IP", local_host)}:{self.agent_conf.agent_port}' channel = grpc.insecure_channel(target) self._stub = AgentServiceStub(channel) return self._stub def get_server_type(self, st): if isinstance(st, str): if FLAGS.server_type == 'ps': return ServerType.PS elif FLAGS.server_type == 'entry': return ServerType.ENTRY elif FLAGS.server_type == 'dense': return ServerType.DENSE else: raise Exception('server_type error') else: return st def heart_beat(self, server_type): server_type = self.get_server_type(server_type) request = HeartBeatRequest(server_type=server_type) resp = self.stub.HeartBeat(request) print(resp.addresses, flush=True) return resp def get_replicas(self, server_type, task): server_type = self.get_server_type(server_type) request = GetReplicasRequest(server_type=server_type, task=task) resp = self.stub.GetReplicas(request) print(resp.address_list.address, flush=True) return resp ================================================ FILE: monolith/agent_service/test_data/BUILD ================================================ package(default_visibility = [ "//monolith/integration_test:__subpackages__", "//monolith/agent_service:__subpackages__", ]) filegroup( name = "test_data", srcs = [ "inst.pbtext", "inst.json", "inst.dump", ] ) ================================================ FILE: monolith/agent_service/test_data/inst.dump ================================================ "root": "feature": 0: "fid": 0: "1" "name": "fc_a" 1: "fid": 0: "2" "name": "fc_b" 2: "fid": 0: "3" 1: "4" label": 0: 0 1: 1 "line_id": "actions": 0: 1 ================================================ FILE: monolith/agent_service/test_data/inst.json ================================================ { "fid": [ 12345 ], "label": [ 0, 1 ], "line_id": { "actions": [ 1 ] } } ================================================ FILE: monolith/agent_service/test_data/inst.pbtext ================================================ fid: 1 fid: 2 label: 0.0 label: 1.0 line_id { actions: 1 } feature { name: "a" fid: 3 } ================================================ FILE: monolith/agent_service/tfs_client.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from absl import app, flags, logging from datetime import datetime import grpc import json import random import socket import os import sys import uuid import time from struct import unpack from typing import List from google.protobuf import text_format from idl.matrix.proto.proto_parser_pb2 import Instance from multiprocessing import Pool import threading from tensorflow_serving.apis.get_model_metadata_pb2 import GetModelMetadataRequest from tensorflow_serving.apis.get_model_status_pb2 import GetModelStatusRequest from tensorflow_serving.apis.model_service_pb2_grpc import ModelServiceStub from tensorflow_serving.apis.prediction_service_pb2_grpc import PredictionServiceStub from tensorflow_serving.apis.predict_pb2 import PredictRequest from tensorflow_serving.apis.model_management_pb2 import ReloadConfigRequest from tensorflow_serving.config.model_server_config_pb2 import ModelServerConfig from monolith.agent_service import utils from monolith.agent_service.client import FLAGS from monolith.native_training.data.utils import enable_tob_env from monolith.native_training.data.feature_list import FeatureList, get_feature_name_and_slot from idl.matrix.proto.example_pb2 import Example, ExampleBatch, FeatureListType, Feature from idl.matrix.proto.line_id_pb2 import LineId from monolith.native_training import env_utils from monolith.native_training.model_export import data_gen_utils VALID_SLOTS = [] _NUM_SLOTS = 6 _VOCAB_SIZES = [5, 5, 5, 5, 5, 5] flags.DEFINE_string("signature_name", "serving_default", "signature name") flags.DEFINE_string("feature_list", None, "feature_list for prediction") flags.DEFINE_enum("file_type", 'pb', ['pb', 'pbtxt'], "The input file type") flags.DEFINE_integer("batch_size", 8, "batch_size for prediction") flags.DEFINE_bool("lagrangex_header", False, "wheather has lagrangex_header") flags.DEFINE_bool("has_sort_id", False, "wheather has sort_id") flags.DEFINE_bool("kafka_dump", False, "wheather has kafka_dump") flags.DEFINE_bool("kafka_dump_prefix", False, "wheather has kafka_dump_prefix") flags.DEFINE_integer("parallel_num", 1, "parallel_num for profile") flags.DEFINE_integer("profile_duration", 600, "second for profile") flags.DEFINE_string("profile_data_dir", None, "profile input dir") SKIP_LIST = { '-', '_lt_', '_st_', '_lt', '_st', '_cp_', '_recent_', '_cp', '_recent' } def read_header(stream): int_size = 8 if FLAGS.lagrangex_header: stream.read(int_size) else: aggregate_page_sortid_size = 0 if FLAGS.kafka_dump_prefix: size = unpack(" str: eb = ExampleBatch() eb.batch_size = batch_size for feature in feature_list: flag = False for s in SKIP_LIST: if s in feature.name: flag = True break if flag: continue if not ("_id" in feature.name or "_name" in feature.name): continue named_feature_list = eb.named_feature_list.add() named_feature_list.name = feature.name for _ in range(batch_size): _feature = named_feature_list.feature.add() if feature.method.lower().startswith( 'vectortop') and feature.args is not None: if len(feature.args) > 0 and feature.args[0].isnumeric(): num = int(feature.args[0]) if num > 0: num = random.randint(1, num) fids = [(feature.slot << 48) | random.randint(1, sys.maxsize - 1) for _ in range(num)] _feature.fid_v2_list.value.extend(fids) else: fid = (feature.slot << 48) | random.randint(1, (1 << 48) - 1) _feature.fid_v2_list.value.append(fid) named_feature_list = eb.named_feature_list.add() named_feature_list.name = '__LINE_ID__' for _ in range(batch_size): _feature = named_feature_list.feature.add() line_id = LineId() line_id.sample_rate = 0.001 line_id.req_time = int(datetime.now().timestamp() - random.randint(1, 1000)) line_id.actions.extend([random.randint(1, 3), random.randint(3, 5)]) _feature.bytes_list.value.append(line_id.SerializeToString()) return eb.SerializeToString() def get_instance_proto(input_file: str = None, batch_size: int = 256): if input_file is None: instances = [generate_random_instance() for _ in range(batch_size)] else: assert os.path.exists(input_file) with open(input_file, 'rb') as stream: instances = [] for _ in range(batch_size): inst = Instance() inst.ParseFromString(read_data(stream)) instances.append(inst.SerializeToString()) return utils.make_tensor_proto(instances) def get_example_batch_proto(input_file: str = None, feature_list: FeatureList = None, batch_size: int = 256, file_type: str = 'pb'): if input_file is None: example_batch = generate_random_example_batch(feature_list, batch_size) else: assert os.path.exists(input_file) eb = ExampleBatch() if file_type == 'pb': with open(input_file, 'rb') as stream: eb.ParseFromString(read_data(stream)) else: with open(input_file, 'r') as stream: txt = stream.read() text_format.Parse(txt, eb) example_batch = eb.SerializeToString() return utils.make_tensor_proto([example_batch]) def gen_random_file(input_file, variant_type="example_batch"): assert input_file is not None assert len(VALID_SLOTS) > 0 parser_args = data_gen_utils.ParserArgs( sparse_features=[ get_feature_name_and_slot(slot)[0] for slot in VALID_SLOTS ], extra_features=[ 'uid', 'sample_rate', 'req_time', 'actions', 'stay_time', 'item_id', 'page', 'chnid' ], extra_feature_shapes=[1, 1, 1, 1, 1, 1, 1, 1], batch_size=FLAGS.batch_size, variant_type=variant_type) actions = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] data_gen_utils.gen_random_data_file(input_file, parser_args, sort_id=FLAGS.has_sort_id, kafka_dump=FLAGS.kafka_dump, num_batch=1, actions=actions) return input_file def get_example_batch_proto_v2(input_file: str): if not os.path.exists(input_file): gen_random_file(input_file) eb = ExampleBatch() with open(input_file, 'rb') as stream: eb.ParseFromString(read_data(stream)) # use same user feature in one example_batch user_fname_set = set( [get_feature_name_and_slot(slot)[0] for slot in user_features]) for named_feature_list in eb.named_feature_list: if named_feature_list.name in user_fname_set: named_feature_list.type = FeatureListType.SHARED new_feature = [named_feature_list.feature[0]] * eb.batch_size del named_feature_list.feature[:] named_feature_list.feature.extend(new_feature) example_batch = eb.SerializeToString() return utils.make_tensor_proto([example_batch]) def get_example_batch_to_instance(input_file: str, file_type: str): assert os.path.exists(input_file) eb = ExampleBatch() if file_type == 'pb': with open(input_file, 'rb') as stream: eb.ParseFromString(read_data(stream)) else: with open(input_file, 'r') as stream: txt = stream.read() text_format.Parse(txt, eb) inst_list = [] mask = (1 << 48) - 1 for i in range(eb.batch_size): inst = Instance() for named_feature_list in eb.named_feature_list: if named_feature_list.type == FeatureListType.SHARED: efeat = named_feature_list.feature[0] else: efeat = named_feature_list.feature[i] if named_feature_list.name == '__LABEL__': inst.label.extend(efeat.float_list.value) elif named_feature_list.name == '__LINE_ID__': inst.line_id.ParseFromString(efeat.bytes_list.value[0]) elif len(efeat.fid_v1_list.value) > 0: ifeat = inst.feature.add() ifeat.name = named_feature_list.name fid = efeat.fid_v1_list.value[0] slot_id = fid >> 54 fid_v2 = [(slot_id << 48) | (mask & v) for v in efeat.fid_v1_list.value] ifeat.fid.extend(fid_v2) elif len(efeat.fid_v2_list.value) > 0: ifeat = inst.feature.add() ifeat.name = named_feature_list.name ifeat.fid.extend(efeat.fid_v2_list.value) elif len(efeat.float_list.value) > 0: ifeat = inst.feature.add() ifeat.name = named_feature_list.name ifeat.float_value.extend(efeat.float_list.value) elif len(efeat.double_list.value) > 0: ifeat = inst.feature.add() ifeat.name = named_feature_list.name ifeat.float_value.extend(efeat.double_list.value) elif len(efeat.int64_list.value) > 0: ifeat = inst.feature.add() ifeat.name = named_feature_list.name ifeat.int64_value.extend(efeat.int64_list.value) elif len(efeat.bytes_list.value) > 0: ifeat = inst.feature.add() ifeat.name = named_feature_list.name ifeat.bytes_value.extend(efeat.bytes_list.value) else: pass inst_list.append(inst.SerializeToString()) return utils.make_tensor_proto(inst_list) class ProfileThread(threading.Thread): def __init__(self, job_id, model_name, stub_list, repeat_time, data_cache): super(ProfileThread, self).__init__() self._job_id = job_id self._model_name = model_name self._stub_list = stub_list self._repeat_time = repeat_time self._data_cache = data_cache self._data_size = len(data_cache) self._req_count = 0 self._req_time_ms_list = [] def run(self): run_st = int(time.time()) run_ed = run_st show_count = 0 while run_ed - run_st < self._repeat_time: try: if self._job_id == 0 and (run_ed - run_st) >= 60 * show_count: logging.info("Processing {}. Time: {}/{}(s)".format(self._req_count, run_ed - run_st, self._repeat_time)) show_count += 1 request = PredictRequest() request.model_spec.CopyFrom( utils.gen_model_spec(self._model_name, signature_name=FLAGS.signature_name)) select = random.randint(0, self._data_size - 1) request.inputs["example_batch"].CopyFrom(self._data_cache[select]) st = time.time() * 1000 # ms select = random.randint(0, len(self._stub_list) - 1) response = self._stub_list[select].Predict(request, 30) ed = time.time() * 1000 # ms req_time_ms = ed - st self._req_time_ms_list.append(req_time_ms) self._req_count += 1 except Exception as e: logging.info("Warning! call request failed. {}".format(repr(e))) pass run_ed = int(time.time()) def get_result(self): self.join() return self._req_time_ms_list def main(_): enable_tob_env() env_utils.setup_host_ip() agent_conf = utils.AgentConfig.from_file(FLAGS.conf) host = os.environ.get("MY_HOST_IP", socket.gethostbyname(socket.gethostname())) model_name = FLAGS.model_name if model_name is None: if agent_conf.deploy_type == utils.DeployType.PS: model_name = 'ps_{}'.format(agent_conf.shard_id) elif agent_conf.deploy_type == utils.DeployType.DENSE: model_name = utils.TFSServerType.DENSE else: model_name = utils.TFSServerType.ENTRY if agent_conf.deploy_type == utils.DeployType.PS: target = FLAGS.target or f"{host}:{agent_conf.tfs_ps_port}" elif agent_conf.deploy_type == utils.DeployType.DENSE: target = FLAGS.target or f"{host}:{agent_conf.tfs_dense_port}" else: target = FLAGS.target or f"{host}:{agent_conf.tfs_entry_port}" target_list = target.split(',') channel_list = [grpc.insecure_channel(tg) for tg in target_list] channel = grpc.insecure_channel(target_list[0]) if FLAGS.cmd_type == 'status': stub = ModelServiceStub(channel) request = GetModelStatusRequest() request.model_spec.CopyFrom( utils.gen_model_spec(model_name, signature_name=FLAGS.signature_name)) print(stub.GetModelStatus(request)) elif FLAGS.cmd_type == 'meta': stub = PredictionServiceStub(channel) request = GetModelMetadataRequest() request.model_spec.CopyFrom( utils.gen_model_spec(model_name, signature_name=FLAGS.signature_name)) request.metadata_field.extend( ['base_path', 'num_versions', 'signature_name']) response = stub.GetModelMetadata(request) print(response) elif FLAGS.cmd_type == 'load': request = ReloadConfigRequest() model_configs = ModelServerConfig() with open(FLAGS.input_file, 'r') as stream: txt = stream.read() text_format.Parse(txt, model_configs) request.config.CopyFrom(model_configs) stub = ModelServiceStub(channel) response = stub.HandleReloadConfigRequest(request) logging.info(f'{model_configs} load done!') return response.status elif FLAGS.cmd_type == 'profile': assert len(VALID_SLOTS) > 0 # ./tfs_client --conf=/path/agent.conf --cmd_type="get" --input_type="example_batch" --batch_size=128 --has_sort_id data_path_list = [] base_data_dir = FLAGS.profile_data_dir; assert base_data_dir is not None if not os.path.exists(base_data_dir): os.makedirs(base_data_dir) for file_name in os.listdir(base_data_dir): data_path_list.append(os.path.join(base_data_dir, file_name)) data_num = 500 from tqdm import tqdm if len(data_path_list) < data_num: add_num = data_num - len(data_path_list) for i in tqdm(range(add_num), desc="gen_random_file"): data_path = os.path.join(base_data_dir, "{}.pb".format(uuid.uuid1())) gen_random_file(data_path) data_path_list.append(data_path) data_path_list = data_path_list[:data_num] data_cache = [] for data_path in tqdm(data_path_list, desc="read_data"): data_cache.append(get_example_batch_proto_v2(data_path)) parallel_num = FLAGS.parallel_num repeat_time = FLAGS.profile_duration # stub = PredictionServiceStub(channel) stub_list = [PredictionServiceStub(chnl) for chnl in channel_list] thread_list = [] e2e_st = time.time() * 1000 # ms for i in range(parallel_num): thread = ProfileThread(i, model_name, stub_list, repeat_time, data_cache) thread.start() thread_list.append(thread) total_req_time_ms_list = [] for thread in thread_list: req_time_ms_list = thread.get_result() total_req_time_ms_list.extend(req_time_ms_list) e2e_ed = time.time() * 1000 # ms if len(total_req_time_ms_list) > 0: avg_req_time_ms = sum(total_req_time_ms_list) / len( total_req_time_ms_list) total_req_time_ms_list.sort() p99_req_time_ms = total_req_time_ms_list[int( round((len(total_req_time_ms_list) - 1) * 0.99))] qps = len(total_req_time_ms_list) * 1000 / (e2e_ed - e2e_st) else: avg_req_time_ms = 0 p99_req_time_ms = 0 qps = 0 logging.info("[Profile] Count: {}, Avg Latency: {}, P99 Latency: {}, QPS: {}".format( len(total_req_time_ms_list), avg_req_time_ms, p99_req_time_ms, qps)) else: # get # url = f"http://{target}/v1/models/{model_name}:{FLAGS.signature_name}" # cmd = ['curl', '-d', f"'{FLAGS.inputs}'", '-X', 'POST', url] # output = subprocess.check_output(cmd, shell=True) # print(output) stub = PredictionServiceStub(channel) request = PredictRequest() request.model_spec.CopyFrom( utils.gen_model_spec(model_name, signature_name=FLAGS.signature_name)) if FLAGS.input_type == 'instance': if not os.path.exists(FLAGS.input_file): gen_random_file(FLAGS.input_file, "instance") tensor_proto = get_instance_proto(FLAGS.input_file, FLAGS.batch_size) request.inputs["instances"].CopyFrom(tensor_proto) elif FLAGS.input_type == 'example_batch': if FLAGS.input_file is None: input_file = "{}.pb".format(uuid.uuid1()) else: input_file = FLAGS.input_file tensor_proto = get_example_batch_proto_v2(input_file) request.inputs["example_batch"].CopyFrom(tensor_proto) else: tensor_proto = get_example_batch_to_instance(FLAGS.input_file, FLAGS.file_type) request.inputs["instances"].CopyFrom(tensor_proto) response = stub.Predict(request, 30) print(response) if __name__ == "__main__": app.run(main) ================================================ FILE: monolith/agent_service/tfs_client_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from absl import logging, app, flags import json import os import tempfile import unittest from monolith.agent_service.tfs_client import get_instance_proto from monolith.agent_service.tfs_client import get_example_batch_to_instance FLAGS = flags.FLAGS class TFSClientTest(unittest.TestCase): def test_get_instance_proto(self): tensor_proto = get_instance_proto() self.assertEqual(tensor_proto.dtype, 7) self.assertEqual(tensor_proto.tensor_shape.dim[0].size, 256) def test_get_example_batch_to_instance_from_pb(self): file_name = "monolith/native_training/data/training_instance/examplebatch.data" FLAGS.lagrangex_header = True get_example_batch_to_instance(file_name, 'pb') def test_get_example_batch_to_instance_from_pbtxt(self): file_name = "monolith/agent_service/example_batch.pbtxt" FLAGS.lagrangex_header = True get_example_batch_to_instance(file_name, 'pbtxt') def main(_): unittest.main() if __name__ == "__main__": app.run(main) ================================================ FILE: monolith/agent_service/tfs_monitor.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from absl import logging import os import grpc from threading import RLock from functools import singledispatchmethod from typing import Dict, Iterable, Union, List from tensorflow.core.protobuf.error_codes_pb2 import Code from tensorflow_serving.util.status_pb2 import StatusProto from tensorflow_serving.apis.get_model_status_pb2 import ModelVersionStatus, GetModelStatusRequest from tensorflow_serving.apis.model_management_pb2 import ReloadConfigRequest from tensorflow_serving.apis.prediction_service_pb2_grpc import PredictionServiceStub from tensorflow_serving.apis.model_service_pb2_grpc import ModelServiceStub from tensorflow_serving.config.model_server_config_pb2 import ModelServerConfig from monolith.agent_service.data_def import PublishMeta, SubModelName, TFSModelName, \ VersionPath, PublishType as PType from monolith.agent_service.utils import AgentConfig, gen_model_spec, gen_model_config, \ DeployType, TFSServerType, get_local_ip, DEFAULT_MODEL_CONFIG State = ModelVersionStatus.State class TFSMonitor(object): def __init__(self, config: AgentConfig): self._conf: AgentConfig = config self._host = None self._lock = RLock() self.stubs = { TFSServerType.ENTRY: {}, TFSServerType.PS: {}, TFSServerType.DENSE: {} } @property def host(self): if self._host is None or self._host in {'', 'localhost', '127.0.0.1'}: self._host = get_local_ip() return self._host def get_addr(self, sub_model_name: SubModelName) -> str: if self._conf.deploy_type == DeployType.MIXED: if self.is_entry(sub_model_name): return f"{self.host}:{self._conf.tfs_entry_port}" elif self.is_ps(sub_model_name): return f"{self.host}:{self._conf.tfs_ps_port}" else: return f"{self.host}:{self._conf.tfs_dense_port}" elif self._conf.deploy_type == DeployType.ENTRY: assert self.is_entry(sub_model_name) return f"{self.host}:{self._conf.tfs_entry_port}" elif self._conf.deploy_type == DeployType.PS: assert self.is_ps(sub_model_name) return f"{self.host}:{self._conf.tfs_ps_port}" elif self._conf.deploy_type == DeployType.DENSE: assert self.is_dense(sub_model_name) return f"{self.host}:{self._conf.tfs_dense_port}" else: raise RuntimeError(f'deploy_type {self._conf.deploy_type} is error') def get_service_type(self, sub_model_name: SubModelName): if self._conf.deploy_type == DeployType.ENTRY: return TFSServerType.ENTRY if sub_model_name.startswith( TFSServerType.ENTRY) else None elif self._conf.deploy_type == DeployType.PS: return TFSServerType.PS if sub_model_name.startswith( TFSServerType.PS) else None elif self._conf.deploy_type == DeployType.DENSE: return TFSServerType.DENSE if sub_model_name.startswith( TFSServerType.DENSE) else None else: assert self._conf.deploy_type == DeployType.MIXED if sub_model_name.startswith(TFSServerType.ENTRY): return TFSServerType.ENTRY elif sub_model_name.startswith(TFSServerType.PS): return TFSServerType.PS elif sub_model_name.startswith(TFSServerType.DENSE): return TFSServerType.DENSE else: return None def is_entry(self, sub_model_name: str): return sub_model_name.startswith('entry') def is_ps(self, sub_model_name: str): return sub_model_name.startswith('ps') def is_dense(self, sub_model_name: str): return sub_model_name.startswith('dense') def connect(self): if self._conf.deploy_type in {DeployType.MIXED, DeployType.ENTRY}: entry_channel = grpc.insecure_channel( f'{self.host}:{self._conf.tfs_entry_port}') self.stubs[TFSServerType.ENTRY]['channel'] = entry_channel self.stubs[TFSServerType.ENTRY]['model_service'] = ModelServiceStub( entry_channel) self.stubs[TFSServerType.ENTRY][ 'prediction_service'] = PredictionServiceStub(entry_channel) if self._conf.deploy_type in {DeployType.MIXED, DeployType.PS}: ps_channel = grpc.insecure_channel( f'{self.host}:{self._conf.tfs_ps_port}') self.stubs[TFSServerType.PS]['channel'] = ps_channel self.stubs[TFSServerType.PS]['model_service'] = ModelServiceStub( ps_channel) self.stubs[TFSServerType. PS]['prediction_service'] = PredictionServiceStub(ps_channel) if self._conf.dense_alone and self._conf.deploy_type in { DeployType.MIXED, DeployType.DENSE }: dense_channel = grpc.insecure_channel( f'{self.host}:{self._conf.tfs_dense_port}') self.stubs[TFSServerType.DENSE]['channel'] = dense_channel self.stubs[TFSServerType.DENSE]['model_service'] = ModelServiceStub( dense_channel) self.stubs[TFSServerType.DENSE][ 'prediction_service'] = PredictionServiceStub(dense_channel) def start(self): self.stubs = { TFSServerType.ENTRY: {}, TFSServerType.PS: {}, TFSServerType.DENSE: {} } self.connect() def stop(self): if len(self.stubs) > 0: for stub in self.stubs.values(): if 'channel' in stub: try: stub['channel'].close() stub['model_service'] = None stub['prediction_service'] = None except: logging.error('stop channel fail!') self.stubs.clear() @singledispatchmethod def get_model_status(self, arg): raise NotImplementedError("get_model_status is not implemented!") @get_model_status.register def _(self, pm: PublishMeta, fix_dense_version: bool = False): # return 'Dict[TFSModelName, (VersionPath, ModelVersionStatus)]' with self._lock: model_status: Dict[str, State] = {} for sub_model_name, smvpath in pm.sub_models.items(): service_type = self.get_service_type(sub_model_name) if service_type is None: continue tfs_model_name = f'{pm.model_name}:{sub_model_name}' request = GetModelStatusRequest() # TODO(ltli): 这步修改不确定 is_dense_node = ( (not self._conf.dense_alone and self.is_entry(sub_model_name)) or (self._conf.dense_alone and self.is_dense(sub_model_name))) if not fix_dense_version and is_dense_node: request.model_spec.CopyFrom(gen_model_spec(tfs_model_name)) else: version = int(os.path.basename(smvpath)) request.model_spec.CopyFrom(gen_model_spec(tfs_model_name, version)) stub: ModelServiceStub = self.stubs[service_type]['model_service'] try: model_version_status = stub.GetModelStatus( request).model_version_status if model_version_status is None or len(model_version_status) == 0: status = ModelVersionStatus( state=State.UNKNOWN, status=StatusProto(error_code=Code.NOT_FOUND, error_message=f'{tfs_model_name} not found')) else: # if there are more than one version, select the latest one model_version_status = sorted(model_version_status, key=lambda mvs: mvs.version) status = model_version_status[-1] except grpc._channel._InactiveRpcError as e: logging.info(repr(e)) status = ModelVersionStatus(state=State.UNKNOWN, status=StatusProto( error_code=e.code().value[0], error_message=e.details())) model_status[tfs_model_name] = (smvpath, status) return model_status @get_model_status.register def _(self, name: str, version: Union[int, str] = None, signature_name: str = None) -> List[ModelVersionStatus]: """Get model version status :param name: The model name :param version: The version of model. If not specify version, information about all versions of the model will be returned. :return a list of ModelVersionStatus, which has three attribute: - version: int, Model version. - state: State, Model state, A Enum of UNKNOWN, START, LOADING, AVAILABLE, UNLOADING, END. - status: StatusProto, Model status. """ with self._lock: service_type = self.get_service_type(SubModelName(name)) if service_type is None: return [] else: request = GetModelStatusRequest() request.model_spec.CopyFrom( gen_model_spec(name, version, signature_name)) stub = self.stubs[service_type]['model_service'] return stub.GetModelStatus(request).model_version_status def gen_model_config( self, pms: Iterable[PublishMeta], fix_dense_version: bool = False) -> Dict[str, ModelServerConfig]: model_configs = { TFSServerType.ENTRY: ModelServerConfig(), TFSServerType.PS: ModelServerConfig(), TFSServerType.DENSE: ModelServerConfig() } for pm in pms: if pm.ptype == PType.UNLOAD: continue for sub_model_name, smv_path in pm.sub_models.items(): tfs_model_name = f'{pm.model_name}:{sub_model_name}' service_type = self.get_service_type(sub_model_name) if service_type is None: continue base_path = os.path.dirname(smv_path) # TODO(ltli): 这步修改不确定 is_dense_node = ( (not self._conf.dense_alone and self.is_entry(sub_model_name)) or (self._conf.dense_alone and self.is_dense(sub_model_name))) if is_dense_node: version_policy = 'specific' if fix_dense_version else 'latest' version_data = int( os.path.basename(smv_path)) if fix_dense_version else 1 else: version_policy = 'specific' version_data = int(os.path.basename(smv_path)) model_config = gen_model_config(tfs_model_name, base_path, version_policy, version_data) model_configs[service_type].model_config_list.config.append( model_config) return model_configs def handle_reload_config_request( self, service_type: str, model_configs: ModelServerConfig) -> StatusProto: with self._lock: request = ReloadConfigRequest() # keep default model in memory, incase no model in tfs model_config_list = model_configs.model_config_list.config if not any(mc.name == 'default' for mc in model_config_list): model_config = model_config_list.add() model_config.CopyFrom(DEFAULT_MODEL_CONFIG) request.config.CopyFrom(model_configs) if service_type == TFSServerType.ENTRY: port = self._conf.tfs_entry_port elif service_type == TFSServerType.PS: port = self._conf.tfs_ps_port else: port = self._conf.tfs_dense_port logging.info(f'{service_type} load ({port}): \n{request}') try: response = self.stubs[service_type][ 'model_service'].HandleReloadConfigRequest(request) logging.info(f'{service_type} load done!') except Exception as e: logging.info(repr(e)) raise e return response.status ================================================ FILE: monolith/agent_service/tfs_monitor_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import socket import time import threading import random import unittest from monolith.agent_service import constants from monolith.agent_service import utils from monolith.agent_service.tfs_monitor import TFSMonitor from monolith.agent_service.mocked_tfserving import FakeTFServing from tensorflow_serving.config.model_server_config_pb2 import ModelServerConfig from monolith.agent_service.data_def import PublishMeta, SubModelName, TFSModelName, \ VersionPath, PublishType as PType from monolith.agent_service.utils import AgentConfig, gen_model_spec, gen_model_config, \ TFSServerType, get_local_ip MODEL_NAME = 'test_model' BASE_PATH = f'/tmp/{MODEL_NAME}/monolith' version = '1634631496' path = '/tmp/monolith/agent_service/test_data/ckpt/exported_models/{}/{}' class TFSMonitorTest(unittest.TestCase): tfs: FakeTFServing = None monitor: TFSMonitor = None @classmethod def setUpClass(cls) -> None: os.environ[constants.HOST_SHARD_ENV] = '10' os.environ['SHARD_ID'] = '1' os.environ['REPLICA_ID'] = '2' cls.agent_conf = utils.AgentConfig(bzid='bzid', deploy_type='mixed') cls.tfs_entry = FakeTFServing(num_versions=2, port=cls.agent_conf.tfs_entry_port, model_config_file=ModelServerConfig()) cls.tfs_ps = FakeTFServing(num_versions=2, port=cls.agent_conf.tfs_ps_port, model_config_file=ModelServerConfig()) entry = threading.Thread(target=lambda: cls.tfs_entry.start()) entry.start() ps = threading.Thread(target=lambda: cls.tfs_ps.start()) ps.start() time.sleep(2) cls.monitor = TFSMonitor(cls.agent_conf) cls.monitor.connect() cls.data = {} @classmethod def tearDownClass(cls) -> None: cls.monitor.stop() cls.tfs_entry.stop() cls.tfs_ps.stop() def setUp(self): sub_models: Dict[SubModelName, VersionPath] = { 'entry': path.format('entry', version), 'ps_0': path.format('ps_0', version), 'ps_3': path.format('ps_3', version), 'ps_5': path.format('ps_5', version) } pm = PublishMeta(shard_id=self.agent_conf.shard_id, replica_id=self.agent_conf.replica_id, model_name='test_1', num_ps=5, sub_models=sub_models) self.data['setUp'] = self.monitor.get_model_status(pm) def tearDown(self): sub_models: Dict[SubModelName, VersionPath] = { 'entry': path.format('entry', version), 'ps_0': path.format('ps_0', version), 'ps_3': path.format('ps_3', version), 'ps_5': path.format('ps_5', version) } pm = PublishMeta(shard_id=self.agent_conf.shard_id, replica_id=self.agent_conf.replica_id, model_name='test_1', num_ps=5, sub_models=sub_models) time.sleep(1) before_status = self.data['setUp'] after_status = self.monitor.get_model_status(pm) self.assertEqual(len(before_status), len(after_status)) if self.data['execute'] == 'reload_config': for tfs_model_name, (bvp, bstate) in before_status.items(): (avp, astate) = after_status[tfs_model_name] self.assertEqual(bvp, avp) self.assertTrue(bstate.version == -1 and bstate.status.error_code == 5) # NOT_FOUND if astate.version == -1: pass elif astate.version == 1: self.assertTrue(tfs_model_name.endswith('entry')) else: self.assertEqual(astate.version, int(os.path.basename(bvp))) else: for tfs_model_name, (bvp, bstate) in before_status.items(): (avp, astate) = after_status[tfs_model_name] self.assertEqual(astate.version, -1) self.assertEqual(bvp, avp) if bstate.version == -1: self.assertTrue(bstate.status.error_code == 5) # NOT_FOUND else: self.assertTrue(bstate.version > 0) def test_reload_config(self): pms = [] for i in range(10): num_ps = random.randint(5, 20) sub_models = { f'ps_{i}': path.format(f'ps_{i}', version) for i in range(num_ps) if i % 3 == 0 } sub_models[TFSServerType.ENTRY] = path.format(f'entry', version) pm = PublishMeta(shard_id=self.agent_conf.shard_id, replica_id=self.agent_conf.replica_id, model_name=f'test_{i}', num_ps=num_ps, sub_models=sub_models) pms.append(pm) model_configs: Dict[str, ModelServerConfig] = self.monitor.gen_model_config(pms) for service_type, model_config in model_configs.items(): if len(model_config.model_config_list.config) > 0: status = self.monitor.handle_reload_config_request( service_type, model_config) self.data['execute'] = 'reload_config' def test_remove_config(self): pms = [] for i in range(5, 10): num_ps = random.randint(5, 20) sub_models = { f'ps_{i}': path.format(f'ps_{i}', version) for i in range(num_ps) if i % 3 == 0 } sub_models[TFSServerType.ENTRY] = path.format(f'entry', version) pm = PublishMeta(shard_id=self.agent_conf.shard_id, replica_id=self.agent_conf.replica_id, model_name=f'test_{i}', num_ps=num_ps, sub_models=sub_models) pms.append(pm) model_configs: Dict[str, ModelServerConfig] = self.monitor.gen_model_config(pms) for service_type, model_config in model_configs.items(): if len(model_config.model_config_list.config) > 0: status = self.monitor.handle_reload_config_request( service_type, model_config) self.data['execute'] = 'remove_config' if __name__ == "__main__": unittest.main() ================================================ FILE: monolith/agent_service/tfs_wrapper.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import subprocess from typing import get_type_hints import grpc from absl import app, flags, logging from google.protobuf import text_format from tensorflow_serving.util.status_pb2 import StatusProto from tensorflow_serving.config import model_server_config_pb2 from tensorflow_serving.apis import model_pb2, get_model_status_pb2 from tensorflow_serving.apis.get_model_status_pb2 import GetModelStatusRequest, ModelVersionStatus from tensorflow_serving.apis.model_service_pb2_grpc import ModelServiceStub from monolith.utils import find_main from monolith.agent_service.utils import TFS_HOME, TfServingConfig State = ModelVersionStatus.State TFS_BINARY = os.environ.get('MONOLITH_TFS_BINARY', None) class TFSWrapper(object): def __init__(self, archon_port: int, grpc_port: int, http_port: int, model_config_file: str, binary_config: TfServingConfig, log_file: str): self._archon_port = archon_port self._grpc_port = grpc_port self._http_port = http_port self._model_config_file = model_config_file self._binary_config = binary_config self._log_file = log_file self._proc = None # model service self._channel = None self._stub = None cp = subprocess.run(f"strings {TFS_BINARY} | grep PredictionServiceGrpc", shell=True) self._is_grpc_remote_op = cp.returncode == 0 def _prepare_cmd(self): flags = [] flags.append(f'--model_config_file={self._model_config_file}') flags.append(f"--port={self._grpc_port}") flags.append(f"--rest_api_port={self._http_port}") flags.append("--model_config_file_poll_wait_seconds=60") psm = os.environ.get("TCE_PSM", "") cluster = os.environ.get("TCE_CLUSTER", "") prefix = psm flags.append(f"--archon_port={self._archon_port}") flags.append(f"--archon_rpc_psm={psm}") flags.append(f"--archon_rpc_cluster={cluster}") flags.append(f"--metrics_namespace_prefix={prefix}") if not self._is_grpc_remote_op: flags.append( f'--archon_entry_to_ps_rpc_timeout={self._binary_config.fetch_ps_timeout_ms}' ) # set some dummy config for archon flags.append("--conf_file=conf/service.conf") flags.append("--log_conf=conf/log4j.properties") for key, clz in get_type_hints(TfServingConfig).items(): default = getattr(TfServingConfig, key) value = getattr(self._binary_config, key) if key == 'platform_config_file': platform_config_file = value or default if platform_config_file is None: flags.append('--platform_config_file=conf/platform_config_file.cfg') else: flags.append(f'--{key}={platform_config_file}') elif value != default: if clz == bool: flags.append(f'--{key}={str(value).lower()}') else: flags.append(f'--{key}={value}') return f'{TFS_BINARY} {" ".join(flags)}' @property def is_grpc_remote_op(self): return self._is_grpc_remote_op def start(self): os.chdir(find_main()) tfs_cmd = self._prepare_cmd() logging.info( f"starting {'grpc' if self._is_grpc_remote_op else 'archon'} tfs in {os.getcwd()} using command {tfs_cmd}" ) with open(self._log_file, "w") as log_stdout: self._proc = subprocess.Popen(tfs_cmd.split(), shell=False, stderr=subprocess.STDOUT, stdout=log_stdout, env=os.environ) self._channel = grpc.insecure_channel(f'localhost:{self._grpc_port}') self._stub = ModelServiceStub(self._channel) def stop(self): logging.info("stoping tfs") try: self._channel.close() if self._proc is not None and self._proc.stdout is not None: self._proc.stdout.close() except Exception as e: logging.info(e) finally: self._proc.kill() def poll(self): self._proc.poll() return self._proc.returncode def model_config_text(self): with open(self._model_config_file, "r") as output: return output.read() def list_saved_models(self): model_server_config = text_format.Parse( self.model_config_text(), model_server_config_pb2.ModelServerConfig()) model_config_list = model_server_config.model_config_list.config return [config.name for config in model_config_list] def list_saved_models_status(self): saved_models = self.list_saved_models() model_status = {} for saved_model in saved_models: model_spec = model_pb2.ModelSpec(name=saved_model) request = GetModelStatusRequest() request.model_spec.CopyFrom(model_spec) try: model_version_status = self._stub.GetModelStatus( request).model_version_status if model_version_status is None or len(model_version_status) == 0: status = State.UNKNOWN else: # if there are more than one version, select the available one model_version_status = sorted(model_version_status, key=lambda mvs: mvs.version) available_version_status = [ mvs for mvs in model_version_status if mvs.state == State.AVAILABLE ] if available_version_status: status = available_version_status[-1] else: status = model_version_status[-1] except grpc.RpcError as e: logging.error(repr(e)) status = ModelVersionStatus(state=State.UNKNOWN, status=StatusProto( error_code=e.code().value[0], error_message=e.details())) model_status[saved_model] = status return model_status class FakeTFSWrapper(object): def __init__(self, model_config_file: str): self._model_config_file = model_config_file def start(self): logging.info("starting tfs") def stop(self): logging.info("stoping tfs") def poll(self): return None def model_config_text(self): with open(self._model_config_file, "r") as output: return output.read() def list_saved_models(self): model_server_config = text_format.Parse( self.model_config_text(), model_server_config_pb2.ModelServerConfig()) model_config_list = model_server_config.model_config_list.config return [config.name for config in model_config_list] def list_saved_models_status(self): saved_models = self.list_saved_models() model_status = {} for saved_model in saved_models: status = ModelVersionStatus(state=State.AVAILABLE) model_status[saved_model] = status return model_status ================================================ FILE: monolith/agent_service/utils.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from absl import logging, flags from contextlib import closing from dataclasses import dataclass import google.protobuf.text_format as text_format import google.protobuf.json_format as json_format import json import os import re import socket import tempfile from typing import Dict, List, Union, Optional, get_type_hints from idl.matrix.proto.proto_parser_pb2 import Instance from tensorflow.core.framework.tensor_pb2 import TensorProto from tensorflow.core.framework.types_pb2 import DataType from tensorflow.core.protobuf.error_codes_pb2 import Code as ErrorCode from tensorflow_serving.apis import model_pb2 from tensorflow_serving.apis.get_model_status_pb2 import ModelVersionStatus from tensorflow_serving.config import model_server_config_pb2 from tensorflow_serving.sources.storage_path.file_system_storage_path_source_pb2 import \ FileSystemStoragePathSourceConfig from tensorflow_serving.util.status_pb2 import StatusProto from tensorflow_serving.config import platform_config_pb2 from tensorflow_serving.servables.tensorflow import session_bundle_config_pb2 from tensorflow_serving.servables.tensorflow import saved_model_bundle_source_adapter_pb2 from tensorflow.core.protobuf.config_pb2 import ConfigProto from monolith.agent_service import constants from monolith.native_training.zk_utils import default_zk_servers, _HOSTS, _PORT, is_ipv6_only import hashlib ModelState = ModelVersionStatus.State SEQ = re.compile(r"[ =\t]+") ServableVersionPolicy = FileSystemStoragePathSourceConfig.ServableVersionPolicy FeatureKeys = { 'name', 'fid', 'float_value', 'int64_value', 'bytes_value', 'fid_list', 'float_list', 'int64_list', 'bytes_list' } flags.DEFINE_string("conf", "", "agent conf file") TFS_HOME = "/opt/tiger/monolith_serving" DEFAULT_MODEL_CONFIG = None DEFAULT_PLATFORM_CONFIG_FILE = "{}/conf/platform_config_file.cfg".format(TFS_HOME) old_isabs = os.path.isabs def isabs(path: str): if path.startswith('hdfs:/'): return True else: return old_isabs(path) os.path.isabs = isabs DefaultRoughSortModelLocalPath = None DefaultRoughSortModelP2PPath = None class TFSServerType: PS = 'ps' ENTRY = 'entry' DENSE = 'dense' UNIFIED = 'unified' class DeployType(TFSServerType): MIXED = 'mixed' # bath ps anf entry are host in one tfs def __init__(self, dtype: str): assert dtype.lower() in { self.ENTRY, self.PS, self.DENSE, self.MIXED, self.UNIFIED } self._dtype = dtype.lower() def __str__(self): return self._dtype def __hash__(self): return hash(self._dtype) def __eq__(self, o): if isinstance(o, str): return self._dtype == o elif isinstance(o, DeployType): return self._dtype == o._dtype else: return False def compat_server_type(self, server_type: str): if server_type is None or server_type == DeployType.MIXED: if self._dtype == DeployType.MIXED: raise RuntimeError('DeployType and ServerType is not compatable!') else: return self._dtype elif self._dtype == DeployType.MIXED: return server_type else: assert self._dtype == server_type return server_type class RoughSortModelLoadedServer: NONE = 'none' ENTRY = 'entry' PS = 'ps' DENSE = 'dense' class RoughSortModelPrefix: PS = 'ps_item_embedding' ENTRY = 'entry_item_embedding' DENSE = 'dense_item_embedding' def conf_parser(file_name: str, args: dict): if not os.path.exists(file_name): return with open(file_name) as f: for line in f: line = line.strip() if line.startswith('#') or len(line) == 0: continue else: idx = line.find('#') if idx > 0: line = line[0:idx] if line.startswith('include'): conf_parser(line.split()[-1], args) else: try: key, value = SEQ.split(line, maxsplit=1) if key in args: if type(args[key]) is not list: args[key] = [args[key]] + [value] elif value is not None and len(value) > 0: args[key] = value except Exception as e: logging.error(f'{e}') def find_free_port(): with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: s.bind(('localhost', 0)) s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) host, port = s.getsockname() return port def check_port_open(port): s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) try: s.connect(('127.0.0.1', port)) s.close() except Exception: logging.info(f'port {port} not open!') return False logging.info(f'port {port} opened!') return True def write_to_tmp_file(content) -> str: fd, path = tempfile.mkstemp() with os.fdopen(fd, 'w') as fp: fp.write(str(content)) return path def replica_id_from_pod_name() -> int: try: if 'MY_POD_NAME' in os.environ: md5 = hashlib.md5() pod_name = os.environ.get('MY_POD_NAME', 'pod_name') md5.update(pod_name.encode('utf-8')) return int(md5.hexdigest()[10:20], base=16) else: return -1 except Exception as e: return -1 @dataclass class TfServingConfig: '''TfServingConfig attributes: :param enable_batching: enable batching :param allow_version_labels_for_unavailable_models: If true, allows assigning unused version labels to models that are not available yet. :param batching_parameters_file: If non-empty, read an ascii BatchingParameters protobuf from the supplied file name and use the contained values instead of the defaults. :param num_load_threads: The number of threads in the thread-pool used to load servables. If set as 0, we don't use a thread-pool, and servable loads are performed serially in the manager's main work loop, may casue the Serving request to be delayed. Default: 0 :param num_unload_threads: The number of threads in the thread-pool used to unload servables. If set as 0, we don't use a thread-pool, and servable loads are performed serially in the manager's main work loop, may casue the Serving request to be delayed. Default: 0 :param max_num_load_retries: maximum number of times it retries loading a model after the first failure, before giving up. If set to 0, a load is attempted only once. Default: 5 :param load_retry_interval_micros: The interval, in microseconds, between each servable load retry. If set negative, it doesn't wait. Default: 1 minute :param file_system_poll_wait_seconds: Interval in seconds between each poll of the filesystem for new model version. If set to zero poll will be exactly done once and not periodically. Setting this to negative value will disable polling entirely causing ModelServer to indefinitely wait for a new model at startup. Negative values are reserved for testing purposes only. :param flush_filesystem_caches: If true (the default), filesystem caches will be flushed after the initial load of all servables, and after each subsequent individual servable reload (if the number of load threads is 1). This reduces memory consumption of the model server, at the potential cost of cache misses if model files are accessed after servables are loaded. :param tensorflow_session_parallelism: Number of threads to use for running a Tensorflow session. Auto-configured by default. Note that this option is ignored if --platform_config_file is non-empty. :param tensorflow_intra_op_parallelism: Number of threads to use to parallelize the execution of an individual op. Auto-configured by default. Note that this option is ignored if --platform_config_file is non-empty. :param tensorflow_inter_op_parallelism: Controls the number of operators that can be executed simultaneously. Auto-configured by default. Note that this option is ignored if --platform_config_file is non-empty. :param ssl_config_file: If non-empty, read an ascii SSLConfig protobuf from the supplied file name and set up a secure gRPC channel :param per_process_gpu_memory_fraction: Fraction that each process occupies of the GPU memory space the value is between 0.0 and 1.0 (with 0.0 as the default) If 1.0, the server will allocate all the memory when the server starts, If 0.0, Tensorflow will automatically select a value. :param allow_growth: allow gpu growth :param saved_model_tags: Comma-separated set of tags corresponding to the meta graph def to load from SavedModel. :param grpc_channel_arguments: A comma separated list of arguments to be passed to the grpc server. (e.g. grpc.max_connection_age_ms=2000) :param grpc_max_threads: Max grpc server threads to handle grpc messages. :param enable_model_warmup: Enables model warmup, which triggers lazy initializations (such as TF optimizations) at load time, to reduce first request latency. :param version: Display version :param remove_unused_fields_from_bundle_metagraph: Removes unused fields from MetaGraphDef proto message to save memory. :param enable_signature_method_name_check: Enable method_name check for SignatureDef. Disable this if agent_service native TF2 regression/classification models. :param xla_cpu_compilation_enabled: EXPERIMENTAL; CAN BE REMOVED ANYTIME! Enable XLA:CPU JIT (default is disabled). With XLA:CPU JIT disabled, models utilizing this feature will return bad Status on first compilation request. :param enable_profiler: Enable profiler service. ''' enable_batching: bool = False allow_version_labels_for_unavailable_models: bool = False batching_parameters_file: str = None num_load_threads: int = 0 num_unload_threads: int = 0 max_num_load_retries: int = 5 load_retry_interval_micros: int = 60 * 1000 * 1000 file_system_poll_wait_seconds: int = 1 flush_filesystem_caches: bool = True tensorflow_session_parallelism: int = 0 tensorflow_intra_op_parallelism: int = 0 tensorflow_inter_op_parallelism: int = 0 ssl_config_file: str = None platform_config_file: str = None per_process_gpu_memory_fraction: float = 0 allow_growth: bool = True saved_model_tags: str = None grpc_channel_arguments: str = None grpc_max_threads: int = 0 enable_model_warmup: bool = True version: str = None remove_unused_fields_from_bundle_metagraph: bool = True enable_signature_method_name_check: bool = False xla_cpu_compilation_enabled: bool = False enable_profiler: bool = True @dataclass class AgentConfig(TfServingConfig): '''AgentConfig attributes: :param bzid: business id of this agent_service, cannot be None. :param base_name: base name of model :param base_path: path to export :param num_ps: The number of ps. :param num_shard: The total number of shard. :param deploy_type: Server type, can be ps/entry/dense/mixed. :param stand_alone_serving: Whether is stand alone agent_service. :param zk_servers: The zk servers. :param proxy_port: TODO :param tfs_entry_port: TODO :param tfs_entry_http_port: TODO :param tfs_ps_port: TODO :param tfs_ps_http_port: TODO :param dense_alone: whether dense alone :param dense_service_num: dense service num for mps :param tfs_dense_port: TODO :param tfs_dense_http_port: TODO :param agent_port: TODO :param update_model_status_interval: Update model status interval. :param max_waiting_sec: The waiting second for PS/DENSE to load, default 600 :param agent_version: Version of Agent, default 1 :param version_policy: Tensorflow version_policy, can be latest/specific/all :param version_data: saved_model version :param preload_jemalloc: preload jemalloc.so :param rough_sort_model_name: model name for deep rough sort, which is generated by FeynmanTob :param rough_sort_model_local_path: load deep rough sort model from this dir :param rough_sort_model_loaded_server: load rough sort model on which server: ps or entry or dense :param layout_pattern: layout path format :param layout_filters: filter saved_models under layout_pattern to load :param tfs_port_archon: service archon port :param tfs_port_grpc: service grpc port :param tfs_port_http: service http port :param use_metrics: whether use metrics :param file_system_poll_wait_seconds_ps: Interval in seconds between each poll of the filesystem for new model version. If set to zero poll will be exactly done once and not periodically. Setting this to negative value will disable polling entirely causing ModelServer to indefinitely wait for a new model at startup. Negative values are reserved for testing purposes only for ps. ''' bzid: str = None base_name: str = None base_path: str = None num_ps: int = 1 num_shard: int = None deploy_type: str = None replica_id: int = None stand_alone_serving: bool = False zk_servers: str = None proxy_port: int = None tfs_entry_port: int = None tfs_entry_http_port: int = None tfs_entry_archon_port: int = None tfs_ps_port: int = None tfs_ps_http_port: int = None tfs_ps_archon_port: int = None dense_alone: bool = False dense_service_num: int = 3 tfs_dense_port: int = None tfs_dense_http_port: int = None tfs_dense_archon_port: int = None agent_port: int = None update_model_status_interval: int = 1 model_config_file = None agent_version: int = 1 max_waiting_sec: int = 1200 preload_jemalloc: bool = True version_policy: str = 'latest' version_data: int = 1 fetch_ps_timeout_ms: int = 200 fetch_ps_long_conn_num: int = 100 fetch_ps_long_conn_enable: bool = True fetch_ps_retry: int = 2 aio_thread_num: int = 30 file_system_poll_wait_seconds_ps: int = 0 # for deep rough sort rough_sort_model_name: str = None rough_sort_model_local_path: str = DefaultRoughSortModelLocalPath rough_sort_model_loaded_server: str = RoughSortModelLoadedServer.ENTRY rough_sort_model_p2p_path: str = DefaultRoughSortModelP2PPath rough_sort_resource_constrained: bool = False dc_aware: bool = False # for unified container layout_pattern: str = None layout_filters: List = None tfs_port_archon: int = None tfs_port_grpc: int = None tfs_port_http: int = None use_metrics: bool = True def __post_init__(self): self.zk_servers = self._update_zk_servers( self.zk_servers, is_ipv6_only()) if self.stand_alone_serving: self.deploy_type = DeployType(DeployType.MIXED) else: assert self.deploy_type is not None self.deploy_type = DeployType(self.deploy_type) if self.num_shard is None: self.num_shard = self.num_tce_shard else: assert self.num_shard == self.num_tce_shard # PORT1 reserve for p2p # PORT2 reserve for agent if self.deploy_type == DeployType.MIXED: self.proxy_port = find_free_port() self.tfs_entry_archon_port = int(os.environ.get('PORT', find_free_port())) self.tfs_entry_port = int(os.environ.get('PORT3', find_free_port())) self.tfs_entry_http_port = int(os.environ.get('PORT4', find_free_port())) self.tfs_ps_port = int(os.environ.get('PORT5', find_free_port())) self.tfs_ps_http_port = int(os.environ.get('PORT6', find_free_port())) self.tfs_ps_archon_port = int(os.environ.get('PORT7', find_free_port())) if self.dense_alone: dense_service_idx = int(os.environ.get('DENSE_SERVICE_IDX', '0')) if dense_service_idx == 0: self.tfs_dense_port = int(os.environ.get('PORT8', find_free_port())) self.tfs_dense_http_port = int(os.environ.get('PORT9', find_free_port())) self.tfs_dense_archon_port = int(os.environ.get('PORT10', find_free_port())) else: self.tfs_dense_archon_port = find_free_port() self.tfs_dense_port = find_free_port() self.tfs_dense_http_port = find_free_port() elif self.deploy_type == DeployType.ENTRY: self.proxy_port = find_free_port() self.tfs_ps_archon_port = find_free_port() self.tfs_ps_port = find_free_port() self.tfs_ps_http_port = find_free_port() if self.dense_alone: self.tfs_dense_port = find_free_port() self.tfs_dense_http_port = find_free_port() self.tfs_dense_archon_port = find_free_port() self.tfs_entry_archon_port = int(os.environ.get('PORT', find_free_port())) self.tfs_entry_port = int(os.environ.get('PORT3', find_free_port())) self.tfs_entry_http_port = int(os.environ.get('PORT4', find_free_port())) elif self.deploy_type == DeployType.PS: self.proxy_port = find_free_port() self.tfs_entry_archon_port = find_free_port() self.tfs_entry_port = find_free_port() self.tfs_entry_http_port = find_free_port() if self.dense_alone: self.tfs_dense_port = find_free_port() self.tfs_dense_http_port = find_free_port() self.tfs_dense_archon_port = find_free_port() self.tfs_ps_archon_port = int(os.environ.get('PORT', find_free_port())) self.tfs_ps_port = int(os.environ.get('PORT3', find_free_port())) self.tfs_ps_http_port = int(os.environ.get('PORT4', find_free_port())) elif self.deploy_type == DeployType.DENSE: assert self.dense_alone == True self.proxy_port = find_free_port() self.tfs_entry_archon_port = find_free_port() self.tfs_entry_port = find_free_port() self.tfs_entry_http_port = find_free_port() self.tfs_ps_archon_port = find_free_port() self.tfs_ps_port = find_free_port() self.tfs_ps_http_port = find_free_port() dense_service_idx = int(os.environ.get('DENSE_SERVICE_IDX', '0')) if dense_service_idx == 0: self.tfs_dense_archon_port = int(os.environ.get('PORT', find_free_port())) self.tfs_dense_port = int(os.environ.get('PORT3', find_free_port())) self.tfs_dense_http_port = int(os.environ.get('PORT4', find_free_port())) else: self.tfs_dense_archon_port = find_free_port() self.tfs_dense_port = find_free_port() self.tfs_dense_http_port = find_free_port() else: assert self.deploy_type == DeployType.UNIFIED self.tfs_port_archon = int(os.environ.get('PORT', find_free_port())) self.tfs_port_grpc = int(os.environ.get('PORT3', find_free_port())) self.tfs_port_http = int(os.environ.get('PORT4', find_free_port())) if self.agent_port is None: self.agent_port = int(os.environ.get('PORT2', find_free_port())) if self.agent_version == 1: self.replica_id = replica_id_from_pod_name() else: replica_id = int(os.environ.get('REPLICA_ID', -1)) if replica_id == -1: replica_id = replica_id_from_pod_name() self.replica_id = replica_id if not self.platform_config_file: self.platform_config_file = DEFAULT_PLATFORM_CONFIG_FILE self.generate_platform_config_file() def generate_platform_config_file(self): try: session_config = ConfigProto() session_config.intra_op_parallelism_threads = ( self.tensorflow_intra_op_parallelism or int(os.getenv("MY_CPU_LIMIT", "0"))) or 16 session_config.inter_op_parallelism_threads = ( self.tensorflow_inter_op_parallelism or int(os.getenv("MY_CPU_LIMIT", "0"))) or 16 session_config.allow_soft_placement = True session_config.gpu_options.allow_growth = self.allow_growth if self.dense_alone and self.enable_batching: batching_parameters = session_bundle_config_pb2.BatchingParameters() batching_parameters.max_batch_size.value = 1024 batching_parameters.batch_timeout_micros.value = 800 batching_parameters.max_enqueued_batches.value = 100000 batching_parameters.num_batch_threads.value = 8 batching_parameters.support_diff_dim_size_inputs = True legacy_config = session_bundle_config_pb2.SessionBundleConfig( session_config=session_config, batching_parameters=batching_parameters) else: legacy_config = session_bundle_config_pb2.SessionBundleConfig( session_config=session_config) legacy_config.enable_model_warmup = self.enable_model_warmup adapter = saved_model_bundle_source_adapter_pb2.SavedModelBundleSourceAdapterConfig( legacy_config=legacy_config) config_map = platform_config_pb2.PlatformConfigMap() config_map.platform_configs['tensorflow'].source_adapter_config.Pack( adapter) text_config_map = text_format.MessageToString(config_map) with open(self.platform_config_file, 'w') as f: f.write(text_config_map) except Exception as e: logging.info(e) try: if os.path.isfile(self.platform_config_file): os.remove(self.platform_config_file) except Exception as e2: logging.info(e2) @property def num_tce_shard(self) -> int: return int(os.environ.get(constants.HOST_SHARD_ENV, 1)) @property def shard_id(self) -> int: return int(os.environ.get('SHARD_ID', -1)) @property def idc(self) -> Optional[str]: idc = os.environ.get('TCE_INTERNAL_IDC') if idc is None: return None else: return idc.lower() @property def cluster(self) -> Optional[str]: cluster = (os.environ.get('TCE_LOGICAL_CLUSTER') or os.environ.get('TCE_CLUSTER') or os.environ.get('TCE_PHYSICAL_CLUSTER')) if cluster is None: return None else: return cluster.lower() @property def location(self) -> Optional[str]: idc, cluster = self.idc, self.cluster if idc is None or cluster is None: return None else: return f'{idc}:{cluster}' @property def path_prefix(self) -> str: if self.dc_aware: return os.path.join('/', self.bzid, 'service', self.base_name, self.location) else: return os.path.join('/', self.bzid, 'service', self.base_name) @property def layout_path(self) -> str: if self.layout_pattern.startswith("/"): return self.layout_pattern else: return f"/{self.bzid}/layouts/{self.layout_pattern}" @property def container_cluster(self) -> str: psm = os.environ.get("TCE_PSM", "unknown") return f"{psm};{self.idc};{self.cluster}" @property def container_id(self) -> str: return os.environ.get("MY_POD_NAME", get_local_ip()) def get_cmd_and_port(self, binary, server_type: str = None, config_file: str = None): server_type = self.deploy_type.compat_server_type(server_type) if config_file is None: model_server_config = self._gen_model_server_config(server_type) config_file = write_to_tmp_file(model_server_config) flags = [] flags.append(f'--model_config_file={config_file}') psm = os.environ.get("TCE_PSM", "") cluster = os.environ.get("TCE_CLUSTER", "") prefix = psm log_conf = '../conf/log4j.properties' if self.deploy_type == DeployType.MIXED and server_type != TFSServerType.ENTRY: psm = psm + '_' + server_type.lower() prefix = psm log_conf = '../conf/log4j_{}.properties'.format(server_type.lower()) flags.append(f"--archon_rpc_psm={psm}") flags.append(f"--archon_rpc_cluster={cluster}") flags.append(f"--metrics_namespace_prefix={prefix}") flags.append(f"--log_conf={log_conf}") if server_type == TFSServerType.PS: flags.append(f"--port={self.tfs_ps_port}") flags.append(f"--rest_api_port={self.tfs_ps_http_port}") flags.append(f'--archon_port={self.tfs_ps_archon_port}') port = self.tfs_ps_port elif server_type == TFSServerType.DENSE: flags.append(f"--port={self.tfs_dense_port}") flags.append(f"--rest_api_port={self.tfs_dense_http_port}") flags.append(f'--archon_port={self.tfs_dense_archon_port}') if self.enable_batching: flags.append(f'--enable_batching=true') flags.append( f'--archon_entry_to_ps_rpc_timeout={self.fetch_ps_timeout_ms}') flags.append( f'--archon_entry_to_ps_long_conn_num={self.fetch_ps_long_conn_num}') flags.append(f'--archon_entry_to_ps_rpc_retry={self.fetch_ps_retry}') flags.append(f'--archon_async_dispatcher_threads={self.aio_thread_num}') if not self.fetch_ps_long_conn_enable: flags.append(f'--archon_entry_to_ps_long_conn_enable=false') port = self.tfs_dense_port else: flags.append(f"--port={self.tfs_entry_port}") flags.append(f"--rest_api_port={self.tfs_entry_http_port}") flags.append(f'--archon_port={self.tfs_entry_archon_port}') flags.append( f'--archon_entry_to_ps_rpc_timeout={self.fetch_ps_timeout_ms}') flags.append( f'--archon_entry_to_ps_long_conn_num={self.fetch_ps_long_conn_num}') flags.append(f'--archon_entry_to_ps_rpc_retry={self.fetch_ps_retry}') flags.append(f'--archon_async_dispatcher_threads={self.aio_thread_num}') if not self.fetch_ps_long_conn_enable: flags.append(f'--archon_entry_to_ps_long_conn_enable=false') port = self.tfs_entry_port if self.agent_version != 1: flags.append("--model_config_file_poll_wait_seconds=0") for key, clz in get_type_hints(TfServingConfig).items(): default = getattr(TfServingConfig, key) value = getattr(self, key) if key == 'file_system_poll_wait_seconds': if self.agent_version == 1: if server_type == TFSServerType.PS: flags.append(f'--file_system_poll_wait_seconds={self.file_system_poll_wait_seconds_ps}') elif value != default: # entry,dense flags.append(f'--file_system_poll_wait_seconds={value}') elif value != default: if clz == bool: flags.append(f'--{key}={str(value).lower()}') else: flags.append(f'--{key}={value}') return f'{binary} {" ".join(flags)}', port def get_cmd(self, binary, server_type: str = None) -> str: cmd, port = self.get_cmd_and_port(binary, server_type) return cmd def get_server_schedule_iter(self, server_type): if self.deploy_type == DeployType.MIXED or self.deploy_type == DeployType.PS: if server_type == TFSServerType.PS: for i in range(self.num_ps): if i % self.num_shard == self.shard_id: yield i else: yield None elif self.deploy_type == DeployType.DENSE and server_type == TFSServerType.DENSE: # [TODO] (fitz) maybe there is a bug, fix it later yield self.replica_id else: yield None def _gen_model_server_config( self, server_type: str = None, ) -> model_server_config_pb2.ModelServerConfig: version_policy: str = self.version_policy version_data: int = self.version_data server_type = self.deploy_type.compat_server_type(server_type) assert server_type is not None model_server_config = model_server_config_pb2.ModelServerConfig() model_config_list = model_server_config.model_config_list.config if server_type == TFSServerType.PS: for i in self.get_server_schedule_iter(server_type): name = f'{server_type}_{i}' model_config = model_config_list.add() model_config.CopyFrom( gen_model_config(name=name, base_path=os.path.join(self.base_path, name), version_policy=version_policy, version_data=version_data)) if self.rough_sort_model_name and self.rough_sort_model_loaded_server == RoughSortModelLoadedServer.PS: name = f'{RoughSortModelPrefix.PS}_{i}' model_config = model_config_list.add() rough_model_path = os.path.join(self.rough_sort_model_local_path, self.rough_sort_model_name, name) model_config.CopyFrom( gen_model_config(name=name, base_path=rough_model_path, version_policy=version_policy, version_data=version_data)) else: if server_type == TFSServerType.DENSE: name = f'{server_type}_0' else: name = server_type model_config = model_config_list.add() model_config.CopyFrom( gen_model_config(name=name, base_path=os.path.join(self.base_path, name), version_policy=version_policy, version_data=version_data)) if self.rough_sort_resource_constrained and self.rough_sort_model_loaded_server == RoughSortModelLoadedServer.ENTRY: name = f'{RoughSortModelPrefix.ENTRY}_0' model_config = model_config_list.add() rough_model_path = os.path.join(self.base_path, name) model_config.CopyFrom( gen_model_config(name=name, base_path=rough_model_path, version_policy=version_policy, version_data=version_data)) elif (self.rough_sort_model_name and (self.rough_sort_model_loaded_server == RoughSortModelLoadedServer.ENTRY or self.rough_sort_model_loaded_server == RoughSortModelLoadedServer.DENSE)): if self.rough_sort_model_loaded_server == RoughSortModelLoadedServer.ENTRY: name = f'{RoughSortModelPrefix.ENTRY}_0' elif self.rough_sort_model_loaded_server == RoughSortModelLoadedServer.DENSE: name = f'{RoughSortModelPrefix.DENSE}_0' model_config = model_config_list.add() rough_model_path = os.path.join(self.rough_sort_model_local_path, self.rough_sort_model_name, name) model_config.CopyFrom( gen_model_config(name=name, base_path=rough_model_path, version_policy=version_policy, version_data=version_data)) return model_server_config @classmethod def from_file(cls, fname): kwarg = {} conf_parser(fname, kwarg) args = {} for key, clz in get_type_hints(AgentConfig).items(): try: if key in kwarg: if clz == bool and kwarg[key].lower() in { 'true', 'y', 't', 'yes', '1' }: args[key] = True elif clz == bool and kwarg[key].lower() in { 'false', 'n', 'f', 'no', '0' }: args[key] = False elif clz in {int, float}: args[key] = clz(eval(kwarg[key])) elif clz == str: if kwarg[key].lower() == 'none': args[key] = None else: args[key] = kwarg[key] elif clz == List: if type(kwarg[key]) is not list: args[key] = [kwarg[key]] else: args[key] = kwarg[key] else: args[key] = clz(kwarg[key]) except: logging.error(f'type convert {key} error, the type is {clz}') # for compat if 'deploy_type' not in args: args['deploy_type'] = kwarg.pop('server_type', None) return cls(**args) @classmethod def _update_zk_servers(cls, zk_servers, use_ipv6: bool = False): if zk_servers and use_ipv6: ipv4s = [] for addr in zk_servers.split(','): ip_port = addr.rsplit(':', 1) if len(ip_port) == 2: items = ip_port[0].split('.') if len(items) == 4 and all([item.isnumeric() for item in items]): ipv4s.append(ip_port[0]) if len(ipv4s) > 0: default_zk_servers_ipv4 = ','.join(['{ip}:{port}'.format(ip=ip, port=_PORT) for ip in _HOSTS]) if zk_servers == default_zk_servers_ipv4: logging.warning('the host is is ipv6 only, but zk_servers specified is ipv4') return default_zk_servers(True) else: raise Exception('the host is is ipv6 only, but zk_servers specified is ipv4') return zk_servers class ZKPath(object): PAT = re.compile( r'^/(?P[-_0-9A-Za-z]+)/service/(?P[-_0-9A-Za-z]+)(/(?P[-_0-9A-Za-z]+):(?P[-_0-9A-Za-z]+))?/(?P\w+):(?P\d+)(/(?P\d+))?$' ) def __init__(self, path: str): self.path = path if path is None or len(path) != 0: matched = self.PAT.match(self.path) if matched: self._group_dict = matched.groupdict() else: logging.info(f"[INFO] path not matched: {path}") self._group_dict = None else: self._group_dict = None def __getattr__(self, name: str): assert name in { 'bzid', 'base_name', 'idc', 'cluster', 'server_type', 'index', 'replica_id' } if self._group_dict: return self._group_dict.get(name) else: return None @property def task(self) -> str: server_type, index = self.server_type, self.index if server_type is not None and index is not None: return f'{server_type}:{index}' else: return None @property def location(self) -> Optional[str]: idc, cluster = self.idc, self.cluster if idc is None or cluster is None: return None else: return f'{idc}:{cluster}' def ship_in(self, idc: str, cluster: str) -> bool: if idc is None or cluster is None: return True else: return idc == self.idc and cluster == self.cluster def parse_pattern(pattern_str, init_val, comb_fn, lp='{', rp='}'): ret_val = init_val while len(pattern_str): begin = pattern_str.find(lp) end = pattern_str.find(rp, begin) if begin == -1 or end == -1: ret_val = comb_fn(ret_val, pattern_str, None) break ret_val = comb_fn(ret_val, pattern_str[:begin], pattern_str[begin + 1:end]) pattern_str = pattern_str[end + 1:] return ret_val def normalize_regex(pattern_str): def comb(val: str, p1: str, p2: str): if p2 is None: return val + p1 return val + p1 + f'(?P<{p2}>\d+)' return parse_pattern(pattern_str, '', comb_fn=comb) def expand_pattern(pattern_str): def comb(vals: List, p1: str, p2: str): if p2 is None: return [val + p1 for val in vals] l = [] for t in p2.split(','): if '-' in t: s, e = t.split('-') l.extend(range(int(s), int(e))) else: l.extend(int(t)) return [val + p1 + str(i) for val in vals for i in l] return parse_pattern(pattern_str, [''], comb, '[', ']') def gen_model_spec(name: str, version: Union[int, str] = None, signature_name: str = None): mode_spec = model_pb2.ModelSpec(name=name) if version is not None: if isinstance(version, int): mode_spec.version.value = version else: mode_spec.version_label = version if signature_name is not None: mode_spec.signature_name = signature_name return mode_spec def gen_model_config( name: str, base_path: str, version_policy: str = 'latest', version_data: Union[int, List[int]] = 1, model_platform: str = 'tensorflow', version_labels: Dict[str, int] = None) -> model_server_config_pb2.ModelConfig: model_config = model_server_config_pb2.ModelConfig( name=name, base_path=base_path, model_platform=model_platform) if version_policy.lower() == 'latest': assert isinstance(version_data, int) model_config.model_version_policy.latest.num_versions = version_data elif version_policy.lower() == 'latest_once': assert isinstance(version_data, int) model_config.model_version_policy.latest_once.num_versions = version_data elif version_policy.lower() == 'all': model_config.model_version_policy.all.CopyFrom(ServableVersionPolicy.All()) elif version_policy.lower() == 'specific': if isinstance(version_data, int): version_data = [version_data] assert isinstance(version_data, list) model_config.model_version_policy.specific.versions.extend(version_data) else: raise ValueError(version_policy + " is not allowed!") if version_labels is not None: model_config.version_labels.update(version_labels) return model_config DEFAULT_MODEL_CONFIG = gen_model_config(name='default', base_path=os.path.join( TFS_HOME, 'dat', 'saved_models', 'entry')) def gen_status_proto(error_code: ErrorCode = ErrorCode.OK, error_message: str = None): return StatusProto(error_code=error_code, error_message=error_message) def gen_model_version_status(version: int, state: ModelState = ModelState.UNKNOWN, error_code: ErrorCode = ErrorCode.OK, error_message: str = None): mvs = ModelVersionStatus(version=version, state=state) mvs.status.CopyFrom(gen_status_proto(error_code, error_message)) return mvs def make_tensor_proto(instances): tp = TensorProto(dtype=DataType.DT_STRING) dim = tp.tensor_shape.dim.add() dim.size = len(instances) tp.string_val.extend(instances) return tp class InstanceFormater: def __init__(self, inst: Instance): self._inst = inst def __str__(self): return f"{self._inst}" def to_tensor_proto(self, batch_size: int): serialized = self._inst.SerializeToString() instances = [serialized for _ in range(batch_size)] return make_tensor_proto(instances) def to_pb(self, fname: str = None) -> str: content = self._inst.SerializeToString() if fname is None: fd, path = tempfile.mkstemp() with os.fdopen(fd, 'wb') as fp: fp.write(content) return path else: with open(fname, 'wb') as fid: fid.write(content) return fname def to_json(self, fname: str = None) -> str: content = json_format.MessageToJson(self._inst) if fname is None: return write_to_tmp_file(content) else: with open(fname) as fid: fid.write(content) return fname def to_pb_text(self, fname: str = None) -> str: if fname is None: return write_to_tmp_file(self._inst) else: with open(fname) as fid: fid.write(str(self._inst)) return fname @classmethod def from_json(cls, fname: str): message = Instance() with open(fname) as fid: kwargs = json.load(fid) return cls(json_format.ParseDict(kwargs, message)) @classmethod def from_pb_text(cls, fname: str): message = Instance() text = [] with open(fname) as fid: for line in fid: text.append(line.strip()) return cls(text_format.Parse('\n'.join(text), message)) @classmethod def from_dump(cls, fname: str): stack, kwargs = [], {} def get_item(): if stack and kwargs: arg = kwargs for item in stack: if item in arg: arg = arg[item] else: return None return arg else: return None def set_item(item): last_arg, arg = None, kwargs for key in stack: last_arg = arg arg = arg[key] if isinstance(item, dict): if isinstance(arg, list): stack.pop() arg = last_arg (key, value), = item.items() if value is None: if key.isnumeric() and int(key) == 0: last_arg[stack[-1]] = [value] stack.append(0) elif key.isnumeric(): stack[-1] = int(key) last_arg.append(value) else: if arg is None: last_arg[stack[-1]] = item elif len(stack) >= 2 and isinstance(stack[-1], int) and stack[-2] == 'feature': if key in FeatureKeys and key not in arg: arg.update(item) else: stack.pop() stack.pop() kwargs.update(item) else: arg.update(item) stack.append(key) else: if arg is None: last_arg[stack[-1]] = item else: assert isinstance(arg, dict) arg.update(item) else: if arg is None: last_arg[stack[-1]] = [item] else: arg.append(item) with open(fname) as fid: for line in fid: if line.startswith('"root":'): continue (key, value) = [ item.strip().strip('"').strip("'") for item in line.strip().split(':') ] if len(value) == 0: set_item(item={key: None}) else: if value.isnumeric(): value = int(value) if key.isnumeric(): # list set_item(item=value) else: # dict set_item(item={key: value}) message = Instance() return cls(json_format.ParseDict(kwargs, message)) def pasre_sub_model_name(sub_model_name: str): if sub_model_name is None or len(sub_model_name) == 0: raise RuntimeError('sub_model_name is None or empty') pasred = sub_model_name.strip().split('_') if len(pasred) == 1: return pasred[0].lower(), 0 else: assert len(pasred) == 2 return pasred[0].lower(), int(pasred[1]) def get_local_ip() -> str: try: local_ip = os.environ.get("MY_HOST_IP", socket.gethostbyname(socket.gethostname())) if local_ip is not None and local_ip not in {'', 'localhost', '127.0.0.1'}: return local_ip except Exception as e: logging.warning(e) skt = None try: skt = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) skt.connect(('8.8.8.8', 80)) local_ip = skt.getsockname()[0] if local_ip is not None and local_ip not in {'', 'localhost', '127.0.0.1'}: return local_ip except Exception as e: logging.warning(e) finally: if skt is not None: skt.close() return 'localhost' ================================================ FILE: monolith/agent_service/utils_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import unittest from monolith.agent_service import utils class ServingUtilsTest(unittest.TestCase): @classmethod def setUpClass(cls) -> None: os.environ['MY_HOST_IP'] = '127.0.0.1' def test_gen_model_spec(self): name, version, signature_name = 'model', 1, 'predict' model_spec = utils.gen_model_spec(name, version, signature_name) self.assertEqual(model_spec.name, name) self.assertEqual(model_spec.version.value, version) self.assertEqual(model_spec.signature_name, signature_name) def test_gen_model_config(self): name, base_path, num_versions = 'model', '/tmp/model/saved_model', 2 version_labels = {'v0': 0, 'v1': 1} model_config = utils.gen_model_config(name, base_path, version_data=num_versions, version_labels=version_labels) self.assertEqual(model_config.name, name) self.assertEqual(model_config.base_path, base_path) self.assertEqual(model_config.model_version_policy.latest.num_versions, num_versions) def test_gen_status_proto(self): status_proto = utils.gen_status_proto(utils.ErrorCode.CANCELLED, error_message='CANCELLED') self.assertEqual(status_proto.error_code, utils.ErrorCode.CANCELLED) self.assertEqual(status_proto.error_message, 'CANCELLED') def test_gen_model_version_status(self): version, state = 1, utils.ModelState.START error_code, error_message = utils.ErrorCode.NOT_FOUND, "NOT_FOUND" model_version_status = utils.gen_model_version_status( version, state, error_code, error_message) self.assertEqual(model_version_status.version, version) self.assertEqual(model_version_status.state, state) def test_gen_from_file(self): conf = utils.AgentConfig.from_file( fname='monolith/agent_service/agent.conf') self.assertTrue(conf.stand_alone_serving) def test_list_field(self): conf = utils.AgentConfig.from_file( fname='monolith/agent_service/agent.conf') self.assertEqual(conf.layout_filters, ['ps_0', 'ps_1']) def test_instance_wrapper_from_json(self): iw = utils.InstanceFormater.from_json( 'monolith/agent_service/test_data/inst.json') tensor_proto = iw.to_tensor_proto(5) self.assertEqual(tensor_proto.dtype, 7) self.assertEqual(tensor_proto.tensor_shape.dim[0].size, 5) def test_instance_wrapper_from_pbtext(self): iw = utils.InstanceFormater.from_pb_text( 'monolith/agent_service/test_data/inst.pbtext') tensor_proto = iw.to_tensor_proto(5) self.assertEqual(tensor_proto.dtype, 7) self.assertEqual(tensor_proto.tensor_shape.dim[0].size, 5) def test_instance_wrapper_from_dump(self): iw = utils.InstanceFormater.from_dump( 'monolith/agent_service/test_data/inst.dump') tensor_proto = iw.to_tensor_proto(5) self.assertEqual(tensor_proto.dtype, 7) self.assertEqual(tensor_proto.tensor_shape.dim[0].size, 5) def test_get_cmd_and_port(self): conf = utils.AgentConfig.from_file( fname='monolith/agent_service/agent.conf') conf.agent_version = 2 cmd, port = conf.get_cmd_and_port(binary='tensorflow_model_server', server_type='ps') self.assertTrue('model_config_file_poll_wait_seconds' in cmd) def test_zk_path_full(self): zk_pzth = utils.ZKPath( '/bzid/service/base_name/idc:cluster/server_type:0/1') self.assertEqual(zk_pzth.bzid, 'bzid') self.assertEqual(zk_pzth.base_name, 'base_name') self.assertEqual(zk_pzth.idc, 'idc') self.assertEqual(zk_pzth.cluster, 'cluster') self.assertEqual(zk_pzth.server_type, 'server_type') self.assertEqual(zk_pzth.index, '0') self.assertEqual(zk_pzth.replica_id, '1') self.assertEqual(zk_pzth.location, 'idc:cluster') self.assertEqual(zk_pzth.task, 'server_type:0') self.assertTrue(zk_pzth.ship_in(None, None)) def test_zk_path_partial(self): zk_pzth = utils.ZKPath('/bzid/service/base_name/idc:cluster/server_type:0') self.assertEqual(zk_pzth.bzid, 'bzid') self.assertEqual(zk_pzth.base_name, 'base_name') self.assertEqual(zk_pzth.idc, 'idc') self.assertEqual(zk_pzth.cluster, 'cluster') self.assertEqual(zk_pzth.server_type, 'server_type') self.assertEqual(zk_pzth.index, '0') self.assertEqual(zk_pzth.replica_id, None) self.assertEqual(zk_pzth.location, 'idc:cluster') self.assertEqual(zk_pzth.task, 'server_type:0') self.assertTrue(zk_pzth.ship_in('idc', 'cluster')) def test_zk_path_old_full(self): zk_pzth = utils.ZKPath('/bzid/service/base_name/server_type:0/1') self.assertEqual(zk_pzth.bzid, 'bzid') self.assertEqual(zk_pzth.base_name, 'base_name') self.assertEqual(zk_pzth.idc, None) self.assertEqual(zk_pzth.cluster, None) self.assertEqual(zk_pzth.server_type, 'server_type') self.assertEqual(zk_pzth.index, '0') self.assertEqual(zk_pzth.replica_id, '1') self.assertEqual(zk_pzth.location, None) self.assertEqual(zk_pzth.task, 'server_type:0') self.assertTrue(zk_pzth.ship_in(None, None)) def test_zk_path_old_partial(self): zk_pzth = utils.ZKPath('/bzid/service/base_name/server_type:0') self.assertEqual(zk_pzth.bzid, 'bzid') self.assertEqual(zk_pzth.base_name, 'base_name') self.assertEqual(zk_pzth.idc, None) self.assertEqual(zk_pzth.cluster, None) self.assertEqual(zk_pzth.server_type, 'server_type') self.assertEqual(zk_pzth.index, '0') self.assertEqual(zk_pzth.replica_id, None) self.assertEqual(zk_pzth.location, None) self.assertEqual(zk_pzth.task, 'server_type:0') self.assertTrue(zk_pzth.ship_in(None, None)) def test_zk_path_old_partial2(self): zk_pzth = utils.ZKPath( '/1_20001223_44ce735e-d05c-11ec-ba29-00163e356637/service/20001223_zm_test_realtime_training_1328_v4_r982567_0/ps:1' ) self.assertEqual(zk_pzth.bzid, '1_20001223_44ce735e-d05c-11ec-ba29-00163e356637') self.assertEqual(zk_pzth.base_name, '20001223_zm_test_realtime_training_1328_v4_r982567_0') self.assertEqual(zk_pzth.idc, None) self.assertEqual(zk_pzth.cluster, None) self.assertEqual(zk_pzth.server_type, 'ps') self.assertEqual(zk_pzth.index, '1') self.assertEqual(zk_pzth.replica_id, None) self.assertEqual(zk_pzth.location, None) self.assertEqual(zk_pzth.task, 'ps:1') self.assertTrue(zk_pzth.ship_in(None, None)) if __name__ == "__main__": unittest.main() ================================================ FILE: monolith/agent_service/zk_mirror.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import socket import time import traceback from queue import Queue from absl import logging from collections import defaultdict from threading import Thread, RLock from typing import List, Optional, Dict, Union, Set from monolith.agent_service.data_def import ResourceSpec, ReplicaMeta from monolith.agent_service.data_def import PublishMeta, PublishType as PType from monolith.agent_service.data_def import ModelState, ModelName, ModelMeta from monolith.agent_service.data_def import Event, EventType as EType from kazoo.client import Election from kazoo.protocol.states import WatchedEvent, EventType, ZnodeStat, KazooState from kazoo.exceptions import NodeExistsError, ZookeeperError, \ NoNodeError, NotEmptyError, ConnectionClosedError from monolith.native_training.zk_utils import MonolithKazooClient from monolith.agent_service.utils import get_local_ip, DeployType, replica_id_from_pod_name class ZKMirror(object): _data_lock = RLock() _zk_lock = RLock() _sep = '/' def __init__( self, zk: MonolithKazooClient, bzid: str, queue: Queue = None, tce_shard_id: int = -1, # for entry deploy mode, tce_shard_id is -1 num_tce_shard: int = 1, deploy_type: str = DeployType.MIXED): self._data: Dict[str, bytes] = {} self._zk: MonolithKazooClient = zk self.queue: Queue = queue self._bzid: str = bzid self._is_leader = False self.tce_shard_id: int = tce_shard_id self.num_tce_shard: int = num_tce_shard self._local_host = get_local_ip() self._deploy_type = deploy_type self._leader = None self._zk_lock_path = f"/{bzid}/locks/" self._zk_election_path = f"/{bzid}/election/" # /{bzid}/resource/{shard_id}:{replica_id} -> ResourceSpec self.resource_path: str = f'/{bzid}/resource' # /{bzid}/portal/{model_name} -> ModelMeta self.portal_base_path: str = f'/{bzid}/portal' # /{bzid}/publish/{shard_id}:{model_name} -> PublishMeta self.publish_base_path: str = f'/{bzid}/publish' # /{bzid}/service/{model_name}/deploy_type:task_id/replica -> ReplicaMeta self.service_base_path: str = f'/{bzid}/service' @property def is_leader(self) -> bool: return self._is_leader def set_leader(self): self._is_leader = True def create(self, path, value=b"", acl=None, ephemeral=False, sequence=False, makepath=True, include_data=False): with self._zk_lock: try: self._zk.create(path, value, acl=acl, ephemeral=ephemeral, sequence=sequence, makepath=makepath, include_data=include_data) except NodeExistsError as e: self._zk.retry(self._zk.set, path=path, value=value) except Exception as e: raise e def ensure_path(self, path): with self._zk_lock: self._zk.retry(self._zk.ensure_path, path=path) def set(self, path: str, value: bytes = b''): with self._zk_lock: try: self._zk.retry(self._zk.set, path=path, value=value) except NoNodeError as e: self._zk.create(path=path, value=value, makepath=True) except Exception as e: raise e def exists(self, path: str) -> bool: with self._zk_lock: try: status = self._zk.exists(path=path) if isinstance(status, bool): return status else: return status is not None except ZookeeperError as e: return path in self._data def delete(self, path: str, recursive: bool = True): with self._zk_lock: try: self._zk.retry(self._zk.delete, path=path, recursive=recursive) except NoNodeError as e: logging.info(e) except NotEmptyError as e: self._zk.retry(self._zk.delete, path=path, recursive=True) except Exception as e: raise e def get(self, path) -> Optional[bytes]: with self._data_lock: return self._data.get(path) def get_children(self, path: str) -> List[str]: with self._data_lock: length = len(path.split(self._sep)) children = [] for p in self._data: if p.startswith(path): tl = p.split(self._sep) if len(tl) > length: children.append(tl[length]) return children def report_resource(self, recource: ResourceSpec): path = recource.get_path(self.resource_path) value = b'' if recource is None else recource.serialize() self.create(path=path, value=value, makepath=True, ephemeral=True) @property def resources(self) -> List[ResourceSpec]: # /{bzid}/resource/{shard_id}:{replica_id} -> ResourceSpec with self._data_lock: resource_paths = [ os.path.join(self.resource_path, child) for child in self.get_children(self.resource_path) ] return [ResourceSpec.deserialize(self.get(p)) for p in resource_paths] @property def num_tce_replica(self) -> Optional[int]: def get_replica_cnt(): replica_cnt = {} with self._data_lock: for path in self._data: if path.startswith(self.resource_path): replica_id = int(os.path.basename(path).split(':')[-1]) if replica_id == -1: continue # skip entry elif replica_id in replica_cnt: replica_cnt[replica_id] += 1 else: replica_cnt[replica_id] = 1 replicas = get_replica_cnt() while not all(cnt == self.tce_shard_id for cnt in replicas.values()): time.sleep(5) # log every minute logging.log_every_n(level=logging.INFO, msg='cluster autoscaler or broken node, keep waiting', n=12) replicas = get_replica_cnt() return len(replica_cnt) @property def tce_replica_id(self) -> int: replica_id = int(os.environ.get('REPLICA_ID', -1)) if replica_id == -1: replica_id = replica_id_from_pod_name() return replica_id def publish_loadding(self, info: Union[PublishMeta, List[PublishMeta]]): if isinstance(info, (list, tuple)): for pm in info: path = pm.get_path(self.publish_base_path) value = pm.serialize() loc = self.get(path) if loc is None or loc != value: self.create(path=path, value=value, makepath=True) else: path = info.get_path(self.publish_base_path) value = info.serialize() loc = self.get(path) if loc is None or loc != value: self.create(path=path, value=value, makepath=True) def expected_loading(self) -> Dict[ModelName, PublishMeta]: # /{bzid}/publish/{shard_id}:{model_name} -> PublishMeta with self._data_lock: nodes = self.get_children(self.publish_base_path) models, select, shortest_sub_model_pm = {}, [], {} for node in nodes: path = os.path.join(self.publish_base_path, node) pm = PublishMeta.deserialize(self.get(path)) shard_id, replica_id, model_name = node.split(':') if model_name in models: models[model_name] += 1 else: models[model_name] = 1 # record the most sub_model pm if model_name not in shortest_sub_model_pm: shortest_sub_model_pm[model_name] = pm elif len(shortest_sub_model_pm[model_name].sub_models) > len( pm.sub_models): shortest_sub_model_pm[model_name] = pm # the last one if models[model_name] == pm.total_publish_num: select.append(shortest_sub_model_pm[model_name]) expected: Dict[str, PublishMeta] = {} for pm in select: path = os.path.join( self.publish_base_path, f'{self.tce_shard_id}:{self.tce_replica_id}:{pm.model_name}') data = self.get(path) # for new replica or entry, data is None pm = pm if data is None else PublishMeta.deserialize(data) model_name = pm.model_name if pm.ptype != PType.LOAD: logging.info("ptype is not load, skip!") continue # note: the sceduler will not scedule entry submodel alone, # all the submodels are sceduled with ps submodel. # so there is an entry submodel in every PublishMeta. # the shard_id of every service is -1. if pm.shard_id == self.tce_shard_id and pm.replica_id == self.tce_replica_id: # if service type is 'entry', then shard_id is -1, # and no PublishMeta will fall in this branch # only ps/dense/mixed service type will hit this branch expected[model_name] = pm elif pm.shard_id == self.tce_shard_id and not pm.is_spec: # for autoscalar, new replica if model_name not in expected: pm.replica_id = self.tce_replica_id expected[model_name] = pm else: # all entry/ps/dense/mixed service type can hit this branch # and ps/dense submodels were filtered if model_name not in expected: pm.shard_id = self.tce_shard_id pm.replica_id = self.tce_replica_id pm.sub_models = { sub_model_name: vp for sub_model_name, vp in pm.sub_models.items() if sub_model_name.startswith('entry') } expected[model_name] = pm return expected def get_published_path(self, model_name: str) -> List[str]: with self._data_lock: paths = [] for path in self._data: if path.startswith( self.publish_base_path) and path.endswith(model_name): paths.append(path) return paths def update_service(self, replicas: List[ReplicaMeta]): # /{bzid}/service/{model_name}/deploy_type:task_id/replica -> ReplicaMeta need_create_or_update, local_load_paths = {}, set() for rm in replicas: path = rm.get_path(self._bzid, self._sep) value = rm.serialize() local_load_paths.add(path) loc = self.get(path) if loc is None or loc != value: # not exists or changed need_create_or_update[path] = value # only care about local replicas, remove first need_remove_paths = self.local_replica_paths - local_load_paths for path in need_remove_paths: self.delete(path) # create or update replicas if need_create_or_update: logging.info(f'need_create_or_update: {need_create_or_update}') for path, value in need_create_or_update.items(): try: self.create(path=path, value=value, ephemeral=True, makepath=True) except Exception as e: logging.info(repr(e)) @property def local_replica_paths(self) -> Set[str]: with self._data_lock: local_replicas = set() for path in self._data: if path.startswith(self.service_base_path): rm = ReplicaMeta.deserialize(self._data[path]) host = rm.address.split(':')[0] if host == self._local_host and rm.replica == self.tce_replica_id: local_replicas.add(path) return local_replicas def get_all_replicas(self, server_type: str) -> Dict[str, List[ReplicaMeta]]: # f'{model_name}:{server_type}:{task_id}' -> ReplicaMeta with self._data_lock: result: Dict[str, List[ReplicaMeta]] = defaultdict(list) for path, value in self._data.items(): if path.startswith(self.service_base_path): raw_key = path[len(self.service_base_path):].strip(self._sep) model_name, st, task, _ = raw_key.replace(self._sep, ':').split(':') if st == server_type: key = ':'.join([model_name, server_type, task]) rm = ReplicaMeta.deserialize(value) if rm.stat == ModelState.AVAILABLE: result[key].append(rm) return result def get_model_replicas(self, model_name: str, server_type: str) -> Dict[str, List[ReplicaMeta]]: # f'{model_name}:{server_type}:{task_id}' -> ReplicaMeta with self._data_lock: result: Dict[str, List[ReplicaMeta]] = defaultdict(list) base_path = os.path.join(self.service_base_path, model_name) for task in self.get_children(base_path): if task.startswith(server_type.lower()): task_path = os.path.join(base_path, task) for replica in self.get_children(task_path): path = os.path.join(task_path, replica) content = self._data.get(path) if content is not None: rm = ReplicaMeta.deserialize(content) if rm.stat == ModelState.AVAILABLE: result[f'{model_name}:{task}'].append(rm) return result def get_task_replicas(self, model_name: str, server_type: str, task: int) -> List[ReplicaMeta]: with self._data_lock: path = os.path.join(self.service_base_path, model_name, f'{server_type.lower()}:{task}') result: List[ReplicaMeta] = [] for child in self.get_children(path): content = self._data.get(os.path.join(path, child)) if content is not None: rm = ReplicaMeta.deserialize(content) if rm.stat == ModelState.AVAILABLE: result.append(rm) return result def get_replica(self, model_name: str, server_type: str, task: int, replica: int) -> Optional[ReplicaMeta]: with self._data_lock: path = os.path.join(self.service_base_path, model_name, f'{server_type.lower()}:{task}', str(replica)) content = self._data.get(path) if content is None: return None else: rm = ReplicaMeta.deserialize(content) if rm.stat == ModelState.AVAILABLE: return rm else: return None def watch_portal(self): # 1) check portal/publish conscience self._zk.ensure_path(path=self.portal_base_path) self._zk.ensure_path(path=self.publish_base_path) models_in_portal = set( self._zk.get_children(path=self.portal_base_path) or []) models_in_publish = { item.split(':')[-1] # {shard_id}:{model_name} -> model_name for item in (self._zk.get_children(path=self.publish_base_path) or []) } if len(models_in_publish) > 0: if len(models_in_portal) == 0: # just remove all remove = models_in_publish else: remove = models_in_publish - models_in_portal for model in remove: for key in self._zk.get_children(path=self.publish_base_path): if key.endswith(model): self.delete(path=os.path.join(self.publish_base_path, key), recursive=True) # 2) watch portal models = set() def create_data_watch(data_path: str): logging.info( f"add data watch for model {os.path.basename(data_path)} in portal") def data_watch(data: bytes, state: ZnodeStat, event: WatchedEvent): # info = ModelMeta.deserialize(data) with self._data_lock: if event is None or event.type in { EventType.CREATED, EventType.DELETED }: # in the first call, event is None if event is None: logging.info(f'call watch_portal when restart {data}') if event is None and data is None: action = EventType.DELETED else: action = EventType.NONE if event is None else event.type if data is not None and len(data) > 0: mm = ModelMeta.deserialize(data) mm.action = action else: mm = ModelMeta(model_name=os.path.basename(data_path), action=action) self.queue.put(Event(data_path, mm.serialize(), EType.PORTAL)) else: assert event.type in { EventType.CHILD, EventType.CHANGED, EventType.NONE } return data_watch def children_watch(children: List[str]): if children is None or len(children) == 0: return else: for model in children: if model not in models: models.add(model) path = os.path.join(self.portal_base_path, model) self._zk.DataWatch(path=path, func=create_data_watch(path)) logging.info(f"add children watch in portal") self._zk.ChildrenWatch(path=self.portal_base_path, func=children_watch) def watch_publish(self): publishs = set() self._zk.ensure_path(path=self.publish_base_path) def get_publish_cnt(model_name: str): cnt = 0 for path in self._data: if path.startswith( self.publish_base_path) and path.endswith(model_name): cnt += 1 return cnt def create_data_watch(data_path: str): def data_watch(data: bytes, state: ZnodeStat, event: WatchedEvent): data = data or self._data.get(data_path, None) if data is not None and len(data) > 0: pm = PublishMeta.deserialize(data) else: logging.info(f'watch_publish: data is None, {event}') return with self._data_lock: if event is None or event.type == EventType.CREATED: # in the first call, event is None if pm.ptype == PType.LOAD: self._data[data_path] = data else: del self._data[data_path] elif event.type == EventType.DELETED: if data_path in self._data: del self._data[data_path] else: assert event.type in { EventType.CHILD, EventType.CHANGED, EventType.NONE } cnt = get_publish_cnt(pm.model_name) if cnt == 0 or cnt == pm.total_publish_num: logging.info(f"all the publish of model {pm.model_name} arrived, " f"send event to {'unload' if cnt == 0 else 'load'}") load_path = pm.get_path(self.publish_base_path) data = self._data.get(load_path, data) self.queue.put(Event(data_path, data, EType.PUBLISH)) return data_watch def children_watch(children: List[str]): if children is None or len(children) == 0: return else: for pub in children: if pub not in publishs: publishs.add(pub) path = os.path.join(self.publish_base_path, pub) self._zk.DataWatch(path=path, func=create_data_watch(path)) self._zk.ChildrenWatch(path=self.publish_base_path, func=children_watch) def watch_resource(self): instances = set() def create_data_watch(data_path: str): def data_watch(data: bytes, state: ZnodeStat, event: WatchedEvent): data = data or self._data.get(data_path, None) with self._data_lock: if event is None or event.type == EventType.CREATED: # in the first call, event is None self._data[data_path] = data elif event.type == EventType.DELETED: del self._data[data_path] elif event.type == EventType.CHANGED: self._data[data_path] = data else: assert event.type in {EventType.CHILD, EventType.NONE} return data_watch def children_watch(children: List[str]): if children is None or len(children) == 0: return else: for inst in children: if inst not in instances: instances.add(inst) path = os.path.join(self.resource_path, inst) self._zk.DataWatch(path=path, func=create_data_watch(path)) self._zk.ChildrenWatch(path=self.resource_path, func=children_watch) def watch_service(self): # /{bzid}/service/{model_name}/deploy_type:task_id/replica -> ReplicaMeta children_set = set() self._zk.ensure_path(path=self.service_base_path) def create_data_watch(data_path: str): logging.info(f'data_path: {data_path}') def data_watch(data: bytes, state: ZnodeStat, event: WatchedEvent): logging.info(f'service data_watch: {data_path}: {data}, {event}') data = data or self._data.get(data_path, None) with self._data_lock: if event is None or event.type == EventType.CREATED: # in the first call, event is None self._data[data_path] = data elif event.type == EventType.DELETED: del self._data[data_path] elif event.type == EventType.CHANGED: self._data[data_path] = data else: assert event.type in {EventType.CHILD, EventType.NONE} return data_watch def create_replica_watch(task_path: str): logging.info(f'task_path: {task_path}') model = os.path.basename(os.path.dirname(task_path)) task = os.path.basename(task_path) def replica_watch(children: List[str]): if children is None or len(children) == 0: return else: for replica in children: key = f"{model}:{task}:{replica}" if key not in children_set: children_set.add(key) path = os.path.join(task_path, replica) self._zk.DataWatch(path=path, func=create_data_watch(path)) return replica_watch def create_task_watch(model_path: str): logging.info(f'model_path: {model_path}') model = os.path.basename(model_path) def task_watch(children: List[str]): if children is None or len(children) == 0: return else: for task in children: key = f"{model}:{task}" if key not in children_set: children_set.add(key) path = os.path.join(model_path, task) self._zk.ChildrenWatch(path=path, func=create_replica_watch(path)) return task_watch def model_watch(children: List[str]): if children is None or len(children) == 0: return else: for model in children: key = model if key not in children_set: children_set.add(key) path = os.path.join(self.service_base_path, model) self._zk.ChildrenWatch(path=path, func=create_task_watch(path)) self._zk.ChildrenWatch(path=self.service_base_path, func=model_watch) def election(self, leader, sched, identifier: str = None): self._leader = leader identifier = identifier or os.environ.get('MY_POD_NAME') if self._deploy_type == DeployType.ENTRY: logging.info('entry cannot be leader!') return def target(): try: election: Election = self._zk.Election(self._zk_election_path, identifier) election.run(leader, zk=self, sched=sched) except ConnectionClosedError as e: if self._zk.state in {KazooState.CONNECTED, KazooState.SUSPENDED}: logging.info(f"ConnectionClosedError, state is {self._zk.state}") pass else: logging.info(f"kazo {self._zk.state}, restart!") with self._data_lock: self._data = {} while not self.queue.empty(): self.queue.get_nowait() self.start() except Exception as e: logging.info(e) thread = Thread(target=target) thread.start() def start(self, is_client: bool = False): self._zk.start() self.watch_service() if not is_client: self.watch_publish() def stop(self): if self._leader is not None: self._leader.cancel() self._zk.stop() ================================================ FILE: monolith/agent_service/zk_mirror_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from absl import logging import os from random import shuffle from queue import Queue from kazoo.exceptions import NodeExistsError import socket import time import threading import unittest from monolith.agent_service import constants from monolith.agent_service import utils from monolith.agent_service.agent_service_pb2 import ServerType from monolith.agent_service.mocked_tfserving import FakeTFServing from monolith.agent_service.mocked_zkclient import FakeKazooClient from monolith.agent_service.zk_mirror import ZKMirror from monolith.agent_service.data_def import PublishMeta, PublishType, ReplicaMeta, ResourceSpec, \ SubModelName, VersionPath, ModelMeta, EventType MODEL_NAME = 'model' BASE_PATH = f'/tmp/{MODEL_NAME}/saved_models' NUM_REPLICAS = 3 class ZKMirrorTest(unittest.TestCase): tfs: FakeTFServing = None agent_conf: utils.AgentConfig = None @classmethod def setUpClass(cls) -> None: os.environ[constants.HOST_SHARD_ENV] = '10' os.environ['SHARD_ID'] = '2' os.environ['REPLICA_ID'] = '2' cls.bzid = 'bzid' cls.shard_id = 2 cls.num_tce_shard = 10 cls.replica_id = 2 cls.zk = ZKMirror(zk=FakeKazooClient(), bzid=cls.bzid, queue=Queue(), tce_shard_id=cls.shard_id, num_tce_shard=cls.num_tce_shard) cls.zk.start() cls.resource = ResourceSpec( address=f'{utils.get_local_ip()}:1234', # host:port shard_id=cls.shard_id, replica_id=cls.replica_id, memory=12345, cpu=5.6, network=3.2, work_load=0.7) @classmethod def tearDownClass(cls) -> None: cls.zk.stop() def test_crud(self): # ensure_path self.zk.ensure_path(path='/model/crud') # exists self.assertTrue(self.zk.exists(path='/model/crud')) # create self.zk.create(path='/model/crud/data', value=b'test', makepath=True) # get/set value, _ = self.zk._zk.get(path='/model/crud/data') self.assertEqual(value, b'test') self.zk.set(path='/model/crud/data', value=b'new_test') value, _ = self.zk._zk.get(path='/model/crud/data') self.assertEqual(value, b'new_test') # delete self.zk.delete(path='/model/crud', recursive=False) self.assertFalse(self.zk.exists(path='/model/crud')) # porperties self.assertEqual(self.zk.num_tce_shard, 10) self.assertEqual(self.zk.tce_replica_id, 2) self.assertEqual(self.zk.tce_shard_id, 2) def test_zk_mirror(self): # 0) test_step0_request_loading self.zk.watch_portal() self.zk.watch_resource() path = os.path.join(self.zk.portal_base_path, MODEL_NAME) mm = ModelMeta(model_name=MODEL_NAME, model_dir=BASE_PATH, num_shard=5) self.zk.create(path, mm.serialize()) # 1) test_step1_scheduler path = os.path.join(self.zk.portal_base_path, MODEL_NAME) event = self.zk.queue.get() self.assertEqual(event.etype, EventType.PORTAL) self.assertEqual(event.path, path) mm = ModelMeta.deserialize(event.data) version, num_ps, num_tce_shard = 123456, 10, self.zk.num_tce_shard pms = [] tce_shards = list(range(self.zk.num_tce_shard)) shuffle(tce_shards) # scheduler for i in range(mm.num_shard): sub_models: Dict[SubModelName, VersionPath] = { f'ps_{k}': f'{mm.model_dir}/ps_{k}/{version}' for k in range(num_ps) if k % mm.num_shard == i } sub_models['entry'] = f'{mm.model_dir}/entry/{version}' # random schedule, and ensure current shard included if i == 0: shard_id = self.shard_id else: shard_id = tce_shards.pop() if shard_id == self.shard_id: shard_id = tce_shards.pop() for replica_id in range(NUM_REPLICAS): pm = PublishMeta(shard_id=shard_id, replica_id=replica_id, model_name=mm.model_name, num_ps=10, sub_models=sub_models) pms.append(pm) for pm in pms: pm.total_publish_num = len(pms) self.zk.publish_loadding(pms) # 2) test_step2_loading expected_loading = self.zk.expected_loading() for model_name, pm in expected_loading.items(): self.assertEqual(model_name, MODEL_NAME) self.assertEqual(self.shard_id, pm.shard_id) self.assertTrue('entry' in pm.sub_models) # 3) test_step3_update_service expected_loading = self.zk.expected_loading() for model_name, pm in expected_loading.items(): replicas = [] for sub_model_name, vp in pm.sub_models.items(): if sub_model_name == 'entry': server_type, task = 'entry', 0 else: server_type, task = sub_model_name.split('_') task = int(task) rm = ReplicaMeta( address=f'{utils.get_local_ip()}:8080', # host:port model_name=model_name, server_type=server_type, task=task, replica=self.replica_id, stat=utils.ModelState.AVAILABLE) replicas.append(rm) self.zk.update_service(replicas) # 4) test_step4_replicas_ops local_ip = utils.get_local_ip() entry_replica = ReplicaMeta(address=f'{local_ip}:8080', model_name='model', server_type='entry', task=0, replica=2, stat=30) ps0_replica = ReplicaMeta(address=f'{local_ip}:8080', model_name='model', server_type='ps', task=0, replica=2, stat=30) ps5_replica = ReplicaMeta(address=f'{local_ip}:8080', model_name='model', server_type='ps', task=5, replica=2, stat=30) all_replicas = self.zk.get_all_replicas(server_type='ps') self.assertEqual(all_replicas['model:ps:0'][0], ps0_replica) self.assertEqual(all_replicas['model:ps:5'][0], ps5_replica) model_replicas = self.zk.get_model_replicas(model_name=MODEL_NAME, server_type='entry') self.assertEqual(model_replicas['model:entry:0'][0], entry_replica) task_replicas = self.zk.get_task_replicas(model_name=MODEL_NAME, server_type='ps', task=0) self.assertEqual(task_replicas[0], ps0_replica) self.assertEqual( ps5_replica, self.zk.get_replica(model_name=MODEL_NAME, server_type='ps', task=5, replica=2)) local_replica_paths = { '/bzid/service/model/ps:0/2', '/bzid/service/model/entry:0/2', '/bzid/service/model/ps:5/2' } self.assertSetEqual(local_replica_paths, self.zk.local_replica_paths) # 5) test_step5_report_resources self.zk.report_resource(self.resource) # 6) test_step6_get_resources self.assertEqual(self.zk.resources[0], self.resource) if __name__ == "__main__": unittest.main() ================================================ FILE: monolith/base_runner.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Base class for all jobs.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import tensorflow as tf class BaseRunner(object): """Base class for all jobs.""" def __init__(self, *args, **kwargs): """Construct a new BaseRunner. Args: params: Params object containing model configuration. model_dir: String path to the log directory to output to. """ pass def run(self): raise NotImplementedError def write_summary(self, logs, summary_writer, current_step): """Write out summaries of current training step for the checkpoint.""" with tf.compat.v1.Graph().as_default(): summaries = [ tf.compat.v1.Summary.Value(tag=tag, simple_value=value) for tag, value in logs.items() ] tf_summary = tf.compat.v1.Summary(value=summaries) summary_writer.add_summary(tf_summary, current_step) ================================================ FILE: monolith/common/python/BUILD ================================================ load("@rules_python//python:defs.bzl", "py_library") filegroup( name = "libtcmalloc", srcs = ["@gperftools//:libtcmalloc"], ) filegroup( name = "mem_profiling_internal_deps", ) py_library( name = "mem_profiling", srcs = ["mem_profiling.py"], data = [ ":libtcmalloc", ":mem_profiling_internal_deps", ], visibility = ["//visibility:public"], deps = [ "//monolith:utils", "//monolith/native_training:mlp_utils", ], ) ================================================ FILE: monolith/common/python/mem_profiling.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os from monolith import utils from monolith.native_training.mlp_utils import MLPEnv def enable_tcmalloc(): libs = os.environ.get("LD_PRELOAD", "").split(":") libs.append( utils.get_libops_path("../gperftools/libtcmalloc/lib/libtcmalloc.so")) os.environ["LD_PRELOAD"] = ":".join(libs) def setup_heap_profile(heap_profile_inuse_interval=104857600, heap_profile_allocation_interval=1073741824, heap_profile_time_interval=0, sample_ratio=1.0, heap_profile_mmap=False, heap_pro_file=None): """See https://gperftools.github.io/gperftools/heapprofile.html for the meaning of each parameters meaning. Args: sample_ratio: ratio of new we tracked in the heap profiler. Since the full profiler is very slow, usually can be set something like 1/64. """ enable_tcmalloc() mlp_env = MLPEnv() os.environ["HEAPPROFILE"] = os.path.join(heap_pro_file or utils.find_main(), f"hprof_{mlp_env.index}") os.environ["HEAP_PROFILE_INUSE_INTERVAL"] = str( int(heap_profile_inuse_interval / sample_ratio)) os.environ["HEAP_PROFILE_ALLOCATION_INTERVAL"] = str( int(heap_profile_allocation_interval / sample_ratio)) os.environ["HEAP_PROFILE_SAMPLE_RATIO"] = str(sample_ratio) os.environ["HEAP_PROFILE_TIME_INTERVAL"] = str(heap_profile_time_interval) os.environ["HEAP_PROFILE_MMAP"] = str(heap_profile_mmap).lower() ================================================ FILE: monolith/core/BUILD ================================================ load("@pip_deps//:requirements.bzl", "requirement") load("@rules_python//python:defs.bzl", "py_binary", "py_library", "py_test") package( default_visibility = ["//visibility:public"], ) py_library( name = "base_embedding_task", srcs = ["base_embedding_task.py"], srcs_version = "PY3", deps = [ ":auto_checkpoint_feed_hook", ":base_embedding_host_call", ":base_task", ":feature", ":util", ], ) py_library( name = "base_layer", srcs = ["base_layer.py"], srcs_version = "PY3", deps = [ ":hyperparams", ":py_utils", ], ) py_library( name = "base_host_call", srcs = ["base_host_call.py"], srcs_version = "PY3", deps = [ ], ) py_library( name = "base_embedding_host_call", srcs = ["base_embedding_host_call.py"], srcs_version = "PY3", deps = [ ":base_host_call", ":tpu_variable", ], ) py_test( name = "base_embedding_host_call_test", srcs = ["base_embedding_host_call_test.py"], python_version = "PY3", srcs_version = "PY3", deps = [ ":base_embedding_host_call", ], ) py_library( name = "host_call", srcs = ["host_call.py"], srcs_version = "PY3", ) py_library( name = "mixed_emb_op_comb_nws", srcs = ["mixed_emb_op_comb_nws.py"], srcs_version = "PY3", ) py_test( name = "base_layer_test", srcs = ["base_layer_test.py"], python_version = "PY3", srcs_version = "PY3", deps = [ ":base_layer", ], ) py_library( name = "base_model_params", srcs = ["base_model_params.py"], srcs_version = "PY3", ) py_library( name = "base_task", srcs = ["base_task.py"], srcs_version = "PY3", deps = [ ":base_layer", ":hyperparams", ], ) py_test( name = "core_test_suite", srcs = ["core_test_suite.py"], python_version = "PY3", srcs_version = "PY3", deps = [ ":base_embedding_host_call_test", ":base_layer_test", ":hyperparams_test", ":util_test", ], ) py_library( name = "dense", srcs = ["dense.py"], srcs_version = "PY3", deps = [ ":base_layer", ":variance_scaling", ], ) py_test( name = "dense_test", srcs = ["dense_test.py"], python_version = "PY3", srcs_version = "PY3", deps = [ ":dense", ":testing_utils", ], ) py_library( name = "feature", srcs = ["feature.py"], srcs_version = "PY3", ) py_test( name = "hyperparams_test", srcs = ["hyperparams_test.py"], python_version = "PY3", srcs_version = "PY3", deps = [ ":hyperparams", ], ) py_library( name = "model", srcs = ["model.py"], srcs_version = "PY3", deps = [ ":feature", ], ) py_library( name = "hyperparams", srcs = ["hyperparams.py"], srcs_version = "PY3", ) py_library( name = "model_imports_no_params", srcs = ["model_imports.py"], srcs_version = "PY3", ) py_library( name = "model_imports", srcs_version = "PY3", deps = [ ":model_imports_no_params", ], ) py_library( name = "model_registry", srcs = ["model_registry.py"], srcs_version = "PY3", deps = [ ":base_model_params", ":model_imports_no_params", ], ) py_library( name = "optimizers", srcs = ["optimizers.py"], srcs_version = "PY3", ) py_library( name = "py_utils", srcs = ["py_utils.py"], srcs_version = "PY3", deps = [], ) py_library( name = "tpu_variable", srcs = ["tpu_variable.py"], srcs_version = "PY3", deps = [], ) py_library( name = "testing_utils", srcs = ["testing_utils.py"], srcs_version = "PY3", ) py_library( name = "util", srcs = ["util.py"], srcs_version = "PY3", deps = [ requirement("google-cloud-storage"), ], ) py_library( name = "variance_scaling", srcs = ["variance_scaling.py"], srcs_version = "PY3", ) py_library( name = "auto_checkpoint_feed_hook", srcs = ["auto_checkpoint_feed_hook.py"], srcs_version = "PY3", ) py_test( name = "feature_test", srcs = ["feature_test.py"], srcs_version = "PY3", deps = [ "feature", "hyperparams", ], ) py_library( name = "util_test", srcs = ["util_test.py"], srcs_version = "PY3", deps = [":util"], ) ================================================ FILE: monolith/core/__init__.py ================================================ ================================================ FILE: monolith/core/auto_checkpoint_feed_hook.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== from __future__ import division from __future__ import print_function import threading import time import os from six.moves import queue as Queue # pylint: disable=redefined-builtin from six.moves import xrange # pylint: disable=redefined-builtin import tensorflow.compat.v1 as tf from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf.tpu import compilation_result_pb2 as tpu_compilation_result from tensorflow.python.ops import summary_ops_v2 as contrib_summary _USER_PROVIDED_SIGNAL_NAME = "_user_provided_signal_name" _TPU_ESTIMATOR = 'tpu_estimator' _ITERATIONS_PER_LOOP_VAR = 'iterations_per_loop' class PeriodicLogger(object): def __init__(self, seconds): self._log_every_n_seconds = seconds self._last_log_time = 0 def log(self, msg, *args, **kw): if time.time() - self._last_log_time > self._log_every_n_seconds: self._last_log_time = time.time() tf.compat.v1.logging.info(msg, *args, **kw) class _SIGNAL(object): """Signal used to control the thread of infeed/outfeed. All preserved signals must be negative numbers. Positive numbers are used to indicate the number of iterations for next training/evaluation loop. """ NEXT_BATCH = -1 STOP = -2 class _OpQueueContext(object): """Manages work queue and thread for a infeed/outfeed thread.""" def __init__(self, name, target, args): self._name = name self._queue = Queue.Queue() args = (self,) + args self._thread = threading.Thread(name=name, target=target, args=args) self._thread.daemon = True self._thread.start() def stop(self): self._queue.put(_SIGNAL.STOP) def send_next_batch_signal(self, iterations): self._queue.put(iterations) def read_iteration_counts(self): while True: iterations = self._queue.get(block=True) tf.compat.v1.logging.debug('%s read iterations %s', self._name, iterations) if iterations == _SIGNAL.STOP: tf.compat.v1.logging.info('%s received shutdown signal, stopping.', self._name) return yield iterations def join(self): tf.compat.v1.logging.info('Shutting down %s thread.', self._name) self.stop() self._thread.join() class _OpSignalOnceQueueContext(_OpQueueContext): """Manages work queue and thread for a infeed/outfeed thread. This subclass only signals once. """ def __init__(self, name, target, args): super(_OpSignalOnceQueueContext, self).__init__(name, target, args) self._has_signaled = False def send_next_batch_signal(self, iterations): if not self._has_signaled: self._queue.put(iterations) self._has_signaled = True class TPUInfeedOutfeedSessionWithEndOfStreamHandlingHook( tf.estimator.SessionRunHook): """A Session hook setting up the TPU initialization, infeed, and outfeed. This hook does two major things: 1. initialize and shutdown TPU system. 2. launch and join the threads for infeed enqueue and (optional) outfeed dequeue. """ def __init__(self, ctx, enqueue_ops, dequeue_ops, tpu_compile_op, run_infeed_loop_on_coordinator=True, rendezvous=None, master=None, session_config=None, tpu_init_ops=None, outfeed_every_n_steps=1): self._master_job = ctx.master_job self._enqueue_ops = enqueue_ops self._dequeue_ops = dequeue_ops self._rendezvous = rendezvous self._master = master self._session_config = session_config self._init_ops = list(tpu_init_ops or []) if ctx.embedding_config is None: self._embedding_layer_config = None else: self._embedding_layer_config = ( ctx.embedding_config.tpu_embedding.config_proto) self._run_infeed_loop_on_coordinator = run_infeed_loop_on_coordinator self._initial_infeed_sleep_secs = ( ctx.config.tpu_config.initial_infeed_sleep_secs) self._tpu_compile_op = tpu_compile_op # When using model parallelism, the TPU is pre-initialized at startup to # fetch mesh information. We skip re-initializing it here for # MeshTensorFlow since it places variables on TPU directly. Reinitialize tpu # is causing the variable corruption since the previous allocated memory # might be overwritten for other purpose. if (ctx.model_parallelism_enabled and (ctx.config.tpu_config.per_host_input_for_training is tpu_config.InputPipelineConfig.BROADCAST)): self._should_initialize_tpu = False else: self._should_initialize_tpu = True self._outfeed_every_n_steps = outfeed_every_n_steps self.stopping_signal = False def _create_or_get_iterations_per_loop(self): """Creates or gets the iterations_per_loop variable. In TPUEstimator, the user provided computation, the model_fn, is wrapped inside a tf.while_loop for peak performance. The iterations of the loop are specified by this variable, which adjusts its value on the CPU after each TPU program execution and before the next TPU execution. The purpose of using a variable, rather then a constant, is to allow TPUEstimator adapt the TPU training iterations according to the final steps specified by users. For example, if the user sets the iterations_per_loop as 4 in TPUConfig and steps as 10 in TPUEstimator.train(), the iterations_per_loop variable will have the following value before each TPU training. - 1-th TPU execution: iterations_per_loop = 4 - 2-th TPU execution: iterations_per_loop = 4 - 3-th TPU execution: iterations_per_loop = 2 As model_fn increases the global step once per train_op invocation, the global step is 10 after all TPU executions, matching the steps=10 inputs passed in by users. Returns: A TF non-trainable resource variable. Raises: RuntimeError: If multi iterations_per_loop variables were found. """ graph = tf.compat.v1.get_default_graph() collection_name = '{}_{}'.format(_TPU_ESTIMATOR, _ITERATIONS_PER_LOOP_VAR) iter_vars = graph.get_collection(collection_name) if len(iter_vars) == 1: return iter_vars[0] elif len(iter_vars) > 1: raise RuntimeError('Multiple iterations_per_loop_var in collection.') with ops.colocate_with(tf.compat.v1.train.get_global_step()): with tf.compat.v1.variable_scope(_TPU_ESTIMATOR, reuse=tf.compat.v1.AUTO_REUSE): return tf.compat.v1.get_variable( _ITERATIONS_PER_LOOP_VAR, initializer=tf.compat.v1.initializers.zeros(), shape=[], dtype=tf.dtypes.int32, trainable=False, collections=[ collection_name, tf.compat.v1.GraphKeys.LOCAL_VARIABLES ], use_resource=True) def begin(self): tf.compat.v1.logging.info('TPU job name %s', self._master_job) self._iterations_per_loop_var = self._create_or_get_iterations_per_loop() if self._should_initialize_tpu: self._finalize_ops = [ tf.compat.v1.tpu.shutdown_system(job=self._master_job) ] else: self._finalize_ops = [] summary_writer_init_ops = contrib_summary.summary_writer_initializer_op() self._init_ops.extend(summary_writer_init_ops) # Get all the writer resources from the initializer, so we know what to # flush. for op in summary_writer_init_ops: self._finalize_ops.append(contrib_summary.flush(writer=op.inputs[0])) def _run_infeed(self, queue_ctx, session): tf.compat.v1.logging.info('Starting infeed thread controller.') if self._initial_infeed_sleep_secs: tf.compat.v1.logging.info('Infeed thread sleeping for %d seconds.', self._initial_infeed_sleep_secs) time.sleep(self._initial_infeed_sleep_secs) tf.compat.v1.logging.info('Infeed thread starting after sleep') with self._rendezvous.catch_errors(source='infeed', session=session): if self._run_infeed_loop_on_coordinator: for count, steps in enumerate(queue_ctx.read_iteration_counts()): for i in xrange(steps): tf.compat.v1.logging.debug('Infeed enqueue for iteration (%d, %d)', count, i) session.run(self._enqueue_ops) else: for _ in queue_ctx.read_iteration_counts(): session.run(self._enqueue_ops) tf.compat.v1.logging.info('Infeed thread finished, shutting down.') def _run_outfeed(self, queue_ctx, session): tf.compat.v1.logging.info('Starting outfeed thread controller.') status_logger = PeriodicLogger(seconds=60) with self._rendezvous.catch_errors(source='outfeed', session=session): stopping_signals = False for count, steps in enumerate(queue_ctx.read_iteration_counts()): step_counter = 0 for i in xrange(steps): tf.compat.v1.logging.debug('Outfeed dequeue for iteration (%d, %d)', count, i) if step_counter % self._outfeed_every_n_steps == 0: ret = session.run(self._dequeue_ops) if _USER_PROVIDED_SIGNAL_NAME in ret: if 'stopping' not in ret[_USER_PROVIDED_SIGNAL_NAME]: raise RuntimeError('ret[{}] must contain key \'stopping\'.' ).format(_USER_PROVIDED_SIGNAL_NAME) if ret[_USER_PROVIDED_SIGNAL_NAME]['stopping'][0] == True \ and stopping_signals == False: stopping_signals = True tf.compat.v1.logging.info( 'Encountered stop signal at iteration (%d, %d).', count, i) step_counter += 1 status_logger.log('Outfeed finished for iteration (%d, %d)', count, i) if stopping_signals == True: tf.compat.v1.logging.info( 'Set shared stop signal at iteration (%d, %d).', count, i) self.stopping_signal = True tf.compat.v1.logging.info('Outfeed thread finished, shutting down.') def _create_infeed_controller(self, name, target, args): return _OpQueueContext(name=name, target=target, args=args) def _assertCompilationSucceeded(self, result, coord): proto = tpu_compilation_result.CompilationResultProto() proto.ParseFromString(result) if proto.status_error_message: tf.compat.v1.logging.error('Compilation failed: {}'.format( proto.status_error_message)) coord.request_stop() else: tf.compat.v1.logging.info('Compilation succeeded') def after_create_session(self, session, coord): if self._should_initialize_tpu: tf.compat.v1.logging.info('Init TPU system') start = time.time() with tf.Graph().as_default(): with tf.compat.v1.Session(self._master, config=self._session_config) as sess: sess.run( tf.compat.v1.tpu.initialize_system( job=self._master_job, embedding_config=self._embedding_layer_config)) tf.compat.v1.logging.info('Initialized TPU in %d seconds', time.time() - start) session.run(self._init_ops, options=config_pb2.RunOptions(timeout_in_ms=30 * 60 * 1000)) if os.environ.get('TPU_SPLIT_COMPILE_AND_EXECUTE', '') == '1': tf.compat.v1.logging.info( 'Compiling user program: this may take a while...') self._assertCompilationSucceeded(session.run(self._tpu_compile_op), coord) self._infeed_controller = self._create_infeed_controller( name='InfeedController', target=self._run_infeed, args=(session,)) self._outfeed_controller = _OpQueueContext(name='OutfeedController', target=self._run_outfeed, args=(session,)) # Enable the worker watchdog to terminate workers on coordinator exit. watchdog_timeout = int(os.environ.get('TF_TPU_WATCHDOG_TIMEOUT', '0')) if watchdog_timeout > 0: session_support.start_worker_watchdog(session, shutdown_timeout=watchdog_timeout) def before_run(self, run_context): if self.stopping_signal == True: tf.compat.v1.logging.info( 'Throw OutOfRangeError error due to encountering stopping signal in before_run.' ) raise tf.errors.OutOfRangeError(None, None, 'Stopped by stopping signal.') iterations = run_context.session.run(self._iterations_per_loop_var) tf.compat.v1.logging.info('Enqueue next (%d) batch(es) of data to infeed.', iterations) self._infeed_controller.send_next_batch_signal(iterations) tf.compat.v1.logging.info( 'Dequeue next (%d) batch(es) of data from outfeed.', iterations) self._outfeed_controller.send_next_batch_signal(iterations) def end(self, session): tf.compat.v1.logging.info('Stop infeed thread controller') self._infeed_controller.join() self._rendezvous.record_done('infeed') tf.compat.v1.logging.info('Stop output thread controller') self._outfeed_controller.join() self._rendezvous.record_done('outfeed') tf.compat.v1.logging.info('Shutdown TPU system.') session.run(self._finalize_ops) @staticmethod def get_stopping_signals_and_name(features): stopping_signals = None if _USER_PROVIDED_SIGNAL_NAME in features: tf.compat.v1.logging.info("Get stopping signals and name.") sum_stopping_signals = tf.compat.v1.tpu.cross_replica_sum( tf.cast(features[_USER_PROVIDED_SIGNAL_NAME], tf.int32)) stopping_signals = {'stopping': sum_stopping_signals > 0} return stopping_signals, _USER_PROVIDED_SIGNAL_NAME ================================================ FILE: monolith/core/base_embedding_host_call.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import sys from absl import logging import tensorflow.compat.v1 as tf import tensorflow as tf2 from monolith.core.base_host_call import BaseHostCall from monolith.core.tpu_variable import ReplicatedVariable _LABLES_FOR_AUC_CALCULATION = "labels_for_auc_calculation" _Y_PRED_FOR_AUC_CALCULATION = "y_pred_for_auc_calculation" _REQ_TIME = "req_time" _SAMPLE_RATE = "sample_rate" _UID = "uid" _UID_BUCKET = 'uid_bucket' _DEEPINSIGHT_SAMPLE_RATES = "di_example_sample_rates" _DEEPINSIGHT_LABELS = "di_labels" _DEEPINSIGHT_PREDS = "di_preds" _DEEPINSIGHT_REQ_TIMES = "di_req_times" _RATIO_N = 1000 _UID_SAMPLE_RATE = 0.01 _HOST_CALL_AUC_METRICS = set([ _LABLES_FOR_AUC_CALCULATION, _Y_PRED_FOR_AUC_CALCULATION, _SAMPLE_RATE, _REQ_TIME, _UID_BUCKET ]) # TPU variables names _LABELS_TPU_VARIABLE = "labels_tpu_variable" _PREDS_TPU_VARIABLE = "preds_tpu_variable" _UID_BUCKETS_TPU_VARIABLE = "uid_buckets_tpu_variable" _REQ_TIMES_TPU_VARIABLE = "req_times_tpu_variable" _SAMPLE_RATES_TPU_VARIABLE = "sample_rates_tpu_variable" _ACCUMULATED_COUNTER_TPU_VARIABLE = "tpu_variables_accumulated_times" _DEPRECATED_METRIC_NAMES = [ _LABLES_FOR_AUC_CALCULATION, _Y_PRED_FOR_AUC_CALCULATION, _UID_BUCKET, _REQ_TIME, _SAMPLE_RATE ] class TPUVariableRestoreHook(tf.estimator.SessionRunHook): """Initialize variables on TPU devices.""" def __init__(self, op): self._op = op def after_create_session(self, session, coord): logging.info("Initialize variables on TPU devices.") session.run(self._op) class BaseEmbeddingHostCall(BaseHostCall): def __init__(self, output_dir, enable_host_call, enable_deepinsight, enable_host_call_scalar_metrics, enable_caching_with_tpu_var_mode, top_k_sampling_num_per_core, params): super(BaseEmbeddingHostCall, self).__init__(output_dir, enable_host_call) self._enable_host_call = params["enable_host_call"] self._enable_deepinsight = enable_deepinsight self._enable_host_call_scalar_metrics = enable_host_call_scalar_metrics self._enable_caching_with_tpu_var_mode = enable_caching_with_tpu_var_mode self._top_k_sampling_num_per_core = top_k_sampling_num_per_core if params["cpu_test"] is True: self._context = None else: self._context = params["context"] self._host_call_every_n_steps = params["host_call_every_n_steps"] # Each TPU core uses these tpu variables to record metrics. # labels tpu variable, shape is (topk_num * host_call_steps, ) # preds tpu variable, shape is (topk_num * host_call_steps, ) # uid_buckets tpu variable, shape is (topk_num * host_call_steps, ) # req_times tpu variable, shape is (host_call_steps, ) # sample_rates tpu variable, shape is (host_call_steps, ) self._labels_tpu_variable = None self._preds_tpu_variable = None self._uid_buckets_tpu_variable = None self._req_times_tpu_variable = None self._sample_rates_tpu_variable = None # Counter of accumulating times for next host call. self._accumulated_counter_tpu_variable = None self.tpu_var_restore_hooks = [] # Create TPU variables. self._create_all_tpu_variables() # Use TPU variables to reach each step's metrics and process all metrics # accumulated in each host call for deepinsight usage. def _create_all_tpu_variables(self): if self._enable_host_call is False: logging.info("enable_host_call is False, do not create tpu variables.") return if self._enable_caching_with_tpu_var_mode is False: logging.info( "enable_caching_with_tpu_var_mode is False, do not create tpu variables." ) return assert self._host_call_every_n_steps > 1, "If tpu variables caching is enabled, we need host_call_every_n_steps bigger than 1." logging.info("Create all tpu variables.") # Create TPU variables for metrics. max_accumulated_samples_per_host_call = self._top_k_sampling_num_per_core * self._host_call_every_n_steps self._labels_tpu_variable = self._create_tpu_var( _LABELS_TPU_VARIABLE, [max_accumulated_samples_per_host_call], tf.float32) self._preds_tpu_variable = self._create_tpu_var( _PREDS_TPU_VARIABLE, [max_accumulated_samples_per_host_call], tf.float32) self._uid_buckets_tpu_variable = self._create_tpu_var( _UID_BUCKETS_TPU_VARIABLE, [max_accumulated_samples_per_host_call], tf.int32) self._req_times_tpu_variable = self._create_tpu_var( _REQ_TIMES_TPU_VARIABLE, [self._host_call_every_n_steps], tf.int64) self._sample_rates_tpu_variable = self._create_tpu_var( _SAMPLE_RATES_TPU_VARIABLE, [self._host_call_every_n_steps], tf.float32) # Create meta to maintain TPU varaibles. self._accumulated_counter_tpu_variable = self._create_tpu_var( _ACCUMULATED_COUNTER_TPU_VARIABLE, [], tf.int32) def _create_tpu_var(self, var_name, var_shape, var_type): ctx = self._context master = ctx._internal_ctx.master_job job_device = '' if master is None else ('/job:%s' % master) slices = [] tpu_host_placement_fn = ctx.tpu_host_placement_function with tf.control_dependencies(None): assign_ops = [] for h in range(ctx.num_hosts): with tf.device(tpu_host_placement_fn(h)): zero_tensor = tf.zeros(shape=var_shape, dtype=var_type) for d in range(ctx.num_of_replicas_per_host): with tf.device('%s/task:%d/device:TPU:%d' % (job_device, h, d)): slice_var = tf.Variable(initial_value=zero_tensor, trainable=False, name="slice_{}_{}_{}".format( var_name, h, d), dtype=var_type, expected_shape=var_shape, collections=["TPU_VAR"]) slices.append(slice_var) assign_ops.append(tf.assign(slice_var, zero_tensor)) tpu_var = ReplicatedVariable(var_name, slices) group_assign_op = tf.group(assign_ops) tpu_var_restore_hook = TPUVariableRestoreHook(group_assign_op) self.tpu_var_restore_hooks.append(tpu_var_restore_hook) logging.info("Created TPU variable, name: {}, shape: {}, type: {}".format( var_name, var_shape, var_type)) return tpu_var def _compute_new_value(self, base_tpu_var, delta_value, update_offset): # Need assert shape.rank is 1 for both base_tpu_var and delta_value. assert base_tpu_var.get_shape().rank == 1, \ "base_tpu_var's rank must be 1, base_tpu_var shape: {}".format(base_tpu_var.get_shape()) assert delta_value.get_shape().rank == 1, \ "delta_value's rank must be 1, delta_value shape: {}".format(delta_value.get_shape()) assert base_tpu_var.dtype == delta_value.dtype, "base_tpu_var dtype: {} must be same as delta_value dtype: {}" \ .format(base_tpu_var.dtype, delta_value.dtype) # Padding in the end of delta_value so that it has same shape with base_tpu_var. # And then right shift delta_value so that its valid data starts from update_offset. # Note: Here we can't pad directly to delta_value in the begining and end. Because the pad position # depends on update_offset which is a tensor. And that will make the pad op encounter a compilation error # as following: # Compilation failure: Input 1 to node `Pad` with op Pad must be a compile-time constant. # XLA compilation requires that operator arguments that represent shapes or dimensions be evaluated to # concrete values at compile time. This error means that a shape or dimension argument could not be # evaluated at compile time, usually because the value of the argument depends on a parameter to the # computation, on a variable, or on a stateful operation such as a random number generator. base_len = base_tpu_var.shape.as_list()[0] delta_len = delta_value.shape.as_list()[0] paddings = [[0, base_len - delta_len]] delta_value = tf.pad(delta_value, paddings, 'CONSTANT', constant_values=0) # Right shift delta_tpu_var according to update_offset in base_tpu_var. delta_value = tf.roll(delta_value, shift=update_offset, axis=0) return base_tpu_var + delta_value def _clear_value_at_index_1(self, tpu_var, var_type, index): return tf.where(tf.math.equal(index, 1), tf.zeros_like(tpu_var, dtype=var_type), tpu_var) def update_tpu_variables_ops(self, global_step, labels, preds, uid_buckets, req_times, sample_rates): if self._enable_host_call is False: logging.info("enable_host_call is False, do not update tpu variables.") return [] if self._enable_caching_with_tpu_var_mode is False: logging.info( "enable_caching_with_tpu_var_mode is False, do not update tpu variables." ) return [] logging.info("Update tpu variables.") assert labels is not None assert preds is not None assert uid_buckets is not None expected_shape = [self._top_k_sampling_num_per_core] assert labels.get_shape() == expected_shape, \ "Expect shape: {}, but shape is {}".format(expected_shape, labels.get_shape()) assert preds.get_shape() == expected_shape, \ "Expect shape: {}, but shape is {}".format(expected_shape, preds.get_shape()) assert uid_buckets.get_shape() == expected_shape, \ "Expect shape: {}, but shape is {}".format(expected_shape, uid_bkcets.get_shape()) # We do host call host_call_every_n_steps steps. At step 0, host_call_every_n_steps, # 2 * host_call_every_n_steps, ..., we need a completed data until this step, and we do a host call # to dump those data. At step 1, host_call_every_n_steps + 1, 2 * host_call_every_n_steps + 1, # We will need clear any accumulated data firstly and then start accumulating new data again. # We use index = tf.math.floormod(global_step, host_call_every_n_steps) to represent the step index where we are. # If index is 1, we will clear everything with TPU variables before accumulating new data. index = tf.math.floormod(global_step, self._host_call_every_n_steps) old_accumulated_counter_value = self._clear_value_at_index_1( self._accumulated_counter_tpu_variable, tf.int32, index) old_labels_value = self._clear_value_at_index_1(self._labels_tpu_variable, tf.float32, index) old_preds_value = self._clear_value_at_index_1(self._preds_tpu_variable, tf.float32, index) old_uid_buckets_value = self._clear_value_at_index_1( self._uid_buckets_tpu_variable, tf.int32, index) old_req_times_tpu_variable = self._clear_value_at_index_1( self._req_times_tpu_variable, tf.int64, index) old_sample_rates_value = self._clear_value_at_index_1( self._sample_rates_tpu_variable, tf.float32, index) # Update labels, preds, and uid_buckets which have self._top_k_sampling_num_per_core elements in the last dimension. tpu_var_offset = old_accumulated_counter_value * self._top_k_sampling_num_per_core new_labels_value = self._compute_new_value(old_labels_value, labels, tpu_var_offset) new_preds_value = self._compute_new_value(old_preds_value, preds, tpu_var_offset) new_uid_buckets_tpu_value = self._compute_new_value(old_uid_buckets_value, uid_buckets, tpu_var_offset) new_req_times_value = self._compute_new_value( old_req_times_tpu_variable, req_times, old_accumulated_counter_value) new_sample_rates_value = self._compute_new_value( old_sample_rates_value, sample_rates, old_accumulated_counter_value) # Increment tpu variable counter. new_accumulated_counter_value = tf.math.add(old_accumulated_counter_value, 1) # Update tpu variables. update_tpu_var_ops = [ tf.assign(self._labels_tpu_variable, new_labels_value), tf.assign(self._preds_tpu_variable, new_preds_value), tf.assign(self._uid_buckets_tpu_variable, new_uid_buckets_tpu_value), tf.assign(self._req_times_tpu_variable, new_req_times_value), tf.assign(self._sample_rates_tpu_variable, new_sample_rates_value), ] # Update tpu variable counter should only happen after updating tpu variables. with tf.control_dependencies(update_tpu_var_ops): update_tpu_var_counter_op = tf.assign( self._accumulated_counter_tpu_variable, new_accumulated_counter_value) return [update_tpu_var_counter_op] def record_summary_tpu_variables(self): if self._enable_host_call is False: logging.info( "enable_host_call is False, do not record summary tpu variables.") return if self._enable_caching_with_tpu_var_mode is False: logging.info( "enable_caching_with_tpu_var_mode is False, record summary tpu variables." ) return logging.info("Record tpu variables.") self.record_summary_tensor(_LABELS_TPU_VARIABLE, self._labels_tpu_variable.read_value()) self.record_summary_tensor(_PREDS_TPU_VARIABLE, self._preds_tpu_variable.read_value()) self.record_summary_tensor(_UID_BUCKETS_TPU_VARIABLE, self._uid_buckets_tpu_variable.read_value()) self.record_summary_tensor(_REQ_TIMES_TPU_VARIABLE, self._req_times_tpu_variable.read_value()) self.record_summary_tensor(_SAMPLE_RATES_TPU_VARIABLE, self._sample_rates_tpu_variable.read_value()) self.record_summary_tensor( _ACCUMULATED_COUNTER_TPU_VARIABLE, self._accumulated_counter_tpu_variable.read_value()) def record_summary_tensor(self, name, tensor): if self._enable_host_call_scalar_metrics is False and name not in _HOST_CALL_AUC_METRICS: return if self._enable_caching_with_tpu_var_mode is True and name in _DEPRECATED_METRIC_NAMES: return super(BaseEmbeddingHostCall, self).record_summary_tensor(name, tensor) def _verify_shape_and_dtype(self, tensor, shape_list, dtype): assert tensor is not None assert tensor.shape.as_list( ) == shape_list, "Expect shape: {}, but actual shape: {}".format( shape_list, tensor.shape.as_list()) assert tensor.dtype == dtype, "Expect dtype {}, but actual dtype: {}".format( dtype, tensor.dtype) def _slice_tensor(self, tensor, indices, expect_shape, expect_dtype): """Select elements from a given tensor using given indices. Args: tensor: The Tensor whose elements are selected using the indices. indices: The Tensor storing the indices to be sliced. expect_shape: The expected shape of tensor. expect_dtype: The expected dtype of tensor. Return: The sliced tensor. """ self._verify_shape_and_dtype(tensor, expect_shape, expect_dtype) # Flatten the tensor here and simplify its data format using reshape, # which is low cost without real data copy. # Each tensor has shape (n, ), n equals to core_number * batch_size_per_core tensor = tf.reshape(tensor, [-1]) sliced_tensor = tf.gather(tensor, indices) return sliced_tensor def _serialize_tensor(self, sampled_tensor, gs, message_name): tf2.summary.text(message_name, data=tf.io.serialize_tensor(sampled_tensor), step=gs) # This function is deprecated for host_call.py. def _serialize_messages(self, labels, y_preds, sample_rates, req_times, uid_buckets, gs): assert labels is not None assert y_preds is not None assert sample_rates is not None assert req_times is not None assert uid_buckets is not None # For sample_rates and req_times, we only need to keep the first one. expect_shape = sample_rates.shape.as_list() assert len( expect_shape ) == 1, "Expect sample_rates shape rank to be 1, but its shape is {}".format( expect_shape) self._serialize_tensor(sample_rates, [0], expect_shape, tf.float32, gs, _DEEPINSIGHT_SAMPLE_RATES) self._serialize_tensor(req_times, [0], expect_shape, tf.int64, gs, _DEEPINSIGHT_REQ_TIMES) def _write_summary_ops(self, gs, labels, y_preds, uid_buckets, req_times, sample_rates, stopping_signals_sum=None): if labels is not None and y_preds is not None: # Filter labels and y_preds by uids to ensure that only a fraction of # uids are selected for AUC calculation if uid_buckets is not None: # Filter out data with uid_bucket < _UID_SAMPLE_RATE * _RATIO_N to write to summary file reshaped_uid_buckets = tf.reshape(uid_buckets, [-1]) if stopping_signals_sum is None: indices = tf.squeeze( tf.where(reshaped_uid_buckets < int(_UID_SAMPLE_RATE * _RATIO_N))) else: indices = tf.squeeze( tf.where( tf.math.logical_and( reshaped_uid_buckets < int(_UID_SAMPLE_RATE * _RATIO_N), tf.math.equal(stopping_signals_sum, 0)))) expect_shape = labels.shape.as_list() labels = self._slice_tensor(labels, indices, expect_shape, tf.float32) y_preds = self._slice_tensor(y_preds, indices, expect_shape, tf.float32) self._serialize_tensor(labels, gs, _DEEPINSIGHT_LABELS) self._serialize_tensor(y_preds, gs, _DEEPINSIGHT_PREDS) if self._enable_deepinsight is True and req_times is not None: assert req_times.get_shape().rank == 2, "req_times shape: {}".format( req_times.get_shape()) # Repeat each element self._top_k_sampling_num_per_core times in dim(1). req_times = tf.repeat(req_times, [self._top_k_sampling_num_per_core], axis=1) req_times = self._slice_tensor(req_times, indices, expect_shape, tf.int64) self._serialize_tensor(req_times, gs, _DEEPINSIGHT_REQ_TIMES) # Calculate AUC based on filtered labels and y_preds auc, auc_op = tf.metrics.auc(labels=labels, predictions=y_preds) tf2.summary.scalar("auc", data=auc, step=gs) # Serialize message if self._enable_deepinsight is True: if sample_rates is not None: expect_shape = sample_rates.shape.as_list() assert len( expect_shape ) == 1, "Expect sample_rates shape rank to be 1, but its shape is {}".format( expect_shape) # For sample_rates and req_times, we only need to keep the first element. sampled_sample_rates = self._slice_tensor(sample_rates, [0], expect_shape, tf.float32) self._serialize_tensor(sampled_sample_rates, gs, _DEEPINSIGHT_SAMPLE_RATES) else: auc_op = None tf2.summary.scalar("sampled_labels_variable_avg", data=tf.reduce_mean(labels), step=gs) tf2.summary.scalar("sampled_preds_variable_avg", data=tf.reduce_mean(y_preds), step=gs) tf2.summary.scalar("req_times_variable_avg", data=tf.reduce_mean(req_times), step=gs) tf2.summary.scalar("sample_rates_variable_avg", data=tf.reduce_mean(sample_rates), step=gs) return auc_op def generate_host_call_hook(self): def _host_call(*args): gs, tensors = self.decompress_tensors(args) summary_writer = tf2.summary.create_file_writer(self._output_dir + "/host_call", flush_millis=10000, max_queue=5000) with summary_writer.as_default(): labels = None y_preds = None req_times = None sample_rates = None uid_buckets = None for i, t in enumerate(tensors): if i == 0: continue name = self._tensor_names[i] data = None if "_avg" in name: data = tf.reduce_mean(t) elif "_max" in name: data = tf.reduce_max(t) elif "_sum" in name: data = tf.reduce_sum(t) elif _LABLES_FOR_AUC_CALCULATION in name: labels = t elif _Y_PRED_FOR_AUC_CALCULATION in name: y_preds = t elif _REQ_TIME in name: req_times = tf.expand_dims(t, -1) elif _SAMPLE_RATE in name: sample_rates = t elif _UID_BUCKET in name: uid_buckets = t else: data = t[0] if data is not None: tf2.summary.scalar(name, data=data, step=gs) auc_op = self._write_summary_ops(gs, labels, y_preds, uid_buckets, req_times, sample_rates) if auc_op is not None: return tf.group(tf.compat.v1.summary.all_v2_summary_ops(), auc_op) else: return tf.compat.v1.summary.all_v2_summary_ops() def get_used_slice(tpu_variable, used_elements_count): return tf.slice(tpu_variable, [0, 0], [-1, used_elements_count]) def _host_call_with_tpu(*args): gs, tensors = self.decompress_tensors(args) summary_writer = tf2.summary.create_file_writer(self._output_dir + "/host_call", flush_millis=10000, max_queue=5000) stopping_signals_sum = None auc_op = None with summary_writer.as_default(): labels_value = None preds_value = None uid_buckets_value = None req_times_value = None sample_rates_value = None accumulated_counter_value = None for i, t in enumerate(tensors): if i == 0: continue name = self._tensor_names[i] data = None if "_avg" in name: data = tf.reduce_mean(t) elif "_max" in name: data = tf.reduce_max(t) elif "_sum" in name: data = tf.reduce_sum(t) elif _LABELS_TPU_VARIABLE in name: labels_value = t elif _PREDS_TPU_VARIABLE in name: preds_value = t elif _UID_BUCKETS_TPU_VARIABLE in name: uid_buckets_value = t elif _REQ_TIMES_TPU_VARIABLE in name: req_times_value = t elif _SAMPLE_RATES_TPU_VARIABLE in name: sample_rates_value = t elif _ACCUMULATED_COUNTER_TPU_VARIABLE in name: accumulated_counter_value = t elif "stopping_signals" in name: stopping_signals_sum = tf.reduce_sum(tf.cast(t, tf.int32)) elif name not in _DEPRECATED_METRIC_NAMES: data = t[0] if data is not None: tf2.summary.scalar(name, data=data, step=gs) # Check labels, preds, uid_buckets shape is as expected. expected_multiple_values_per_step_shape = [ self._context.num_replicas, self._host_call_every_n_steps * self._top_k_sampling_num_per_core ] assert labels_value.get_shape() == expected_multiple_values_per_step_shape, \ "labels_tpu_variable shape: {}, expectd_shape: {}." \ .format(labels_value.get_shape(), expected_multiple_values_per_step_shape) assert preds_value.get_shape() == expected_multiple_values_per_step_shape, \ "preds_tpu_variable shape: {}, expectd_shape: {}." \ .format(preds_value.get_shape(), expected_multiple_values_per_step_shape) assert uid_buckets_value.get_shape() == expected_multiple_values_per_step_shape, \ "uid_buckets_tpu_variable shape: {}, expectd_shape: {}." \ .format(uid_buckets_value.get_shape(), expected_multiple_values_per_step_shape) # Check tpu_variable_accumulated_times_scalar is as expected. expected_scalar_shape = [self._context.num_replicas] assert accumulated_counter_value.get_shape()== expected_scalar_shape, \ "tpu_variable_accumulated_times_scalar shape: {}, expectd_shape: {}." \ .format(accumulated_counter_value.get_shape(), expected_scalar_shape) used_slice_len = accumulated_counter_value[ 0] * self._top_k_sampling_num_per_core # Get the used parts of TPU variable. Then reshape all tpu variables to be similar shape with same rank and same batch dimension # as other non-tpu variables. labels = get_used_slice(labels_value, used_slice_len) y_preds = get_used_slice(preds_value, used_slice_len) uid_buckets = get_used_slice(uid_buckets_value, used_slice_len) # Check req_times, sample_rates shape is as expected. expected_single_value_per_step_shape = [ self._context.num_replicas, self._host_call_every_n_steps ] assert req_times_value.get_shape()== expected_single_value_per_step_shape, \ "req_times_tpu_variable shape: {}, expectd_shape: {}." \ .format(req_times_value.get_shape(), expected_single_value_per_step_shape) # Get the used part of req_times TPu variable. req_times = get_used_slice(req_times_value, accumulated_counter_value[0]) assert sample_rates_value.get_shape()== expected_single_value_per_step_shape, \ "sample_rates_tpu_variable shape: {}, expectd_shape: {}." \ .format(sample_rates_value.get_shape(), expected_single_value_per_step_shape) # Attention, here for performance purpose we use one one sample_rate # to represent all examples in this host_call. We will evaluate from the deepinsight showing side # to see if this has big impact when user use it. sample_rates = tf.squeeze(tf.slice(sample_rates_value, [0, 0], [-1, 1])) tf2.summary.scalar("uid_buckets_tpu_variable_avg", data=tf.reduce_mean(uid_buckets), step=gs) tf2.summary.scalar("accumulated_times_tpu_variable_avg", data=tf.reduce_mean(accumulated_counter_value), step=gs) tf2.summary.scalar("used_slice_len_avg", data=tf.reduce_min(used_slice_len), step=gs) if stopping_signals_sum is not None: tf2.summary.scalar("stopping_signals_sum", data=stopping_signals_sum, step=gs) auc_op = self._write_summary_ops(gs, labels, y_preds, uid_buckets, req_times, sample_rates, stopping_signals_sum) if auc_op is not None: return tf.group(tf.compat.v1.summary.all_v2_summary_ops(), auc_op) else: return tf.compat.v1.summary.all_v2_summary_ops() if self._enable_host_call == True: self.compress_tensors() if self._enable_caching_with_tpu_var_mode is False: return (_host_call, self._tensors) else: return (_host_call_with_tpu, self._tensors) else: logging.info("host_call has been disabled") return None ================================================ FILE: monolith/core/base_embedding_host_call_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import absolute_import from __future__ import division from __future__ import print_function import collections import enum import functools import re import sys from absl import app as absl_app from absl import flags from absl import logging import tensorflow.compat.v1 as tf import unittest import monolith.core.base_embedding_host_call as base_embedding_host_call tf.disable_eager_execution() class BaseEmbeddingHostCallTest(unittest.TestCase): def test_compute_new_value(self): global_step = tf.train.get_or_create_global_step() params = { "enable_host_call": False, "context": None, "cpu_test": False, "host_call_every_n_steps": 100 } host_call = base_embedding_host_call.BaseEmbeddingHostCall( "", False, False, False, False, 10, params) base_value = tf.zeros([10], dtype=tf.int32) delta_value = tf.ones([2], dtype=tf.int32) offset = tf.constant(1, dtype=tf.int32) base_value = host_call._compute_new_value(base_value, delta_value, offset) expected_value = tf.constant([0, 1, 1, 0, 0, 0, 0, 0, 0, 0], dtype=tf.int32) ret = tf.reduce_all(tf.math.equal(base_value, expected_value)) with tf.Session() as sess: ret = sess.run(ret) self.assertTrue(ret) offset = tf.constant(5) base_value = host_call._compute_new_value(base_value, delta_value, offset) expected_value = tf.constant([0, 1, 1, 0, 0, 1, 1, 0, 0, 0], dtype=tf.int32) ret = tf.reduce_all(tf.math.equal(base_value, expected_value)) with tf.Session() as sess: ret = sess.run(ret) self.assertTrue(ret) offset = tf.constant(6) base_value = host_call._compute_new_value(base_value, delta_value, offset) expected_value = tf.constant([0, 1, 1, 0, 0, 1, 2, 1, 0, 0], dtype=tf.int32) ret = tf.reduce_all(tf.math.equal(base_value, expected_value)) with tf.Session() as sess: ret = sess.run(ret) self.assertTrue(ret) if __name__ == "__main__": unittest.main(verbosity=2) ================================================ FILE: monolith/core/base_embedding_task.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Base class for TPU embedding task""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import functools import math import numpy as np import os import shutil import subprocess from absl import logging import tensorflow.compat.v1 as tf from tensorflow.python.tpu import tpu_embedding from monolith.core import base_task import monolith.core.auto_checkpoint_feed_hook as fh import monolith.core.base_embedding_host_call as hs from monolith.core.base_embedding_host_call import BaseEmbeddingHostCall from monolith.core.feature import FeatureSlot, FeatureColumnV1, FeatureColumn3D, Env from monolith.core import base_task import monolith.core.util as util class BaseEmbeddingTask(base_task.BaseTask): """A embedding task which trains Sail-like model on TPU.""" @classmethod def params(cls): p = super(BaseEmbeddingTask, cls).params() p.define( 'vocab_size_per_slot', None, 'Fixed vocab_size for all the slots, this mainly ' 'for quick testing purpose.') p.define( 'custom_vocab_size_mapping', None, 'Fixed vocab size for some slots', ) p.define( 'vocab_size_offset', None, 'Manually increase the vocab_size of each slot by a constant. ' 'This is used to provide a quick fix for issues/problems in ' 'generating GCP data, e.g., incorrect vocab_id assignment or ' 'incorrect vocab_size calculation.') p.define( 'qr_multi_hashing', False, 'If True, enable QR multi hashing trick for the slots with vocab_id larger than the threshold.' ) p.define('qr_hashing_threshold', 100000000, 'Threshold for the the QR slot.') p.define('qr_collision_rate', 4, 'The hashing collision rate for the QR slot.') p.define('vocab_file_path', None, 'Path for the vocab file.') p.define('enable_deepinsight', False, 'Whether connect to deepinsight to show the results.') p.define( 'enable_host_call_scalar_metrics', False, 'If True, enable host call computation of scalar metrics, including \ average AUC per core, average loss, average label, learning rate, \ (potentially) weight and gradient norms, etc. These metrics are \ useful for model development and debugging. If False, only compute \ basic metrics such as AUC, sample rate, etc.') p.define( 'enable_host_call_norm_metrics', False, 'Whether to enable host call computation of weight and gradient \ norms. If enable_host_call_scalar_metrics is False, this param is \ NOT used.') p.define('files_interleave_cycle_length', 4, 'The number of input files that will be processed concurrently.') p.define( 'deterministic', False, 'Whether enable deterministic mini-batch training for comparable experiments.' ) p.define("gradient_multiplier", 1.0, "Gradient multiplier for embeddings.") p.define('enable_caching_with_tpu_var_mode', False, 'Whether enable host call with tpu variables mode.') # TODO(youlong): Revisit top_k_sampling_num_per_core p.define( 'top_k_sampling_num_per_core', 6, 'The number of samples to use per core for DeepInsight AUC \ calculation. A lower number means fewer samples used and faster \ training, and a bigger number means more samples used and slower \ training.') p.define('use_random_init_embedding_for_oov', False, 'Whether use random initialized embedding for oov ids') p.define('merge_vector', False, 'If True, enable merging vector tables of the same slot.') # If there is file_pattern specified, file_pattern will override file_folder. p.train.define('file_folder', None, 'Training input data folder before date string.') p.train.define('date_and_file_name_format', "*/*/part*", "Training file's date and name pattern.") p.train.define( 'start_date', None, 'Training input data start date inclusively, for example: 20201201.') p.train.define( 'end_date', None, 'Training input data end date inclusively, for example: 20201210.') p.train.define('vocab_file_folder_prefix', None, 'Prefix of hdfs folder to keep per day vocab size file.') return p def __init__(self, params): """Constructs a BaseAttentionLayer object.""" super(BaseEmbeddingTask, self).__init__(params) self.p = params self._enable_deepinsight = self.p.enable_deepinsight self._enable_host_call_scalar_metrics = self.p.enable_host_call_scalar_metrics self._enable_caching_with_tpu_var_mode = self.p.enable_caching_with_tpu_var_mode self._top_k_sampling_num_per_core = self.p.top_k_sampling_num_per_core if params.vocab_size_per_slot is not None: logging.info("Set fixed vocab_size: {} for all the slots.".format( params.vocab_size_per_slot)) if params.custom_vocab_size_mapping is not None: logging.info("Set fixed vocab size for some slots: {}".format( params.custom_vocab_size_mapping)) vocab_size_dict = self._create_vocab_dict() self._env = Env(vocab_size_dict=vocab_size_dict, params=self.p) self._feature_to_config_dict = {} self._table_to_config_dict = {} def download_vocab_size_file_from_hdfs(self): tmp_folder = "temp" if os.path.exists(tmp_folder): shutil.rmtree(tmp_folder) os.mkdir(tmp_folder) hdfs_vocab_size_file_path = "{}{}/part*.csv".format( self.p.train.vocab_file_folder_prefix, self.p.train.end_date) cmd = "hadoop fs -copyToLocal {} {}".format(hdfs_vocab_size_file_path, tmp_folder) logging.info( "Hdfs path prefix: {}, end_date: {}, download vocab size file from hdfs cmd: {}" .format(self.p.train.vocab_file_folder_prefix, self.p.train.end_date, cmd)) ret = subprocess.run(cmd, shell=True) downloaded_files = os.listdir(tmp_folder) if ret.returncode == 0 and len(downloaded_files) == 1: self.p.vocab_file_path = os.path.join(tmp_folder, downloaded_files[0]) logging.info( "Download vocab size file successfully from hdfs: {}, use downloaded vocab size file: {}." .format(hdfs_vocab_size_file_path, self.p.vocab_file_path)) else: logging.info("Downloaded files: {}".format(downloaded_files)) logging.info("Use default vocab size file: {}".format( self.p.vocab_file_path)) def _create_vocab_dict(self): """Create vocab dict from a tsv file.""" vocab_size_per_slot = self.p.vocab_size_per_slot custom_vocab_size_mapping = self.p.custom_vocab_size_mapping vocab_size_dict = {} if self.p.train.end_date is not None and self.p.train.vocab_file_folder_prefix is not None: self.download_vocab_size_file_from_hdfs() assert self.p.vocab_file_path is not None, \ "Either provide vocab_file_path or vocab_file_folder_prefix and end date." with open(self.p.vocab_file_path) as f: for line in f: fields = line.strip().split("\t") assert len(fields) == 2, "each line in {} must have 2 fields".format( fields) if fields[0].isdigit() == False: continue slot_id = int(fields[0]) if vocab_size_per_slot is not None: distinct_count = vocab_size_per_slot else: distinct_count = int(fields[1]) if custom_vocab_size_mapping is not None and slot_id in custom_vocab_size_mapping: distinct_count = custom_vocab_size_mapping[slot_id] if self.p.vocab_size_offset is not None: distinct_count += self.p.vocab_size_offset vocab_size_dict[slot_id] = distinct_count logging.info("Slot and vocab size: {}".format(vocab_size_dict)) return vocab_size_dict def _parse_inputs(return_values): if isinstance(return_values, tuple): features, labels = return_values else: features, labels = return_values, None return features, labels def create_input_fn(self, mode=tf.estimator.ModeKeys.TRAIN): """Create input_fn given the mode. Args: mode: tf.estimator.ModeKeys.TRAIN/EVAL/PREDICT. Returns: An input fn for Estimator. """ # TODO(youlong.cheng): support eval and predict. assert mode == tf.estimator.ModeKeys.TRAIN file_pattern = self.p.train.file_pattern def tf_example_parser(examples): """Parse multiple examples.""" feature_map = self._get_feature_map() example = tf.io.parse_example(serialized=examples, features=feature_map) return self._post_process_example(example) def insert_stopping_signal(stop, batch_size, stopping_signals_name): def _map_fn(features): empty_sparse_tensor = tf.sparse.SparseTensor( indices=tf.zeros([0, 2], dtype=tf.int64), values=tf.zeros([0], dtype=tf.int64), dense_shape=(batch_size, 1)) shape = [batch_size] if stop is True: for name, tensor in features.items(): # For sparse tensors, set them to empty. if isinstance(tensor, tf.sparse.SparseTensor) is True: features[name] = empty_sparse_tensor features[stopping_signals_name] = tf.ones(shape=shape, dtype=tf.dtypes.bool) else: features[stopping_signals_name] = tf.zeros(shape=shape, dtype=tf.dtypes.bool) return features return _map_fn def input_fn(params): """Returns training or eval examples, batched as specified in params.""" logging.info("Model input_fn") if params["cpu_test"] is True: dataset = tf.data.TFRecordDataset(file_pattern, compression_type=None, buffer_size=None) dataset = dataset.batch(params["batch_size"], drop_remainder=True).map( tf_example_parser, num_parallel_calls=tf.data.experimental.AUTOTUNE, deterministic=self.p.deterministic) return dataset.repeat() # By shuffle=False, list_files will get all files already in time sorted order. if file_pattern is not None: files = tf.data.Dataset.list_files(file_pattern, shuffle=False) else: assert self.p.train.file_folder is not None, \ "p.train.file_folder must be defined if file_pattern is None." assert self.p.train.date_and_file_name_format is not None, \ "p.train.date_and_file_name_format must be defined if file_pattern is None." file_pattern_ = os.path.join(self.p.train.file_folder, self.p.train.date_and_file_name_format) logging.info("file_pattern: {}, file_folder: {}".format( file_pattern_, self.p.train.file_folder)) files = tf.data.Dataset.list_files(file_pattern_, shuffle=False) assert self.p.train.end_date is not None, \ "end_date in config or flag must be defined if file_pattern is None" assert params["enable_stopping_signals"] is not None, \ "When using end_date of input data, enable_stopping_signals need to be provided." files = util.range_dateset(files, self.p.train.file_folder, start_date=self.p.train.start_date, end_date=self.p.train.end_date) # This function will get called once per TPU task. Each task will process the files # with indexs which modulo num_calls equals to call_index. _, call_index, num_calls, _ = ( params["context"].current_input_fn_deployment()) files = files.shard(num_calls, call_index) def fetch_dataset(filename): dataset = tf.data.TFRecordDataset(filename, compression_type=None, buffer_size=None) return dataset # Read the data from disk in parallel. # Files will be process from the beginning to the end. With a local parallel of interleaving # multiple files currently. Number of interleaving files are defined by the cycle. dataset = files.interleave( fetch_dataset, cycle_length=self.p.files_interleave_cycle_length, num_parallel_calls=tf.data.experimental.AUTOTUNE, deterministic=self.p.deterministic) dataset = dataset.batch(params["batch_size"], drop_remainder=True).map( tf_example_parser, num_parallel_calls=tf.data.experimental.AUTOTUNE, deterministic=self.p.deterministic) if self.p.train.repeat: assert params["enable_stopping_signals"] is False dataset = dataset.repeat() enable_stopping_signals = params["enable_stopping_signals"] if enable_stopping_signals: logging.info("Adding stop signals to original data set.") # Add stop signal to help handling end of stream. user_provided_dataset = dataset.map( insert_stopping_signal( stop=False, batch_size=params["batch_size"], \ stopping_signals_name=fh._USER_PROVIDED_SIGNAL_NAME), num_parallel_calls=tf.data.experimental.AUTOTUNE, deterministic=False) final_batch_dataset = dataset.repeat().map( insert_stopping_signal( stop=True, batch_size=params["batch_size"], \ stopping_signals_name=fh._USER_PROVIDED_SIGNAL_NAME), num_parallel_calls=tf.data.experimental.AUTOTUNE, deterministic=False) dataset = user_provided_dataset.concatenate(final_batch_dataset) dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) # Test only # dataset = dataset.take(128).cache().repeat() return dataset return input_fn def logits_fn(self): """Calculate logits.""" # Inherated class must implement this function. raise NotImplementedError def init_slot_to_env(self): """Run this in the beginning to initialize the slot and its embedding dims information.""" logging.info("Model init_slot_to_env") logging.info("init_slot_to_env, self: {}".format(self)) self.logits_fn() self._env.finalize() def create_model_fn(self,): """Create model fn.""" raise NotImplementedError('Abstract method.') def _get_feature_map(self): """Returns data format of the serialized tf record file.""" # Inherated class must implement this function. raise NotImplementedError def _post_process_example(self, example): """Postprocess example.""" # build tensors for each embeddings in each slot for slot_id, feature_slot in list( self._env.slot_id_to_feature_slot.items()): # Check if feature_slot has at least one FeatureColumn associated with it # If not, it means that the slot only has ZeroFeatureColumn, so we # ignore it. if len(feature_slot.feature_columns) > 0: for feature_column in feature_slot.feature_columns: # If the vocab size per slot is set, we need to adjust the # vocab_id so that no vocab_id exceed this vocab size per slot embedding_tensor = example["{}_0".format(feature_column.fc_name)] if isinstance(feature_column, FeatureColumn3D): new_embedding_tensor = embedding_tensor.to_sparse() new_embedding_tensor = tf.SparseTensor( indices=new_embedding_tensor.indices, values=tf.maximum(new_embedding_tensor.values, 0), dense_shape=new_embedding_tensor.dense_shape) else: new_embedding_tensor = tf.SparseTensor( indices=embedding_tensor.indices, values=tf.maximum(embedding_tensor.values, 0), dense_shape=embedding_tensor.dense_shape) if self.p.vocab_size_per_slot is not None: new_embedding_tensor = tf.SparseTensor( indices=new_embedding_tensor.indices, values=tf.math.mod(new_embedding_tensor.values, self.p.vocab_size_per_slot), dense_shape=new_embedding_tensor.dense_shape) vocab_size = self.p.vocab_size_per_slot else: vocab_size = self._env._vocab_size_dict.get(slot_id, 10) if self.p.custom_vocab_size_mapping is not None and slot_id in self.p.custom_vocab_size_mapping: new_embedding_tensor = tf.SparseTensor( indices=new_embedding_tensor.indices, values=tf.math.mod(new_embedding_tensor.values, self.p.custom_vocab_size_mapping[slot_id]), dense_shape=new_embedding_tensor.dense_shape) vocab_size = self.p.custom_vocab_size_mapping[slot_id] if self.p.qr_multi_hashing and vocab_size > self.p.qr_hashing_threshold: # setting quotient/remainder vocab size R_vocab_size = vocab_size // self.p.qr_collision_rate + 1 Q_vocab_size = self.p.qr_collision_rate + 1 embedding_tensor = example["{}_0".format(feature_column.fc_name)] del example["{}_0".format(feature_column.fc_name)] # creating two features for remainder/quotient for feature_slice in feature_column.feature_slice_to_tf_placeholder: example["{}_{}_0".format( feature_column.fc_name, feature_slice.slice_index)] = tf.SparseTensor( indices=embedding_tensor.indices, values=tf.math.floormod(embedding_tensor.values, R_vocab_size), dense_shape=embedding_tensor.dense_shape) example["{}_{}_1".format( feature_column.fc_name, feature_slice.slice_index)] = tf.SparseTensor( indices=embedding_tensor.indices, values=tf.math.floordiv(embedding_tensor.values, R_vocab_size), dense_shape=embedding_tensor.dense_shape) else: if isinstance(feature_column, FeatureColumn3D): # Get row_lengths from embedding_tensor, which is RaggedTensor # for FeatureColumn3D row_lengths = tf.cast( embedding_tensor.row_lengths(), tf.int32, ) # [B] Tensor example["{}_0_row_lengths".format( feature_column.fc_name)] = row_lengths # seq feature, dims[0][0] is max seq length new_embedding_tensor = tf.sparse.slice( new_embedding_tensor, [0, 0], [ new_embedding_tensor.dense_shape[0], feature_column.max_seq_length ]) example["{}_0".format(feature_column.fc_name)] = tf.sparse.reorder( new_embedding_tensor) for feature_slice in feature_column.feature_slice_to_tf_placeholder: if feature_slice.slice_index != 0: example["{}_{}".format( feature_column.fc_name, feature_slice.slice_index)] = example["{}_0".format( feature_column.fc_name)] # This logic is to calculate AUC which follows current DeepInsight sampling logic. # Basically we distribute samples by their UIDs in _RATIO_N buckets. # Like we get their UID_BUCKET = UID % _RATIO_N. # Later we choose only the examples if their UID_BUCKET < _RATIO_N * _UID_SAMPLE_RATE if hs._UID in example: example[hs._UID_BUCKET] = tf.cast( tf.math.mod(example[hs._UID], hs._RATIO_N), tf.int32) return example def create_feature_and_table_config_dict(self): """Prepares the table and feature config given the parameters.""" env = self._env assert env.is_finalized() for slot_id, feature_slot in list( self._env.slot_id_to_feature_slot.items()): vocab_size = env.vocab_size_dict.get(slot_id, 1) # Check if feature_slot has at least one FeatureColumn associated with it # If not, it means that the slot only has ZeroFeatureColumn, so we # ignore it. if len(feature_slot.feature_columns) > 0: # Iterate through feature columns to create TableConfig and FeatureConfig for feature_column in feature_slot.feature_columns: for feature_slice in feature_column.feature_slice_to_tf_placeholder: if self.p.qr_multi_hashing and vocab_size > self.p.qr_hashing_threshold: # creating quotient/remainder embedding table logging.info('Setting QR table for slot {}'.format(slot_id)) R_vocab_size = vocab_size // self.p.qr_collision_rate + 1 Q_vocab_size = self.p.qr_collision_rate + 1 # remainder embedding table table_name = "table_{}_{}_0".format(slot_id, feature_slice.slice_index) if table_name not in self._table_to_config_dict: Rtable = tpu_embedding.TableConfig( vocabulary_size=R_vocab_size, dimension=feature_slice.dim, initializer=feature_slice.initializer, combiner="sum", learning_rate_fn=feature_slice.learning_rate_fn, optimization_parameters=feature_slice.optimizer) self._table_to_config_dict[table_name] = Rtable # remainder feature config feature_name = "{}_{}_0".format(feature_column.fc_name, feature_slice.slice_index) self._feature_to_config_dict[ feature_name] = tpu_embedding.FeatureConfig(table_name) # quotient embedding table table_name = "table_{}_{}_1".format(slot_id, feature_slice.slice_index) if table_name not in self._table_to_config_dict: Qtable = tpu_embedding.TableConfig( vocabulary_size=Q_vocab_size, dimension=feature_slice.dim, initializer=feature_slice.initializer, combiner="sum", learning_rate_fn=feature_slice.learning_rate_fn, optimization_parameters=feature_slice.optimizer) self._table_to_config_dict[table_name] = Qtable # quotient feature config feature_name = "{}_{}_1".format(feature_column.fc_name, feature_slice.slice_index) self._feature_to_config_dict[ feature_name] = tpu_embedding.FeatureConfig(table_name) table_name = "table_{}_{}".format(slot_id, feature_slice.slice_index) if table_name not in self._table_to_config_dict: table = tpu_embedding.TableConfig( vocabulary_size=vocab_size, dimension=feature_slice.dim, initializer=feature_slice.initializer, combiner="sum", learning_rate_fn=feature_slice.learning_rate_fn, optimization_parameters=feature_slice.optimizer) self._table_to_config_dict[table_name] = table feature_name = "{}_{}".format(feature_column.fc_name, feature_slice.slice_index) # Multiple feature configs can share the same table config if isinstance(feature_column, FeatureColumn3D): self._feature_to_config_dict[ feature_name] = tpu_embedding.FeatureConfig( table_name, max_sequence_length=feature_column.max_seq_length) else: self._feature_to_config_dict[ feature_name] = tpu_embedding.FeatureConfig(table_name) return self._feature_to_config_dict, self._table_to_config_dict def cross_shard_optimizer(self, optimizer, params): if params["cpu_test"]: return optimizer else: return tf.compat.v1.tpu.CrossShardOptimizer(optimizer) def process_features_for_cpu_test(self, features): processed_features = {} for feature_name, feature_value in features.items(): if isinstance(feature_value, tf.sparse.SparseTensor): feature_config = self._feature_to_config_dict[feature_name] table_config = self._table_to_config_dict[feature_config.table_id] dim = table_config.dimension max_sequence_length = feature_config.max_sequence_length vocab_size = table_config.vocabulary_size if feature_config.max_sequence_length == 0: initvalue = (np.random.rand(vocab_size, dim) - 0.5) / (vocab_size * dim) else: initvalue = (np.random.rand(vocab_size, max_sequence_length * dim) - 0.5) / (vocab_size * max_sequence_length * dim) initvalue = tf.cast(initvalue, tf.float32) embedding_variable = tf.compat.v1.get_variable(name=feature_name, initializer=initvalue, dtype=tf.float32) # Get new feature ids based on vocab_size. We mod # by vocab_size to make sure new feature ids will be # within vocab_size. This is only for test purpose. new_feature_ids = tf.SparseTensor(indices=feature_value.indices, values=tf.math.mod( feature_value.values, vocab_size), dense_shape=feature_value.dense_shape) # Get embeddings. embeddings = tf.nn.safe_embedding_lookup_sparse(embedding_variable, new_feature_ids, sparse_weights=None, combiner="sum") if max_sequence_length != 0: embeddings = tf.reshape(embeddings, [-1, max_sequence_length, dim]) processed_features[feature_name] = embeddings else: processed_features[feature_name] = feature_value # For CPU test, we will clear this state from now on. so later some host_call # will not use do use them to do tpu specific operation. self._feature_to_config_dict.clear() self._table_to_config_dict.clear() return processed_features ================================================ FILE: monolith/core/base_host_call.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import sys from absl import logging import tensorflow as tf class BaseHostCall(object): def __init__(self, output_dir, enable_host_call): self._output_dir = output_dir self._enable_host_call = enable_host_call self._tensor_names = ["global_step"] gs = tf.compat.v1.train.get_global_step() # Creating batch dimension since host call needs to concat all the cores' # results. self._tensors = [tf.reshape(gs, [-1])] # compressed_tensor_list is a list of compressed tensors, in # which each compressed tensor is the concatenation of multiple # uncompressed tensors. To decompress, we need to store the original # sizes of all uncompressed tensors, organized in list of lists. # Each list corresponds to a compressed tensors. self._lists_tensor_sizes = [] # A list of lists def record_summary_tensor(self, name, tensor): if not self._enable_host_call: return if name in self._tensor_names: logging.info('{} | {}'.format(name, self._tensor_names)) assert name not in self._tensor_names self._tensor_names.append(name) assert len(tensor.get_shape()) <= 1, "Now we only support tensor with shape (k, ) or ()"\ "but we met tensor with shape: {}".format(tensor.get_shape()) # Creating batch dimension since host call needs to concat all the cores' # results. reshaped_tensor = tf.reshape(tensor, [-1]) self._tensors.append(reshaped_tensor) def compress_tensors(self): """For n tensors with shape (k_i, ) and same data type, concat them on axis=1. After concatenation the compressed tensors is stored as in shape (1, k_0 + k_1 + ... + k_{n-1}). """ assert len(self._tensor_names) == len(self._tensors), "tensor_names and tensors must have same length," \ "tensor_names length: {}, tensors length: {}".format(len(self._tensor_names), len(self._tensors)) # key is tensor data type, value is a list of tensor names with this data type. data_type_to_tensor_names = {} # key is tensor data type, value is tensor a list of tensors with this data type. date_type_to_tensors = {} # Group tensor names and tensors by same data type. for tensor_name, tensor in zip(self._tensor_names, self._tensors): data_type_to_tensor_names.setdefault(tensor.dtype, []).append(tensor_name) date_type_to_tensors.setdefault(tensor.dtype, []).append(tensor) # Compress n tensors tensor_0, tensor_1, ... , tensor_{n-1} with shape # (k_0, ), (k_1, )... (k_{n-1}, ) of same data # type to one tensor with shape (1, k_0 + ... + k_{n-1}) compressed_tensor_name_list = [] compressed_tensor_list = [] for data_type, tensor_list in date_type_to_tensors.items(): compressed_tensor_name_list.extend(data_type_to_tensor_names[data_type]) # concat a list of tensors with shapes # (k_0, ), (k_1, )... (k_{n-1}, ) to a tensor with shape # (k_0 + k_1 + ... + k_{n-1}, ) tensor_sizes = [] for tensor in tensor_list: tensor_sizes.append(tensor.shape[0].value) self._lists_tensor_sizes.append(tensor_sizes) compressed_tensor = tf.concat(tensor_list, axis=0) # expand dimension at 0 to make it have the batch dimension # tensor with shape (k_0 + k_1 + ... + k_{n-1}, ) # => tensor with shape (1, k_0 + k_1 + ... + k_{n-1}) compressed_tensor = tf.expand_dims(compressed_tensor, axis=0) compressed_tensor_list.append(compressed_tensor) logging.info( "Host call compressed tensors, data type: {}, compressed tensor shape: {}" .format(data_type, compressed_tensor.shape)) self._tensor_names = compressed_tensor_name_list self._tensors = compressed_tensor_list def decompress_tensors(self, tensors): """ Decompress the compressed tensors into list of decompressed tensors. Given a list of compressed tensors from *args. Each tensor has shape (num_cores, k_0 + k_1 + ... + k_{n-1}), in which the second dimension is the sum of lengths of uncompressed tensors (k_i, ) from the same core, with same shape and data type. Parse and convert them to a list of decompressed tensors. Each decompressed tensor has shape (num_cores, k_i). Decompressed tensor number must match with number of tensor names as well. """ # Need decompress tensors decompressed_tensor_list = [] for index, compressed_tensor in enumerate(tensors): # For each tensor, its shape is (num_cores, k_0 + ... + k_{n-1}) assert len(compressed_tensor.get_shape( )) == 2, "Compressed tensors shape must be (n, m), met shape: {}".format( compressed_tensor.shape) logging.info("Decompressed tensors shape: {}, dtype: {}.".format( compressed_tensor.shape, compressed_tensor.dtype)) # tensor with shape (num_cores, k_0 + ... + k_{n-1}) # => list of tensors with shape (num_cores, k_i) split_tensors = tf.split(compressed_tensor, self._lists_tensor_sizes[index], axis=1) for tensor in split_tensors: # Each decompressed tensor with shape (num_cores, k_i) # => (num_cores * k_i, ). tensor = tf.squeeze(tensor) decompressed_tensor_list.append(tensor) assert self._tensor_names[ 0] == "global_step", "The first tensor name must be global_step, met value: {}".format( self._tensor_names[0][0]) return decompressed_tensor_list[0][0], decompressed_tensor_list def generate_host_call_hook(self): # Children should implement this API and implement it with model specific host call logic. return None ================================================ FILE: monolith/core/base_layer.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Base class for all layers.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import tensorflow as tf from monolith.core.hyperparams import InstantiableParams from monolith.core.py_utils import NestedMap from collections import defaultdict _layer_loss = defaultdict(dict) # _layer_loss[graph][name] _name_inuse = defaultdict(int) class BaseLayer(object): @classmethod def params(cls): """Returns the layer params.""" p = InstantiableParams(cls) p.define('name', get_uname(cls.__name__), 'Name of this layer object.') return p def __init__(self, params): """Layer constructor. Args: params: A params used to construct this layer. """ assert params.name, ('Layer params for %s must have a "name"' % self.__class__.__name__) # Child layers created by this layer through CreateChild/CreateChildren. self._private_children = NestedMap() @property def children(self): """Returns children layers of this layer in a `.NestedMap`.""" return self._private_children def __getattr__(self, name): """Returns the child layer of the given name.""" if name == '_private_children': raise AttributeError( 'pre-mature access to __getattr__ before _private_children ' 'is created.') if name in self._private_children: return self._private_children[name] elif (hasattr(type(self), name) and isinstance(getattr(type(self), name), property)): # There was an AttributeError raised by a property getter. # Call property getter again directly to raise the same error. return getattr(type(self), name).fget(self) else: raise AttributeError('%s is not a sub-layer of %s.' % (name, self)) def __call__(self, *args, **kwargs): """Forwards call to FProp.""" return self.fprop(*args, **kwargs) def fprop(self, *args, **kwargs): """Forward propagation. The central interface that subclasses should implement. The caller calls `FProp`. Args: *args: List args. **kwargs: Keyward args. """ del args del kwargs raise NotImplementedError('Abstract method of %s' % self) def create_child(self, name, params): """Create a sub layer. The created sub layer can be accessed by `name`. E.g.:: self.create_child('foo', ...) self.foo.fprop... or:: self.children['foo'].fprop... self.children.foo.fprop... Args: name: Sub layer name used as key to access it as attribute params: `Hyperparams` object to instantiate a layer. """ # self._check_name(name) if not params.name: params.name = self.p.name # p = copy_params_to(self.p, params.copy()) # params = copy_params_to(self.p, params.copy()) child = params.instantiate() self._private_children[name] = child def create_children(self, name, params): """Create a list or dict of sub layers. The created sub layer list can be accessed by `name`. E.g.: self.create_children('foo', ...) self.foo[10].FProp... or: self.children['foo'][10].Fprop... self.children.foo[10].Fprop... Args: name: The name for the sub layers, which is used as the key into vars/theta. params: a list of `Hyperparams` objects to create. """ self._private_children[name] = [] for index, param in enumerate(params): if not param.name: param.name = '%s_%d' % (name, index) child = param.instantiate() self._private_children[name].append(child) def get_uname(name): if name in _name_inuse: _name_inuse[name] += 1 return "{name}_{idx}".format(name=name, idx=_name_inuse[name]) else: return name def add_layer_loss(name, loss): graph_layer_loss = _layer_loss[tf.compat.v1.get_default_graph()] if name in graph_layer_loss: graph_layer_loss[name] += loss else: graph_layer_loss[name] = loss def get_layer_loss(): return _layer_loss[tf.compat.v1.get_default_graph()] ================================================ FILE: monolith/core/base_layer_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from monolith.core import base_layer from monolith.core import hyperparams import unittest class BaseLayerTest(unittest.TestCase): def test_create_child(self): layer_p = base_layer.BaseLayer.params() layer_p.name = 'test' layer = layer_p.instantiate() layer._disable_create_child = False # pylint: disable=protected-access layer.create_child(name='a', params=layer_p) self.assertTrue('a' in layer.children) def test_create_children(self): layer_p = base_layer.BaseLayer.params() layer_p.name = 'test' layer = layer_p.instantiate() layer._disable_create_child = False # pylint: disable=protected-access layer.create_children(name='a', params=[layer_p, layer_p]) self.assertTrue('a' in layer.children) self.assertEqual(len(layer.a), 2) if __name__ == '__main__': unittest.main() ================================================ FILE: monolith/core/base_model_params.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import absolute_import from __future__ import division from __future__ import print_function class SingleTaskModelParams(object): """Model Params for a `.SingleTaskModel`.""" def task(self): """Returns task params.""" raise NotImplementedError('Abstract method') ================================================ FILE: monolith/core/base_task.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Base task.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function from monolith.core import base_layer from monolith.core import hyperparams class BaseTask(base_layer.BaseLayer): """A single training task.""" @classmethod def params(cls): p = super(BaseTask, cls).params() p.define('accelerator', None, 'Accelerator to use. One of [None, "tpu", "horovod"].') p.define('input', hyperparams.Params(), 'Input Params.') p.input.define('eval_examples', None, 'Number of total examples for evaluation.') p.input.define('train_examples', None, 'Number of total examples for training.') p.define('eval', hyperparams.Params(), 'Params to control how this task should be evaled.') p.eval.define('per_replica_batch_size', None, 'Number of per replica batch size') p.eval.define('steps_per_eval', 10000, 'Number of training steps between two evluations.') p.eval.define('steps', None, 'Number of steps for which to eval model.') p.define('train', hyperparams.Params(), 'Params to control how this task should be trained.') p.train.define('steps', None, 'Number of steps for which to train model.') p.train.define('max_steps', None, 'Number of total steps for which to train model.') p.train.define('per_replica_batch_size', None, 'Number of per replica batch size') p.train.define( 'file_pattern', None, 'Training input data. If file_pattern and file_folder are both' \ ' provided, use file pattern firstly.') p.train.define('repeat', False, 'Whether repeat in the training job or not.') p.train.define('label_key', 'label', 'The key of the label field in the data.') p.train.define( 'save_checkpoints_steps', None, 'Save checkpoint every save_checkpoints_steps. If None, overwrite by runner.' ) p.train.define( 'save_checkpoints_secs', None, 'Save checkpoint every save_checkpoints_secs. If None, overwrite by runner.' ) p.train.define('dense_only_save_checkpoints_secs', None, 'Save dense checkpoint every save_checkpoints_secs') p.train.define('dense_only_save_checkpoints_steps', None, 'Save dense checkpoint every save_checkpoints_steps') return p def __init__(self, params): """Constructs a BaseTask object.""" super(BaseTask, self).__init__(params) def create_input_fn(self, mode): """Create input_fn given the mode. Args: mode: tf.estimator.ModeKeys.TRAIN/EVAL/PREDICT. Returns: An input fn for Estimator. """ raise NotImplementedError('Abstract method.') def create_model_fn(self,): """Create model fn.""" raise NotImplementedError('Abstract method.') ================================================ FILE: monolith/core/base_tpu_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import tensorflow as tf from monolith.core import model_registry from monolith.tpu_runner import TPURunner class BaseTPUTest(tf.test.TestCase): """Base class for tpu test.""" def runWithCPU(self, task_name): task_param = model_registry.GetParams(task_name) runner = TPURunner(task_param) runner._cpu_test = True runner._host_call_every_n_steps = 0 runner.run() def runMergeVectorTestOnCPU(self, task_name): task_param = model_registry.GetParams(task_name) task_param.merge_vector = True runner = TPURunner(task_param) runner._cpu_test = True runner._host_call_every_n_steps = 0 runner.run() env = runner._task._env # Verify slot number should be same before and after merged vectors. self.assertEqual(len(env._slot_to_dims.keys()), len(env._slot_to_merged_dims.keys())) # Verify all slots are same, merged dims are expected. for slot_id, original_dims in env._slot_to_dims.items(): #Verify slot_id are all the same before and after merged vecotrs. self.assertIn(slot_id, env._slot_to_merged_dims) merged_dims = env._slot_to_merged_dims[slot_id] if original_dims[0] == 1: # Veirfy bias dim is same. self.assertEqual(merged_dims[0], 1) # Verify merged vector dim is same. if len(original_dims) > 1: expect_merged_dim = sum(original_dims[1:]) self.assertEqual(len(merged_dims), 2) self.assertEqual(expect_merged_dim, merged_dims[1]) else: self.assertEqual(len(merged_dims), 1) else: # Verify merged vector dim is same. expect_merged_dim = sum(original_dims) self.assertEqual(len(merged_dims), 1) self.assertEqual(expect_merged_dim, merged_dims[0]) # Verify all split features are as expected. for name, embedding in env._tpu_features.items(): if "slot_" in name: slot_id = int(name.split("_")[1]) index = int(name.split("_")[2]) expect_dim = env._slot_to_dims[slot_id][index] actual_dim = embedding.get_shape().as_list()[1] self.assertEqual(actual_dim, expect_dim) ================================================ FILE: monolith/core/core_test_suite.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import sys import unittest from monolith.core.hyperparams_test import ParamsTest from monolith.core.base_layer_test import BaseLayerTest from monolith.core.base_embedding_host_call_test import BaseEmbeddingHostCallTest from monolith.core.util_test import UtilTest def suite(): suite = unittest.TestSuite() suite.addTest(unittest.makeSuite(ParamsTest)) suite.addTest(unittest.makeSuite(BaseLayerTest)) suite.addTest(unittest.makeSuite(BaseEmbeddingHostCallTest)) suite.addTest(unittest.makeSuite(UtilTest)) return suite if __name__ == '__main__': runner = unittest.TextTestRunner(verbosity=2) runner.run(suite()) ================================================ FILE: monolith/core/dense.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. "Code to implement custom Sail-like layer using TensorFlow Keras API." from __future__ import absolute_import, division, print_function import functools import sys import numpy as np import scipy.stats as stats import tensorflow as tf from absl import logging from tensorflow.python.framework import dtypes, tensor_shape from tensorflow.python.keras import activations from tensorflow.python.keras import backend as K from tensorflow.python.keras import constraints, initializers, regularizers from tensorflow.python.keras.engine.input_spec import InputSpec from tensorflow.python.ops import gen_math_ops, math_ops, nn from monolith.core.base_layer import BaseLayer from monolith.core.variance_scaling import VarianceScaling class Dense(tf.keras.layers.Dense, BaseLayer): @classmethod def params(cls): p = super(Dense, cls).params() p.define('units', 512, 'Positive integer, dimensionality of the ' \ 'output space.') p.define('activation', None, 'Activation function to use.') p.define('use_bias', True, 'Boolean, whether the layer uses a bias ' \ 'vector.') p.define('kernel_initializer', VarianceScaling(mode='fan_avg', distribution='uniform'), 'Initializer for the `kernel` weights matrix. Currently only '\ 'supporting variance scaling initializer.') p.define('bias_initializer', 'zeros', 'Initializer for the bias vector.') p.define('allow_kernel_norm', True, 'Boolean, kernel normalization is only applicable when TRAINING.') p.define('kernel_norm_trainable', True, 'Boolean, whether a trainable weight norm variable is allocated') p.define('partitioner', None, 'VariablePartitioner, if we will use partitioned variable') return p def __init__(self, params, **kwargs): if 'input_shape' not in kwargs and 'input_dim' in kwargs: kwargs['input_shape'] = (kwargs.pop('input_dim'),) # Call the __init__() function for BaseLayer BaseLayer.__init__(self, params) # Call the _init__() function for tf.keras.layers.Dense super(Dense, self).__init__( units=params.units, activation=params.activation, use_bias=params.use_bias, kernel_initializer=params.kernel_initializer, bias_initializer=params.bias_initializer, activity_regularizer=None, kernel_constraint=None, bias_constraint=None, **kwargs, ) # Change/Add some class properties to the tf.keras.layers.Dense # properties. Note that this Dense layer does not support regularizers # and constraints. self.p = params self.units = int( params.units) if not isinstance(params.units, int) else params.units self.activation = activations.get(params.activation) self.use_bias = params.use_bias self.kernel_initializer = params.kernel_initializer self.bias_initializer = initializers.get(params.bias_initializer) self.supports_masking = True self.input_spec = InputSpec(min_ndim=2) self.allow_kernel_norm = params.allow_kernel_norm self.kernel_norm_trainable = params.kernel_norm_trainable self.var_name_prefix = params.name self.partitioner = params.partitioner def build(self, input_shape): dtype = tf.dtypes.as_dtype(self.dtype or K.floatx()) if not (dtype.is_floating or dtype.is_complex): raise TypeError('Unable to build `Dense` layer with non-floating point ' 'dtype %s' % (dtype,)) input_shape = tensor_shape.TensorShape(input_shape) if tensor_shape.dimension_value(input_shape[-1]) is None: raise ValueError('The last dimension of the inputs to `Dense` ' 'should be defined. Found `None`.') last_dim = tensor_shape.dimension_value(input_shape[-1]) self.input_spec = InputSpec(min_ndim=2, axes={-1: last_dim}) kernel_shape = [last_dim, self.units] init_kernel = self.kernel_initializer(shape=kernel_shape, dtype=self.dtype) if self.partitioner is None: kernel_initializer = lambda shape, dtype: init_kernel else: kernel_initializer = init_kernel self.kernel = tf.compat.v1.get_variable(initializer=kernel_initializer, trainable=True, name="{}/kernel".format( self.var_name_prefix), shape=kernel_shape, dtype=dtype, partitioner=self.partitioner) # Add the option for allow_kernel_norm if self.allow_kernel_norm: self.kernel = tf.nn.l2_normalize(self.kernel, axis=0, epsilon=1e-6, name='normalized_kernel') if self.kernel_norm_trainable: # Use np to mitigate the error thrown by tensorflow due to the variable # initializer inside a conditional. init_trainable_kernel_norm = np.linalg.norm( init_kernel, axis=0, ) if self.partitioner is None: norm_initializer = lambda shape, dtype: init_trainable_kernel_norm else: norm_initializer = init_trainable_kernel_norm self.trainable_kernel_norm = tf.compat.v1.get_variable( initializer=norm_initializer, shape=init_trainable_kernel_norm.shape, trainable=True, name='{}/trainable_kernel_norm'.format(self.var_name_prefix), dtype=dtype, partitioner=self.partitioner) self.kernel = tf.multiply(self.kernel, self.trainable_kernel_norm, name='mul_of_kernel_and_trainable_norm') if self.use_bias: self.bias = self.add_weight(name='{}/bias'.format(self.var_name_prefix), shape=[ self.units, ], initializer=self.bias_initializer, dtype=dtype, trainable=True) else: self.bias = None self.built = True def get_config(self): config = { 'units': self.units, 'activation': activations.serialize(self.activation), 'use_bias': self.use_bias, 'kernel_initializer': initializers.serialize(self.kernel_initializer), 'bias_initializer': initializers.serialize(self.bias_initializer), 'allow_kernel_norm': self.allow_kernel_norm, 'kernel_norm_trainable': self.kernel_norm_trainable, 'partitioner': self.partitioner, } base_config = super(Dense, self).get_config() return dict(list(base_config.items()) + list(config.items())) def fprop(self, inputs, **kwargs): return self.call(inputs) ================================================ FILE: monolith/core/dense_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for the Dense layer.""" from __future__ import absolute_import, division, print_function import textwrap import numpy as np import tensorflow as tf from tensorflow.python import keras from tensorflow.python.eager import context from tensorflow.python.framework import dtypes, ops from tensorflow.python.keras import backend from monolith.core import testing_utils from monolith.core.dense import Dense class DenseTest(tf.test.TestCase): def test_dense_instantiate(self): dense_layer_template = Dense.params() test_params0 = dense_layer_template.copy() test_params0.name = 'test_dense0' test_params0.units = 3 testing_utils.layer_test(Dense, kwargs={'params': test_params0}, input_shape=(3, 2)) test_params1 = dense_layer_template.copy() test_params1.name = 'test_dense1' test_params1.units = 3 testing_utils.layer_test(Dense, kwargs={'params': test_params1}, input_shape=(3, 4, 2)) test_params2 = dense_layer_template.copy() test_params2.name = 'test_dense2' test_params2.units = 3 testing_utils.layer_test(Dense, kwargs={'params': test_params2}, input_shape=(None, None, 2)) test_params3 = dense_layer_template.copy() test_params3.name = 'test_dense3' test_params3.units = 3 testing_utils.layer_test(Dense, kwargs={'params': test_params3}, input_shape=(3, 4, 5, 2)) def test_dense_dtype(self): dense_layer_template = Dense.params() test_params0 = dense_layer_template.copy() test_params0.name = 'test_dense0' test_params0.units = 3 inputs = ops.convert_to_tensor_v2( np.random.randint(low=0, high=7, size=(2, 2))) layer = Dense(test_params0, dtype='float32') outputs = layer(inputs) self.assertEqual(outputs.dtype, 'float32') def test_dense(self): dense_layer_template = Dense.params() test_params0 = dense_layer_template.copy() test_params0.name = 'test_dense0' test_params0.units = 3 layer = Dense(test_params0) output = layer(keras.backend.variable(np.ones((2, 4)))) self.assertAllEqual((2, 3), output.shape) with self.session() as sess: sess.run(tf.compat.v1.global_variables_initializer()) sess.run(output) def test_dense_with_partitioner(self): param = Dense.params() param.name = "test_dense_with_partitioner" param.units = 5 param.partitioner = tf.compat.v1.variable_axis_size_partitioner(1024) layer = Dense(param) output = layer(keras.backend.variable(np.ones((2, 4096)))) self.assertAllEqual((2, 5), output.shape) with self.session() as sess: sess.run(tf.compat.v1.global_variables_initializer()) sess.run(output) if __name__ == '__main__': tf.compat.v1.disable_eager_execution() tf.test.main() ================================================ FILE: monolith/core/feature.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Sail like API.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function from absl import logging import tensorflow as tf from collections import namedtuple class FeatureSlice(object): """Sail-like FeatureSlice implementation.""" def __init__( self, feature_slot, dim, slice_index, optimizer=None, initializer=None, learning_rate_fn=None, ): """Initialize a FeatureSlice object. Args: feature_slot: A FeatureSlot object that this FeatureSlice object belongs to. dim: The dim of the FeatureSlice. slice_index: The index of this FeatureSlice in the FeatureSlot list of FeatureSlice. optimizer: TensorFlow optimization parameters (e.g., tf.compat.v1.tpu.experimental.FtrlParameters). initializer: TensorFlow initializer (e.g., tf.random_uniform_initializer). learning_rate_fn: A function that changes the learning rate over the training period (e.g., tf.compat.v1.train.polynomial_decay). """ self._feature_slot = feature_slot self._dim = dim self._slice_index = slice_index self._optimizer = optimizer self._initializer = initializer self._learning_rate_fn = learning_rate_fn def __repr__(self): """The default __repr__() method is overwritten to enable FeatureSlice as dict key.""" return '[FeatureSlice][slot_{}][{}]'.format(self._feature_slot.slot_id(), self._slice_index) def __hash__(self): """The default __hash__() method is overwritten to enable FeatureSlice as dict key.""" return hash((self._feature_slot.slot_id(), self._slice_index)) @property def dim(self): return self._dim @property def slice_index(self): return self._slice_index @property def optimizer(self): return self._optimizer @property def initializer(self): return self._initializer @property def learning_rate_fn(self): return self._learning_rate_fn class FeatureSlot(object): """Sail like FeatureSlot implementation.""" def __init__( self, env, slot_id, has_bias=False, bias_optimizer=tf.compat.v1.tpu.experimental.FtrlParameters( learning_rate=0.01), bias_initializer=tf.zeros_initializer(), bias_learning_rate_fn=None, default_vec_optimizer=tf.compat.v1.tpu.experimental.AdagradParameters( learning_rate=0.01), default_vec_initializer=tf.random_uniform_initializer(minval=-0.001, maxval=0.001), default_vec_learning_rate_fn=None, occurrence_threshold=None, expire_time=None, ): """Initialize a FeatureSlot object. Args: env: An Env object that this FeatureSlot belongs to. slot_id: The slot_id of the FeatureSlot. has_bias: A Boolean on whether this FeatureSlot has a bias FeatureSlice. bias_optimizer: TensorFlow optimization parameters for bias slice (e.g., tf.compat.v1.tpu.experimental.FtrlParameters). bias_initializer: TensorFlow initializer for bias slice (e.g., tf.random_uniform_initializer). bias_learning_rate_fn: A function that changes the learning rate of the bias over the training period (e.g., tf.compat.v1.train.polynomial_decay). default_vec_optimizer: The default TensorFlow optimization parameters for vector slices. default_vec_initializer: The default TensorFlow initializer for vector slices. default_vec_learning_rate_fn: The default TensorFlow learning rate function for vector slices. occurrence_threshold: The number of occurrences that an FID of the slot_id must occur in order to be recorded into the training data. expire_time: FID may be evicted from the model if not been updated for expire_time days. """ self._env = env self._slot_id = slot_id self._has_bias = has_bias self._bias_optimizer = bias_optimizer self._bias_initializer = bias_initializer self._bias_learning_rate_fn = bias_learning_rate_fn self._default_vec_optimizer = default_vec_optimizer self._default_vec_initializer = default_vec_initializer self._default_vec_learning_rate_fn = default_vec_learning_rate_fn self._occurrence_threshold = occurrence_threshold self._expire_time = expire_time self._feature_slices = [] self._merged_feature_slices = [] self._feature_columns = [] self._env.set_feature_slot(slot_id, self) if self._has_bias: feature_slice = FeatureSlice( feature_slot=self, dim=1, slice_index=0, optimizer=self._bias_optimizer, initializer=self._bias_initializer, learning_rate_fn=self._bias_learning_rate_fn, ) self._feature_slices.append(feature_slice) def get_env(self): return self._env def slot_id(self): return self._slot_id def has_bias(self): return self._has_bias def add_feature_slice( self, dim, optimizer=None, initializer=None, learning_rate_fn=None, ): """Create a FeatureSlice and add to _feature_slices.""" optimizer = optimizer if optimizer is not None else self._default_vec_optimizer initializer = initializer if initializer is not None else self._default_vec_initializer learning_rate_fn = learning_rate_fn if learning_rate_fn is not None else self._default_vec_learning_rate_fn feature_slice = FeatureSlice( feature_slot=self, dim=dim, slice_index=len(self._feature_slices), optimizer=optimizer, initializer=initializer, learning_rate_fn=learning_rate_fn, ) self._feature_slices.append(feature_slice) return feature_slice def add_merged_feature_slice( self, dim, optimizer=None, initializer=None, learning_rate_fn=None, ): """Create a FeatureSlice for merged embedding and add to _merged_feature_slices.""" optimizer = optimizer if optimizer is not None else self._default_vec_optimizer initializer = initializer if initializer is not None else self._default_vec_initializer learning_rate_fn = learning_rate_fn if learning_rate_fn is not None else self._default_vec_learning_rate_fn feature_slice = FeatureSlice( feature_slot=self, dim=dim, slice_index=len(self._merged_feature_slices), optimizer=optimizer, initializer=initializer, learning_rate_fn=learning_rate_fn, ) self._merged_feature_slices.append(feature_slice) return feature_slice def _add_feature_column(self, feature_column): """Add a FeatureColumn object to the FeatureSlot.""" self._feature_columns.append(feature_column) # If the FeatureSlot has bias, add the bias FeatureSlice to all # FeatureColumn objects if self._has_bias: for feature_column in self._feature_columns: feature_column._bias = feature_column.embedding_lookup( self._feature_slices[0]) @property def bias_optimizer(self): return self._bias_optimizer @property def bias_initializer(self): return self._bias_initializer @property def bias_learning_rate_fn(self): return self._bias_learning_rate_fn @property def default_vec_optimizer(self): return self._default_vec_optimizer @property def default_vec_initializer(self): return self._default_vec_initializer @property def default_vec_learning_rate_fn(self): return self._default_vec_learning_rate_fn @property def feature_slices(self): return self._feature_slices @property def merged_feature_slices(self): return self._merged_feature_slices @property def feature_columns(self): return self._feature_columns class FeatureColumnV1(object): """Sail like class FeatureColumnV1 implementation.""" def __init__(self, feature_slot, fc_name): """Initialize a FeatureSlot object. Args: feature_slot: A FeatureSlot object that this FeatureColumnV1 belongs to. fc_name: The name of the feature column (e.g, "slot_1_0"). """ self._feature_slot = feature_slot self._fc_name = fc_name self._feature_slice_to_tf_placeholder = {} self._merged_feature_slice_to_tf_placeholder = {} self._bias = None self._feature_slot._add_feature_column(self) def get_env(self): return self._feature_slot.get_env() def embedding_lookup( self, feature_slice, init_minval_for_oov=None, init_maxval_for_oov=None, ): return self.get_env()._embedding_lookup(self, feature_slice, init_minval_for_oov, init_maxval_for_oov) def get_bias(self): assert self._bias is not None return self._bias @property def feature_slot(self): return self._feature_slot @property def fc_name(self): return self._fc_name @property def feature_slice_to_tf_placeholder(self): env = self.get_env() assert env.is_finalized, "is_finalized must be true which means \ _feature_slice_to_tf_placeholder must be initialized before using this \ _feature_slice_to_tf_placeholder" if env._merge_vector: return self._merged_feature_slice_to_tf_placeholder else: return self._feature_slice_to_tf_placeholder class FeatureColumn3D(object): """Sail like class FeatureColumn3D implementation.""" def __init__(self, feature_slot, max_seq_length, fc_name): self._feature_slot = feature_slot self._fc_name = fc_name self._feature_slice_to_tf_placeholder = {} self._bias = None logging.info("max_seq_length {}".format(max_seq_length)) self._max_seq_length = max_seq_length self._feature_slot._add_feature_column(self) def get_env(self): return self._feature_slot.get_env() def embedding_lookup( self, feature_slice, max_seq_length, init_minval_for_oov=None, init_maxval_for_oov=None, ): return self.get_env()._seq_embedding_lookup(self, feature_slice, self._max_seq_length, init_minval_for_oov, init_maxval_for_oov) def get_bias(self): assert self._bias is not None return self._bias @property def feature_slot(self): return self._feature_slot @property def fc_name(self): return self._fc_name @property def feature_slice_to_tf_placeholder(self): return self._feature_slice_to_tf_placeholder @property def max_seq_length(self): return self._max_seq_length def size_tensor_lookup(self): """Name with '_size' as the suffix of ${fc_name}""" return self.get_env()._size_tensor_lookup(self) class Env(object): """Environment which holds important information and track the embedding tables.""" def __init__(self, vocab_size_dict, params): self._vocab_size_dict = vocab_size_dict self._slot_id_to_feature_slot = {} # {1: FeatureSlot} self._tpu_features = None self._is_finalized = False self.set_params(params) def set_tpu_features(self, tpu_features): self._tpu_features = tpu_features if self._merge_vector: for slot_id, feature_slot in self._slot_id_to_feature_slot.items(): # Split the merged embeddings self._split_merged_embedding(feature_slot) def set_feature_slot(self, slot_id, feature_slot): # Set feature slot only be called during the first round of calling logits in which env is not finalized yet. if self._is_finalized == True: return assert slot_id not in self._slot_id_to_feature_slot, "Feature slot with id: {} can not be created more than once.".format( slot_id) self._slot_id_to_feature_slot[slot_id] = feature_slot def set_params(self, params): self._QR_multi_hashing = params.qr_multi_hashing self._QR_hashing_threshold = params.qr_hashing_threshold self._QR_collision_rate = params.qr_collision_rate self._use_random_init_embedding_for_oov = params.use_random_init_embedding_for_oov self._merge_vector = params.merge_vector def is_finalized(self): return self.is_finalized def _embedding_lookup(self, feature_column, feature_slice, init_minval_for_oov=None, init_maxval_for_oov=None): assert feature_column._feature_slot.slot_id( ) == feature_slice._feature_slot.slot_id() if self._tpu_features: logging.vlog(2, "__embedding_loopup with features exist.") if self._QR_multi_hashing and self._vocab_size_dict[ slot_id] > self._QR_hashing_threshold: logging.vlog( 2, "__embedding_lookup of QR hashing for slot {}.".format(slot_id)) # taking quotient feature and remainder feature keyR = "{}_{}_0".format(feature_column.fc_name, feature_slice.slice_index) keyQ = "{}_{}_1".format(feature_column.fc_name, feature_slice.slice_index) assert keyR in self._tpu_features and keyQ in self._tpu_features, \ "keyR: {} or keyQ: {} not in tpu features, probably need to check core.base_embedding_task._post_process_example()".format(keyR, keyQ) # combining quotient feature and remainder feature # element-wise addition performs better than element-wise multiplication return self._tpu_features[keyR] + self._tpu_features[keyQ] else: key = "{}_{}".format(feature_column.fc_name, feature_slice.slice_index) if not self._use_random_init_embedding_for_oov or init_minval_for_oov is None: return self._tpu_features[key] norm = tf.norm(self._tpu_features[key], axis=1) random = tf.random.uniform( tf.shape(self._tpu_features[key]), minval=init_minval_for_oov, maxval=init_maxval_for_oov, ) cond = tf.expand_dims(tf.less(norm, 1e-10), -1) return tf.where(cond, random, self._tpu_features[key]) else: logging.vlog(2, "__embedding_lookup with no features exist.") if feature_slice not in feature_column._feature_slice_to_tf_placeholder: feature_column._feature_slice_to_tf_placeholder[ feature_slice] = tf.compat.v1.placeholder(tf.float32, [None, feature_slice.dim]) return feature_column._feature_slice_to_tf_placeholder[feature_slice] def _seq_embedding_lookup(self, feature_column, feature_slice, max_seq_length, init_minval_for_oov=None, init_maxval_for_oov=None): assert feature_column._feature_slot.slot_id( ) == feature_slice._feature_slot.slot_id() if self._tpu_features: logging.vlog(2, "__embedding_loopup with features exist.") key = "{}_{}".format(feature_column.fc_name, feature_slice.slice_index) if not self._use_random_init_embedding_for_oov or init_minval_for_oov is None: return self._tpu_features[key] norm = tf.norm(self._tpu_features[key], axis=1) random = tf.random.uniform( tf.shape(self._tpu_features[key]), minval=feature_slice.init_minval_for_oov, maxval=feature_slice.init_maxval_for_oov, ) cond = tf.expand_dims(tf.less(norm, 1e-10), -1) return tf.where(cond, random, self._tpu_features[key]) else: logging.vlog(2, "__embedding_lookup with no features exist.") if feature_slice not in feature_column._feature_slice_to_tf_placeholder: feature_column._feature_slice_to_tf_placeholder[ feature_slice] = tf.compat.v1.placeholder( tf.float32, [None, max_seq_length, feature_slice.dim]) return feature_column._feature_slice_to_tf_placeholder[feature_slice] def _size_tensor_lookup(self, feature_column): if self._tpu_features: key = "{}_0_row_lengths".format(feature_column.fc_name) row_lengths = self._tpu_features[key] # Convert row_lengths to [B, max_seq_length] Tensor, in which # the first row_length elements of each row are 1, and the rest are # 0. This is used as the size_tensor batch_size = tf.size(row_lengths) # 0-D Tensor boolean_mask = tf.less( tf.reshape( tf.tile(tf.range(0, feature_column.max_seq_length), [batch_size]), [batch_size, -1], ), tf.expand_dims(row_lengths, 1)) # [B, max_seq_length] Tensor return tf.cast(boolean_mask, tf.int32) else: return tf.compat.v1.placeholder( tf.float32, [None, feature_column.max_seq_length], '{}_size'.format(feature_column.fc_name), ) def finalize(self): """Finalize the env after slot to dims dict has been initialized.""" assert self._is_finalized == False, "Env can't be finalized more than once" self._is_finalized = True if self._merge_vector: self._merge_vector_in_same_slot() def _split_merged_embedding(self, feature_slot): """Split merged embedding into embedding splits of corresponding dim. Currently, this assumes all vector FeatureSlice embeddings are shared among all the FeatureColumns, so all vector FeatureSlice embeddings can be merged into one embedding. """ # Iterate through feature columns for feature_column in feature_slot.feature_columns: merged_embedding = None for merged_feature_slice in feature_column._merged_feature_slice_to_tf_placeholder: if merged_feature_slice.slice_index == 0 and feature_slot.has_bias(): # For bias, keep it as bias will not be merged. Nothing to do. assert merged_feature_slice.dim == 1, "Bias in {} must have dim equal to 1, but actual dim is {}.".format( feature_column.fc_name, merged_feature_slice.dim) else: merged_feature_name = "{}_{}".format(feature_column.fc_name, merged_feature_slice.slice_index) merged_embedding = self._tpu_features[merged_feature_name] # del self._tpu_features[merged_feature_name] if merged_embedding is not None: # Split embeddings will be written to the position starting from the previous merged embedding position. dim_splits = [ feature_slice.dim for feature_slice in feature_slot.feature_slices ] if feature_slot.has_bias(): dim_splits = [ feature_slice.dim for feature_slice in feature_slot.feature_slices ][1:] embedding_splits = tf.split(merged_embedding, dim_splits, axis=1) for feature_slice in feature_column._feature_slice_to_tf_placeholder: if feature_slice.slice_index == 0 and feature_slot.has_bias(): # For bias, keep it as bias will not be merged. Nothing to do. assert feature_slice.dim == 1, "Bias in {} must have dim equal to 1, but actual dim is {}.".format( feature_column.fc_name, feature_slice.dim) else: if merged_embedding is not None: if feature_slot.has_bias(): split_index = feature_slice.slice_index - 1 else: split_index = feature_slice.slice_index split = embedding_splits[split_index] self._tpu_features["{}_{}".format( feature_column.fc_name, feature_slice.slice_index)] = split def _merge_vector_in_same_slot(self): """Merge vectors in the same slot. Currently, this assumes all vector FeatureSlice embeddings are shared among all the FeatureColumns, so all vector FeatureSlice embeddings can be merged into one embedding. """ # TODO (long): Support the case where only a subset of FeatureSlices are # shared, so some vector FeatureSlices cannot be merged for slot_id, feature_slot in self._slot_id_to_feature_slot.items(): merged_vector_dim = 0 # Iterate through all FeatureSlices in FeatureSlot to calculate the merged # embedding dimension for feature_slice in feature_slot.feature_slices: # Bias will not be merged if feature_slot.has_bias() and feature_slice.slice_index == 0: assert feature_slice.dim == 1, "Bias in {} must have dim equal to 1, but actual dim is {}.".format( slot_id, feature_slice.dim) feature_slot._merged_feature_slices.append(feature_slice) for feature_column in feature_slot.feature_columns: feature_column._merged_feature_slice_to_tf_placeholder[ feature_slice] = feature_column._feature_slice_to_tf_placeholder[ feature_slice] else: merged_vector_dim += feature_slice.dim # Created a merged slice whose dim is the sum of the dims of all vector # FeatureSlices if merged_vector_dim > 0: # Create the merged FeatureSlice with the merged embedding dimension merged_feature_slice = feature_slot.add_merged_feature_slice( merged_vector_dim) # Add the merged FeatureSlice and its corresponding tf.placeholder to # each FeatureColumn in the FeatureSlot for feature_column in feature_slot.feature_columns: feature_column._merged_feature_slice_to_tf_placeholder[ merged_feature_slice] = tf.compat.v1.placeholder( tf.float32, [None, merged_feature_slice.dim]) @property def vocab_size_dict(self): return self._vocab_size_dict @property def slot_id_to_feature_slot(self): assert self.is_finalized, "is_finalized must be true which means _slot_id_to_feature_slot \ must be initialized before using this _slot_id_to_feature_slot" return self._slot_id_to_feature_slot @property def features(self): return self._tpu_features ================================================ FILE: monolith/core/feature_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import tensorflow as tf import monolith.core.hyperparams as _params from monolith.core.feature import FeatureSlot, FeatureColumnV1, Env class FeatureSlotTest(tf.test.TestCase): def test_has_bias(self): params = _params.Params() params.define('qr_multi_hashing', False, '') params.define('qr_hashing_threshold', 100000000, '') params.define('qr_collision_rate', 4, '') params.define('use_random_init_embedding_for_oov', False, '') params.define('merge_vector', False, '') env = Env({}, params) fs_1 = FeatureSlot(env=env, slot_id=1, has_bias=True) self.assertEqual(len(fs_1.feature_slices), 1) self.assertEqual(fs_1.feature_slices[0].dim, 1) self.assertEqual(fs_1.feature_slices[0].slice_index, 0) def test_add_feature_slice(self): params = _params.Params() params.define('qr_multi_hashing', False, '') params.define('qr_hashing_threshold', 100000000, '') params.define('qr_collision_rate', 4, '') params.define('use_random_init_embedding_for_oov', False, '') params.define('merge_vector', False, '') env = Env({}, params) fs_1 = FeatureSlot(env=env, slot_id=1, has_bias=True) fs_1.add_feature_slice(dim=10) self.assertEqual(len(fs_1.feature_slices), 2) self.assertEqual(fs_1.feature_slices[0].dim, 1) self.assertEqual(fs_1.feature_slices[0].slice_index, 0) self.assertEqual(fs_1.feature_slices[1].dim, 10) self.assertEqual(fs_1.feature_slices[1].slice_index, 1) class FeatureColumnV1Test(tf.test.TestCase): def test_add_feature_column(self): params = _params.Params() params.define('qr_multi_hashing', False, '') params.define('qr_hashing_threshold', 100000000, '') params.define('qr_collision_rate', 4, '') params.define('use_random_init_embedding_for_oov', False, '') params.define('merge_vector', False, '') env = Env({}, params) fs_1 = FeatureSlot(env=env, slot_id=1, has_bias=True) fs_1.add_feature_slice(dim=10) fc_1 = FeatureColumnV1(fs_1, 'fc_name_1') self.assertEqual(len(fs_1._feature_columns), 1) def test_merge_split_vector_in_same_slot(self): params = _params.Params() params.define('qr_multi_hashing', False, '') params.define('qr_hashing_threshold', 100000000, '') params.define('qr_collision_rate', 4, '') params.define('use_random_init_embedding_for_oov', False, '') params.define('merge_vector', True, '') env = Env({}, params) # Test merge logic. fs_1 = FeatureSlot(env=env, slot_id=1, has_bias=True) slice_1_1 = fs_1.add_feature_slice(dim=2) fs_2 = FeatureSlot(env=env, slot_id=2, has_bias=True) fs_3 = FeatureSlot(env=env, slot_id=3, has_bias=False) slice_3_0 = fs_3.add_feature_slice(dim=2) slice_3_1 = fs_3.add_feature_slice(dim=3) fs_4 = FeatureSlot(env=env, slot_id=4, has_bias=True) slice_4_1 = fs_4.add_feature_slice(dim=2) slice_4_2 = fs_4.add_feature_slice(dim=3) slice_4_3 = fs_4.add_feature_slice(dim=4) fc_1 = FeatureColumnV1(fs_1, 'fc_name_1') fc_1.embedding_lookup(slice_1_1) fc_2 = FeatureColumnV1(fs_2, 'fc_name_2') fc_3 = FeatureColumnV1(fs_3, 'fc_name_3') fc_3.embedding_lookup(slice_3_0) fc_3.embedding_lookup(slice_3_1) fc_4 = FeatureColumnV1(fs_4, 'fc_name_4') fc_4.embedding_lookup(slice_4_1) fc_4.embedding_lookup(slice_4_2) fc_4.embedding_lookup(slice_4_3) fc_5 = FeatureColumnV1(fs_4, 'fc_name_5') fc_5.embedding_lookup(slice_4_1) fc_5.embedding_lookup(slice_4_2) fc_5.embedding_lookup(slice_4_3) env._merge_vector_in_same_slot() # Check the length of merged feature slices in FeatureSlot self.assertEqual(len(fs_1._merged_feature_slices), 2) self.assertEqual(len(fs_2._merged_feature_slices), 1) self.assertEqual(len(fs_3._merged_feature_slices), 1) self.assertEqual(len(fs_4._merged_feature_slices), 2) # Check the dim of each merged feature slice in FeatureSlot self.assertEqual(fs_1._merged_feature_slices[0].dim, 1) self.assertEqual(fs_1._merged_feature_slices[1].dim, 2) self.assertEqual(fs_2._merged_feature_slices[0].dim, 1) self.assertEqual(fs_3._merged_feature_slices[0].dim, 5) self.assertEqual(fs_4._merged_feature_slices[0].dim, 1) self.assertEqual(fs_4._merged_feature_slices[1].dim, 9) # Check the dim of each merged feature slice in FeatureColumn self.assertTrue(fs_1._merged_feature_slices[0] in fc_1._merged_feature_slice_to_tf_placeholder) self.assertTrue(fs_1._merged_feature_slices[1] in fc_1._merged_feature_slice_to_tf_placeholder) self.assertTrue(fs_2._merged_feature_slices[0] in fc_2._merged_feature_slice_to_tf_placeholder) self.assertTrue(fs_3._merged_feature_slices[0] in fc_3._merged_feature_slice_to_tf_placeholder) self.assertTrue(fs_4._merged_feature_slices[0] in fc_4._merged_feature_slice_to_tf_placeholder) self.assertTrue(fs_4._merged_feature_slices[1] in fc_4._merged_feature_slice_to_tf_placeholder) self.assertTrue(fs_4._merged_feature_slices[0] in fc_5._merged_feature_slice_to_tf_placeholder) self.assertTrue(fs_4._merged_feature_slices[1] in fc_5._merged_feature_slice_to_tf_placeholder) # Test split logic env._tpu_features = {} env._tpu_features["fc_name_1_0"] = tf.constant([[1]]) env._tpu_features["fc_name_1_1"] = tf.constant([[2, 3]]) env._tpu_features["fc_name_2_0"] = tf.constant([[4]]) env._tpu_features["fc_name_3_0"] = tf.constant([[7, 8, 9, 10, 11]]) env._tpu_features["fc_name_4_0"] = tf.constant([[12]]) env._tpu_features["fc_name_4_1"] = tf.constant( [[13, 14, 15, 16, 17, 18, 19, 20, 21]]) env._tpu_features["fc_name_5_0"] = tf.constant([[12]]) env._tpu_features["fc_name_5_1"] = tf.constant( [[13, 14, 15, 16, 17, 18, 19, 20, 21]]) with tf.compat.v1.Session() as sess: env._split_merged_embedding(fs_1) env._split_merged_embedding(fs_2) env._split_merged_embedding(fs_3) env._split_merged_embedding(fs_4) features = sess.run(env._tpu_features) self.assertAllEqual(features["fc_name_1_0"], [[1]]) self.assertAllEqual(features["fc_name_1_1"], [[2, 3]]) self.assertAllEqual(features["fc_name_2_0"], [[4]]) self.assertAllEqual(features["fc_name_3_0"], [[7, 8]]) self.assertAllEqual(features["fc_name_3_1"], [[9, 10, 11]]) self.assertAllEqual(features["fc_name_4_0"], [[12]]) self.assertAllEqual(features["fc_name_4_1"], [[13, 14]]) self.assertAllEqual(features["fc_name_4_2"], [[15, 16, 17]]) self.assertAllEqual(features["fc_name_4_3"], [[18, 19, 20, 21]]) self.assertAllEqual(features["fc_name_5_0"], [[12]]) self.assertAllEqual(features["fc_name_5_1"], [[13, 14]]) self.assertAllEqual(features["fc_name_5_2"], [[15, 16, 17]]) self.assertAllEqual(features["fc_name_5_3"], [[18, 19, 20, 21]]) if __name__ == '__main__': tf.compat.v1.disable_eager_execution() tf.test.main() ================================================ FILE: monolith/core/host_call.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import sys from absl import logging import json import numpy as np import tensorflow as tf _LABLES_FOR_AUC_CALCULATION = "labels_for_auc_calculation" _Y_PRED_FOR_AUC_CALCULATION = "y_pred_for_auc_calculation" _REQ_TIME = "req_time" _SAMPLE_RATE = "sample_rate" _DEEPINSIGHT_SAMPLE_RATES = "di_example_sample_rates" _DEEPINSIGHT_LABELS = "di_labels" _DEEPINSIGHT_PREDS = "di_preds" _DEEPINSIGHT_REQ_TIMES = "di_req_times" class HostCall(): def __init__(self, output_dir, enable_host_call, enable_deepinsight): self._output_dir = output_dir self._enable_host_call = enable_host_call self._enable_deepinsight = enable_deepinsight self._tensor_names = ["global_step"] gs = tf.compat.v1.train.get_global_step() # Creating batch dimension since host call needs to concat all the cores' # results. self._tensors = [tf.reshape(gs, [-1])] # compressed_tensor_list is a list of compressed tensors, in # which each compressed tensor is the concatenation of multiple # uncompressed tensors. To decompress, we need to store the original # sizes of all uncompressed tensors, organized in list of lists. # Each list corresponds to a compressed tensors. self._lists_tensor_sizes = [] # A list of lists def record_summary_tensor(self, name, tensor): assert name not in self._tensor_names self._tensor_names.append(name) assert len(tensor.get_shape()) <= 1, "Now we only support tensor with shape (k, ) or ()"\ "but we met tensor with shape: {}".format(tensor.get_shape()) # Creating batch dimension since host call needs to concat all the cores' # results. reshaped_tensor = tf.reshape(tensor, [-1]) self._tensors.append(reshaped_tensor) def compress_tensors(self): """For n tensors with shape (k_i, ) and same data type, concat them on axis=1. After concatenation the compressed tensors is stored as in shape (1, k_0 + k_1 + ... + k_{n-1}). """ assert len(self._tensor_names) == len(self._tensors), "tensor_names and tensors must have same length," \ "tensor_names length: {}, tensors length: {}".format(len(self._tensor_names), len(self._tensors)) # key is tensor data type, value is a list of tensor names with this data type. data_type_to_tensor_names = {} # key is tensor data type, value is tensor a list of tensors with this data type. date_type_to_tensors = {} # Group tensor names and tensors by same data type. for tensor_name, tensor in zip(self._tensor_names, self._tensors): data_type_to_tensor_names.setdefault(tensor.dtype, []).append(tensor_name) date_type_to_tensors.setdefault(tensor.dtype, []).append(tensor) # Compress n tensors tensor_0, tensor_1, ... , tensor_{n-1} with shape # (k_0, ), (k_1, )... (k_{n-1}, ) of same data # type to one tensor with shape (1, k_0 + ... + k_{n-1}) compressed_tensor_name_list = [] compressed_tensor_list = [] for data_type, tensor_list in date_type_to_tensors.items(): compressed_tensor_name_list.extend(data_type_to_tensor_names[data_type]) # concat a list of tensors with shapes # (k_0, ), (k_1, )... (k_{n-1}, ) to a tensor with shape # (k_0 + k_1 + ... + k_{n-1}, ) tensor_sizes = [] for tensor in tensor_list: tensor_sizes.append(tensor.shape[0].value) self._lists_tensor_sizes.append(tensor_sizes) compressed_tensor = tf.concat(tensor_list, axis=0) # expand dimension at 0 to make it have the batch dimension # tensor with shape (k_0 + k_1 + ... + k_{n-1}, ) # => tensor with shape (1, k_0 + k_1 + ... + k_{n-1}) compressed_tensor = tf.expand_dims(compressed_tensor, axis=0) compressed_tensor_list.append(compressed_tensor) logging.info( "Host call compressed tensors, data type: {}, compressed tensor shape: {}" .format(data_type, compressed_tensor.shape)) self._tensor_names = compressed_tensor_name_list self._tensors = compressed_tensor_list def decompress_tensors(self, tensors): """ Decompress the compressed tensors into list of decompressed tensors. Given a list of compressed tensors from *args. Each tensor has shape (num_cores, k_0 + k_1 + ... + k_{n-1}), in which the second dimension is the sum of lengths of uncompressed tensors (k_i, ) from the same core, with same shape and data type. Parse and convert them to a list of decompressed tensors. Each decompressed tensor has shape (num_cores, k_i). Decompressed tensor number must match with number of tensor names as well. """ # Need decompress tensors decompressed_tensor_list = [] for index, compressed_tensor in enumerate(tensors): # For each tensor, its shape is (num_cores, k_0 + ... + k_{n-1}) assert len(compressed_tensor.get_shape( )) == 2, "Compressed tensors shape must be (n, m), met shape: {}".format( compressed_tensor.shape) logging.info("Decompressed tensors shape: {}, dtype: {}.".format( compressed_tensor.shape, compressed_tensor.dtype)) # tensor with shape (num_cores, k_0 + ... + k_{n-1}) # => list of tensors with shape (num_cores, k_i) split_tensors = tf.split(compressed_tensor, self._lists_tensor_sizes[index], axis=1) for tensor in split_tensors: # Each decompressed tensor with shape (num_cores, k_i) # => (num_cores * k_i, ). tensor = tf.squeeze(tensor) decompressed_tensor_list.append(tensor) assert self._tensor_names[ 0] == "global_step", "The first tensor name must be global_step, met value: {}".format( self._tensor_names[0][0]) return decompressed_tensor_list[0][0], decompressed_tensor_list def _verify_shape_and_dtype(self, tensor, shape_list, dtype): assert tensor is not None assert tensor.shape.as_list( ) == shape_list, "Expect shape: {}, but actual shape: {}".format( shape_list, tensor.shape.as_list()) assert tensor.dtype == dtype, "Expect dtype {}, but actual dtype: {}".format( dtype, tensor.dtype) def _serialize_messages(self, labels, y_preds, sample_rates, req_times, gs): assert labels is not None labels_shape = labels.shape.as_list() assert len(labels_shape ) == 2, "Expect labels_shape to be 1, but its shape is {}".format( labels_shape) self._verify_shape_and_dtype(y_preds, labels_shape, tf.float32) self._verify_shape_and_dtype(sample_rates, labels_shape, tf.float32) self._verify_shape_and_dtype(req_times, labels_shape, tf.int64) # reshape is low cost without real data copy. # flatten the tensor here and simplify the data format before serializing to string. # Each tensor has shape (n, ), n equals to core_number * batch_size_per_core labels = tf.reshape(labels, [-1]) y_preds = tf.reshape(y_preds, [-1]) sample_rates = tf.reshape(sample_rates, [-1]) req_times = tf.reshape(req_times, [-1]) # The first two model names and di sample rates can be get from host_call folder suffix tf.compat.v1.summary.text(_DEEPINSIGHT_SAMPLE_RATES, data=tf.io.serialize_tensor(sample_rates), step=gs) tf.compat.v1.summary.text(_DEEPINSIGHT_LABELS, data=tf.io.serialize_tensor(labels), step=gs) tf.compat.v1.summary.text(_DEEPINSIGHT_PREDS, data=tf.io.serialize_tensor(y_preds), step=gs) tf.compat.v1.summary.text(_DEEPINSIGHT_REQ_TIMES, data=tf.io.serialize_tensor(req_times), step=gs) def generate_host_call_hook(self): def _host_call(*args): gs, tensors = self.decompress_tensors(args) summary_writer = tf.compat.v1.summary.create_file_writer( self._output_dir + "/host_call", flush_millis=10000, max_queue=5000) with summary_writer.as_default(): labels = None y_preds = None req_times = None sample_rates = None for i, t in enumerate(tensors): if i == 0: continue name = self._tensor_names[i] data = None if "_avg" in name: data = tf.reduce_mean(t) elif "_max" in name: data = tf.reduce_max(t) elif _LABLES_FOR_AUC_CALCULATION in name: labels = t elif _Y_PRED_FOR_AUC_CALCULATION in name: y_preds = t elif _REQ_TIME in name: req_times = t elif _SAMPLE_RATE in name: sample_rates = t else: data = t[0] if data is not None: tf.compat.v1.summary.scalar(name, data=data, step=gs) if labels is not None and y_preds is not None: auc, auc_op = tf.compat.v1.metrics.auc(labels=labels, predictions=y_preds) tf.compat.v1.summary.scalar("auc", data=auc, step=gs) else: auc_op = None if self._enable_deepinsight is True and labels is not None: messages = self._serialize_messages(labels, y_preds, sample_rates, req_times, gs) if auc_op is not None: return tf.group(tf.compat.v1.summary.all_v2_summary_ops(), auc_op) else: return tf.compat.v1.summary.all_v2_summary_ops() if self._enable_host_call == True: self.compress_tensors() return (_host_call, self._tensors) else: logging.info("host_call has been disabled") return None ================================================ FILE: monolith/core/hyperparams.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Defines Params base class, used for defining class/function parameters.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import copy import re import six from inspect import Parameter, signature import tensorflow as tf def _is_named_tuple(x): """Returns whether an object is an instance of a collections.namedtuple. Examples:: _is_named_tuple((42, 'hi')) ==> False Foo = collections.namedtuple('Foo', ['a', 'b']) _is_named_tuple(Foo(a=42, b='hi')) ==> True Args: x: The object to check. """ return isinstance(x, tuple) and hasattr(x, '_fields') class _SortedDict(dict): """A dict with a __repr__ that is always sorted by key.""" def __repr__(self): return '{' + ', '.join( '%r: %r' % item for item in sorted(self.items())) + '}' class _Param(object): """Stores data for a single parameter.""" def __init__(self, name, default_value, description): self._name = name self._value = default_value self._description = description def __eq__(self, other): # pylint: disable=protected-access return self._name == other._name and self._value == other._value # Deep copy the value only if it is supported. def __deepcopy__(self, memo): if isinstance(self._value, (tf.Tensor)): # In case self._value is a tensor, let's just make a reference. value = self._value else: value = copy.deepcopy(self._value, memo) p = _Param(self._name, value, self._description) # Q(yonghui): Is this the right use of memo. memo[id(self)] = p return p def to_string(self, nested_depth): """Prints the parameter as a string.""" def GetRepr(val): """Get the representation of `val`.""" if isinstance(val, Params): return _SortedDict({k: GetRepr(v) for k, v in val.iter_params()}) if isinstance(val, dict): return _SortedDict({k: GetRepr(v) for k, v in six.iteritems(val)}) if isinstance(val, (list, tuple)) and not _is_named_tuple(val): # NB: this constructor signature works for tuples, but not namedtuples. return type(val)([GetRepr(v) for v in val]) # NOTE(markmurphy): I introduced Repr() because it's impossible (afaik) to # overwrite the __str__ or __repr__ method of a types.FunctionType object. if hasattr(val, 'Repr'): return val.Repr() return val nested_indent = ' ' * nested_depth if isinstance(self._value, Params): # pylint: disable=protected-access value_str = self._value._to_string(nested_depth) elif isinstance(self._value, six.string_types): return '%s%s: "%s"' % (nested_indent, self._name, self._value) else: value_str = str(GetRepr(self._value)) return '%s%s: %s' % (nested_indent, self._name, value_str) def set(self, value): # Note that we don't make a copy of Params objects. # TODO(sadovsky): Maybe add safeguard to ensure that Params object is not # owned by other Params objects. self._value = value def get(self): return self._value def copy_params_to(from_p, to_p, skip=None): """Copy from one Params to another, with optional skipped params. Args: from_p: Source params to copy from. to_p: Destination params to copy to. skip: If not None, a list of strings of param names to skip. Returns: None """ for n, p in from_p.iter_params(): if skip and n in skip: continue if isinstance(p, Params): to_p.set(**{n: p.copy()}) else: to_p.set(**{n: p}) return to_p class Params(object): """Stores data for a set of parameters. Provides attribute-based API, e.g. "params.foo = 5". Uses internal {'name': Params} dict for storing parameter data. """ def __init__(self): self.__dict__['_immutable'] = False self._params = {} # name => _Param def __setattr__(self, name, value): if self._immutable: raise TypeError('This Params instance is immutable.') if name == '_params' or name == '_immutable': self.__dict__[name] = value else: try: self._params[name].set(value) except KeyError: raise AttributeError(self._key_error_string(name)) def __getattr__(self, name): if name == '_params' or name == '_immutable': return self.__dict__[name] try: return self._params[name].get() except KeyError: # cPickle expects __getattr__ to raise AttributeError, not KeyError. raise AttributeError(self._key_error_string(name)) def __setitem__(self, name, value): self.__setattr__(name, value) def __getitem__(self, key): return self.__getattr__(key) def __dir__(self): return sorted(self._params.keys()) def __contains__(self, name): return name in self._params def __len__(self): return len(self._params) # Note: This gets called by Params.__eq__() on nested Params objects. def __eq__(self, other): return isinstance(other, Params) and self._params == other._params # pylint: disable=protected-access def __ne__(self, other): return not self == other def __str__(self): return self._to_string(0) def _to_string(self, nested_depth): # Note: We use iteritems() below so as to sort by name. sorted_param_strs = [ v.to_string(nested_depth + 1) for (_, v) in sorted(six.iteritems(self._params)) ] nested_indent = ' ' * nested_depth return '{\n%s\n%s}' % ('\n'.join(sorted_param_strs), nested_indent) # Override __deepcopy__ so that copy.deepcopy(self._params) properly # deep-copies nested Params objects. # TODO(sadovsky): Is it okay not to touch memo? def __deepcopy__(self, unused_memo): return self.copy() def _similar_keys(self, name): """Return a list of params keys that are similar to name.""" def _overlaps(name, key): """The fraction of 3-char substrings in that appear in key.""" matches = 0 trials = 0 for i in range(len(name) - 3): trials += 1 if name[i:i + 3] in key: matches += 1 if trials: return float(matches) / trials return 0 if '_params' in self.__dict__: return [key for key in self._params if _overlaps(name, key) > 0.5] return [] def _key_error_string(self, name): similar = self._similar_keys(name) if similar: return name + ' (did you mean: [%s])' % (','.join(sorted(similar))) return name def copy(self): return self._copy_to(type(self)()) def _copy_to(self, res): # pylint: disable=protected-access res._params = copy.deepcopy(self._params) res._immutable = self._immutable # pylint: enable=protected-access return res # TODO(sadovsky): # - Maybe let users specify whether this parameter is allowed to have # value=None, and if not, assert on get(), like required proto field. # - Maybe enforce that value is one of # {number, string, bool, list, dict, Params}. def define(self, name, default_value, description): """Defines a parameter. Args: name: The parameter name. Must only contain lowercase letters, numbers, and underscores. Must start with lowercase letter. default_value: Default value for this parameter. May be None. description: String description of this parameter. Raises: AttributeError: If parameter 'name' is already defined. """ if self._immutable: raise TypeError('This Params instance is immutable.') assert name is not None and isinstance( name, six.string_types) and (re.match('^[a-z][a-z0-9_]*$', name) is not None) if name in self._params: raise AttributeError('Parameter %s is already defined' % name) self._params[name] = _Param(name, default_value, description) def contain(self, name): return name in self._params def freeze(self): """Marks this Params as immutable.""" self._immutable = True def is_immutable(self): """Return whether this Params is immutable.""" return self._immutable def _get_nested(self, name): """Returns nested param by its name.""" parts = name.split('.') curr = self for i, part in enumerate(parts[:-1]): # get the value (nested Params object) associated with name 'part'. try: is_list = re.match(r'^(.+)\[(.+)\]$', part) if is_list: part = is_list.group(1) list_index = int(is_list.group(2)) # pylint: disable=protected-access curr = curr._params[part].get() if is_list: curr = curr[list_index] except KeyError: raise AttributeError('.'.join(parts[:i + 1])) assert isinstance(curr, Params), ('Cannot introspect %s for %s' % (type(curr), '.'.join(parts[:i + 1]))) return curr, parts[-1] def set(self, **kwargs): """Sets multiple parameters. Dots in names indicate navigation into nested Params objects. We do not allow navigation into lists or dicts, and may ban these types altogether in favor of string representations. Args: **kwargs: Name-value pairs to set. Returns: self """ if self._immutable: raise TypeError('This Params instance is immutable: %s' % self) for name, value in six.iteritems(kwargs): # get nested param. param, key = self._get_nested(name) # Update the value associated with key. try: # pylint: disable=protected-access param._params[key].set(value) except KeyError: raise AttributeError(self._key_error_string(name)) return self def get(self, name): """get parameter. Dots in names indicate navigation into nested Params objects. We do not allow navigation into lists or dicts, and may ban these types altogether in favor of string representations. Args: name: (str) Name. Returns: value. Raises: AttributeError: if parameter is not found """ param, key = self._get_nested(name) # get the value associated with key. try: # pylint: disable=protected-access return param._params[key].get() except KeyError: raise AttributeError(self._key_error_string(name)) def delete(self, *args): """Deletes multiple parameters. Dots in names indicate navigation into nested Params objects. We do not allow navigation into lists or dicts, and may ban these types altogether in favor of string representations. Args: *args: List of names. Returns: self """ if self._immutable: raise TypeError('This Params instance is immutable.') for name in args: # get nested param. param, key = self._get_nested(name) # delete the key. try: # pylint: disable=protected-access del param._params[key] except KeyError: raise AttributeError(self._key_error_string(name)) return self def iter_params(self): """Pythonic dict-like iteration.""" for name, param in six.iteritems(self._params): yield (name, param.get()) allowed_kwargs = { 'input_dim', 'input_shape', 'batch_input_shape', 'weights', 'activity_regularizer', 'autocast', 'implementation', 'name' } def _inverted_index(ips: 'InstantiableParams', idx): for name, item in ips.iter_params(): if isinstance(item, Params): _inverted_index(item, idx) else: idx[name] = item class InstantiableParams(Params): """Params which can be instantiated. When using InstantiableParams, callers must provide a class which supports initialization using a Params instance. This covers a common use case of Params to hold a configuration for a given class. """ def __init__(self, cls=None): super(InstantiableParams, self).__init__() self.define('cls', cls, 'Cls that this param object is associated with.') def instantiate(self): """instantiate an instance that this Params is configured for.""" assert self.cls is not None # The class initializer is expected to support initialization using Params. parameters = signature(self.cls.__init__).parameters if len(parameters) == 2 and hasattr(self.cls, 'params') and 'params' in parameters: return self.cls(self) else: index, args = {}, {} _inverted_index(self, index) for name, p in parameters.items(): if p.kind in {Parameter.VAR_KEYWORD, Parameter.VAR_POSITIONAL}: continue if name not in {'self', 'cls'} and name in index: args[name] = index[name] for key in allowed_kwargs: if key in index and index[key] is not None: args[key] = index[key] return self.cls(**args) def copy(self): return self._copy_to(type(self)(self.cls)) def update_params(ips: Params, args): for key, value in ips.iter_params(): if isinstance(value, Params): update_params(value, args) else: if key in args: ips[key] = args.pop(key) ================================================ FILE: monolith/core/hyperparams_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import absolute_import from __future__ import division from __future__ import print_function import collections import enum import functools import re import sys from absl import app as absl_app from absl import flags from absl import logging import tensorflow as tf import unittest import monolith.core.hyperparams as _params class TestEnum(enum.Enum): """Test enum class.""" A = 1 B = 2 class ParamsTest(unittest.TestCase): def test_equals(self): params1 = _params.Params() params2 = _params.Params() self.assertTrue(params1 == params2) params1.define('first', 'firstvalue', '') self.assertFalse(params1 == params2) params2.define('first', 'firstvalue', '') self.assertTrue(params1 == params2) some_object = object() other_object = object() params1.define('second', some_object, '') params2.define('second', other_object, '') self.assertFalse(params1 == params2) params2.second = some_object self.assertTrue(params1 == params2) params1.define('third', _params.Params(), '') params2.define('third', _params.Params(), '') self.assertTrue(params1 == params2) params1.third.define('fourth', 'x', '') params2.third.define('fourth', 'y', '') self.assertFalse(params1 == params2) params2.third.fourth = 'x' self.assertTrue(params1 == params2) # Comparing params to non-param instances. self.assertFalse(params1 == 3) self.assertFalse(3 == params1) def test_deep_copy(self): inner = _params.Params() inner.define('alpha', 2, '') inner.define('tensor', tf.constant(0), '') outer = _params.Params() outer.define('beta', 1, '') outer.define('inner', inner, '') outer_copy = outer.copy() self.assertIsNot(outer, outer_copy) self.assertEqual(outer, outer_copy) self.assertIsNot(outer.inner, outer_copy.inner) self.assertEqual(outer.inner, outer_copy.inner) self.assertEqual(outer.inner.alpha, outer_copy.inner.alpha) self.assertIs(outer.inner.tensor, outer_copy.inner.tensor) def test_copy_params_to(self): source = _params.Params() dest = _params.Params() source.define('a', 'a', '') source.define('b', 'b', '') source.define('c', 'c', '') dest.define('a', '', '') _params.copy_params_to(source, dest, skip=['b', 'c']) self.assertEqual(source.a, dest.a) self.assertNotIn('b', dest) self.assertNotIn('c', dest) def test_define_existing(self): p = _params.Params() p.define('foo', 1, '') self.assertRaisesRegex(AttributeError, 'already defined', lambda: p.define('foo', 1, '')) def test_legal_param_names(self): p = _params.Params() self.assertRaises(AssertionError, lambda: p.define(None, 1, '')) self.assertRaises(AssertionError, lambda: p.define('', 1, '')) self.assertRaises(AssertionError, lambda: p.define('_foo', 1, '')) self.assertRaises(AssertionError, lambda: p.define('Foo', 1, '')) self.assertRaises(AssertionError, lambda: p.define('1foo', 1, '')) self.assertRaises(AssertionError, lambda: p.define('foo$', 1, '')) p.define('foo_bar', 1, '') p.define('foo9', 1, '') def test_set_and_get(self): p = _params.Params() self.assertRaisesRegex(AttributeError, 'foo', lambda: p.set(foo=4)) # We use setattr() because lambda cannot contain explicit assignment. self.assertRaisesRegex(AttributeError, 'foo', lambda: setattr(p, 'foo', 4)) p.define('foo', 1, '') self.assertEqual(p.foo, 1) self.assertEqual(p.get('foo'), 1) self.assertIn('foo', p) self.assertNotIn('bar', p) p.set(foo=2) self.assertEqual(p.foo, 2) self.assertEqual(p.get('foo'), 2) p.foo = 3 self.assertEqual(p.foo, 3) self.assertEqual(p.get('foo'), 3) p.delete('foo') self.assertNotIn('foo', p) self.assertNotIn('bar', p) self.assertRaisesRegex(AttributeError, 'foo', lambda: p.foo) self.assertRaisesRegex(AttributeError, 'foo', p.get, 'foo') def test_set_and_get_nested_param(self): innermost = _params.Params() innermost.define('delta', 22, '') innermost.define('zeta', 5, '') inner = _params.Params() inner.define('alpha', 2, '') inner.define('innermost', innermost, '') outer = _params.Params() outer.define('beta', 1, '') outer.define('inner', inner, '') outer.define('d', dict(foo='bar'), '') self.assertEqual(inner.alpha, 2) self.assertEqual(outer.beta, 1) self.assertEqual(outer.d['foo'], 'bar') self.assertEqual(outer.inner.alpha, 2) self.assertEqual(outer.inner.innermost.delta, 22) self.assertEqual(outer.inner.innermost.zeta, 5) self.assertEqual(inner.get('alpha'), 2) self.assertEqual(outer.get('beta'), 1) self.assertEqual(outer.get('d')['foo'], 'bar') self.assertEqual(outer.get('inner.alpha'), 2) self.assertEqual(outer.get('inner.innermost.delta'), 22) self.assertEqual(outer.get('inner.innermost.zeta'), 5) outer.set(**{'inner.alpha': 3}) outer.set(d=dict(foo='baq')) outer.delete('beta') outer.delete('inner.innermost.zeta') self.assertEqual(inner.alpha, 3) self.assertRaisesRegex(AttributeError, 'beta', lambda: outer.beta) self.assertEqual(outer.d['foo'], 'baq') self.assertEqual(outer.inner.alpha, 3) self.assertEqual(outer.inner.innermost.delta, 22) self.assertRaisesRegex(AttributeError, 'zeta', lambda: outer.inner.innermost.zeta) self.assertEqual(inner.get('alpha'), 3) self.assertRaisesRegex(AttributeError, 'beta', outer.get, 'beta') self.assertEqual(outer.get('d')['foo'], 'baq') self.assertEqual(outer.get('inner.alpha'), 3) self.assertEqual(outer.get('inner.innermost.delta'), 22) self.assertRaisesRegex(AttributeError, 'inner.innermost.zeta', outer.get, 'inner.innermost.zeta') # NOTE(igushev): Finding nested Param object is shared between Get, Set and # Delete, so we test only Set. self.assertRaisesRegex(AttributeError, r'inner\.gamma', lambda: outer.set(**{'inner.gamma': 5})) self.assertRaisesRegex(AttributeError, r'inner\.innermost\.bad', lambda: outer.set(**{'inner.innermost.bad': 5})) self.assertRaisesRegex(AssertionError, '^Cannot introspect', lambda: outer.set(**{'d.foo': 'baz'})) def test_freeze(self): p = _params.Params() self.assertRaises(AssertionError, lambda: p.define('_immutable', 1, '')) self.assertRaisesRegex(AttributeError, 'foo', lambda: p.set(foo=4)) # We use setattr() because lambda cannot contain explicit assignment. self.assertRaisesRegex(AttributeError, 'foo', lambda: setattr(p, 'foo', 4)) p.define('foo', 1, '') p.define('nested', p.copy(), '') self.assertEqual(p.foo, 1) self.assertEqual(p.get('foo'), 1) self.assertEqual(p.nested.foo, 1) p.freeze() self.assertRaises(TypeError, lambda: p.set(foo=2)) self.assertEqual(p.get('foo'), 1) self.assertRaises(TypeError, lambda: setattr(p, 'foo', 3)) self.assertEqual(p.foo, 1) self.assertRaises(TypeError, lambda: p.delete('foo')) self.assertEqual(p.foo, 1) self.assertRaises(TypeError, lambda: p.define('bar', 1, '')) self.assertRaisesRegex(AttributeError, 'bar', p.get, 'bar') p.nested.foo = 2 self.assertEqual(p.foo, 1) self.assertEqual(p.nested.foo, 2) self.assertRaises(TypeError, lambda: setattr(p, '_immutable', False)) # Copies are still immutable. q = p.copy() self.assertRaises(TypeError, lambda: q.set(foo=2)) def test_to_string(self): outer = _params.Params() outer.define('foo', 1, '') inner = _params.Params() inner.define('bar', 2, '') outer.define('inner', inner, '') outer.define('list', [1, inner, 2], '') outer.define('dict', {'a': 1, 'b': inner}, '') outer.define('enum', TestEnum.B, '') self.assertEqual( '\n' + str(outer), """ { dict: {'a': 1, 'b': {'bar': 2}} enum: TestEnum.B foo: 1 inner: { bar: 2 } list: [1, {'bar': 2}, 2] }""") def test_iter_params(self): keys, values = ['a', 'b', 'c', 'd', 'e'], [True, None, 'zippidy', 78.5, 5] p = _params.Params() for k, v in zip(keys, values): p.define(k, v, 'description of %s' % k) k_set, v_set = set(keys), set(values) number_of_params = 0 for k, v in p.iter_params(): self.assertIn(k, k_set) self.assertIn(v, v_set) number_of_params += 1 self.assertEqual(number_of_params, len(keys)) def test_similar_keys(self): p = _params.Params() p.define('activation', 'RELU', 'Can be a string or a list of strings.') p.define('activations', 'RELU', 'Many activations.') p.define('cheesecake', None, 'dessert') p.define('tofu', None, 'not dessert') def set_param(): p.actuvation = 1 self.assertRaisesRegexp( AttributeError, re.escape('actuvation (did you mean: [activation,activations])'), set_param) if __name__ == "__main__": unittest.main(verbosity=2) ================================================ FILE: monolith/core/mixed_emb_op_comb_nws.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # -*- encoding=utf-8 -*- # Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== from __future__ import absolute_import from __future__ import division from __future__ import print_function import numpy as np import tensorflow as tf from tensorflow.keras.layers import Layer, InputSpec class TeacherEmbeddingTransform(Layer): """Combined. Example:: # as first layer in a sequential model: # x is a compatible tensor x = layers.Dense(32, input_shape=(16,))(x) # now the model will take as input arrays of shape (*, 16) # and output arrays of shape (*, 32) Args: units: Positive integer, dimensionality of the output space. activation: Activation function to use If you don't specify anything, no activation is applied (ie. "linear" activation: `a(x) = x`). use_bias: Boolean, whether the layer uses a bias vector. kernel_initializer: Initializer for the `kernel` weights matrix bias_initializer: Initializer for the bias vector allow_kernel_norm: T/F kernel normalization is only applicable when TRAINING kernel_normalization_trainable: If True, a trainable weight norm variable is allocated Input shapes: nD tensor with shape: `(batch_size, ..., input_dim)`. The most common situation would be a 2D input with shape `(batch_size, input_dim)`. Output shapes: nD tensor with shape: `(batch_size, ..., units)`. For instance, for a 2D input with shape `(batch_size, input_dim)`, the output would have shape `(batch_size, units)`. """ def __init__(self, max_choice_per_embedding, teacher_embedding_sizes_list, **kwargs): super(TeacherEmbeddingTransform, self).__init__(**kwargs) assert len(max_choice_per_embedding) == len(teacher_embedding_sizes_list) self._max_choice_per_embedding = np.array(max_choice_per_embedding) self._teacher_embedding_sizes_list = np.array(teacher_embedding_sizes_list) self.input_spec = InputSpec(ndim=2) def build(self, input_shape): assert len(input_shape) == 2 input_dim = input_shape[-1] assert input_dim == np.sum(self._teacher_embedding_sizes_list) total_teacher_embedding_transform_weight_size = np.sum( self._max_choice_per_embedding * self._teacher_embedding_sizes_list) self._teacher_embedding_transform_weight = self.add_weight( initial_value=tf.keras.initializers.TruncatedNormal(stddev=0.15)( [total_teacher_embedding_transform_weight_size, 1], self.dtype), name='teacher_embedding_transform_weight') self._snapshot_for_serving(self._teacher_embedding_transform_weight, 'teacher_embedding_transform_weight') self._teacher_embedding_transform_bias = self.add_weight( initial_value=tf.keras.initializers.Zeros()( [np.sum(self._max_choice_per_embedding)], self.dtype), name='teacher_embedding_transform_bias') self._snapshot_for_serving(self._teacher_embedding_transform_bias, 'teacher_embedding_transform_bias') self.built = True def call(self, inputs): teacher_embedding = inputs current_weight_idx = 0 current_teacher_idx = 0 teacher_transformed = [] for i in range(self._max_choice_per_embedding.shape[0]): teacher_embedding_slice = teacher_embedding[:, current_teacher_idx: current_teacher_idx + self. _teacher_embedding_sizes_list[ i]] transform_weight_slice = self._teacher_embedding_transform_weight[ current_weight_idx:current_weight_idx + self._teacher_embedding_sizes_list[i] * self._max_choice_per_embedding[i]] teacher_transformed.append( tf.matmul( teacher_embedding_slice, tf.reshape(transform_weight_slice, [ self._teacher_embedding_sizes_list[i], self._max_choice_per_embedding[i] ]))) current_weight_idx += self._teacher_embedding_sizes_list[ i] * self._max_choice_per_embedding[i] current_teacher_idx += self._teacher_embedding_sizes_list[i] return tf.concat(teacher_transformed, axis=1) + self._teacher_embedding_transform_bias def compute_output_shape(self, input_shape): raise NotImplementedError("I don't think I need to implement this one.") def get_config(self): config = { 'max_choice_per_embedding': self._max_choice_per_embedding, 'teacher_embedding_sizes_list': self._teacher_embedding_sizes_list } base_config = super(TeacherEmbeddingTransform, self).get_config() return dict(list(base_config.items()) + list(config.items())) class MixedEmbedOpComb(Layer): """Combined. Example:: # as first layer in a sequential model: # x is a compatible tensor x = layers.Dense(32, input_shape=(16,))(x) # now the model will take as input arrays of shape (*, 16) # and output arrays of shape (*, 32) Args: units: Positive integer, dimensionality of the output space. activation: Activation function to use If you don't specify anything, no activation is applied (ie. "linear" activation: `a(x) = x`). use_bias: Boolean, whether the layer uses a bias vector. kernel_initializer: Initializer for the `kernel` weights matrix bias_initializer: Initializer for the bias vector allow_kernel_norm: T/F kernel normalization is only applicable when TRAINING kernel_normalization_trainable: If True, a trainable weight norm variable is allocated Input shapes: nD tensor with shape: `(batch_size, ..., input_dim)`. The most common situation would be a 2D input with shape `(batch_size, input_dim)`. Output shapes: nD tensor with shape: `(batch_size, ..., units)`. For instance, for a 2D input with shape `(batch_size, input_dim)`, the output would have shape `(batch_size, units)`. """ def __init__(self, slot_names, embedding_size_choices_list, warmup_steps, pretraining_steps, teacher_embedding_sizes_list=None, distillation_mask=False, **kwargs): super(MixedEmbedOpComb, self).__init__(**kwargs) print(len(slot_names)) print(len(embedding_size_choices_list)) assert len(slot_names) == len(embedding_size_choices_list) self._slot_names = slot_names self._embedding_size_choices_list = embedding_size_choices_list self._num_choices_per_embedding = [] self._max_choice_per_embedding = [] self._max_num_choices = 0 self._total_emb_size = 0 self._warmup_steps = warmup_steps self._pretraining_steps = pretraining_steps for embedding_size_choices in embedding_size_choices_list: self._num_choices_per_embedding.append(len(embedding_size_choices)) self._max_num_choices = max(self._max_num_choices, len(embedding_size_choices)) self._max_choice_per_embedding.append(sum(embedding_size_choices)) self._total_emb_size += self._max_choice_per_embedding[-1] self._teacher_embedding_sizes_list = None self._arch_embedding_weights_multipler = None self._arch_embedding_weights = None # allowed input specification if self._teacher_embedding_sizes_list is not None: self.input_spec = [InputSpec(ndim=2), InputSpec(ndim=2)] else: self.input_spec = InputSpec(ndim=2) self._distillation_mask = distillation_mask def build(self, input_shape): assert len(input_shape) == 2 if self._teacher_embedding_sizes_list is not None: assert len(input_shape[0]) == 2 and len(input_shape[1]) == 2 input_dim = input_shape[0][-1] assert input_shape[1][-1] == sum(self._teacher_embedding_sizes_list) else: input_dim = input_shape[-1] print(input_dim) print(self._total_emb_size) assert input_dim == self._total_emb_size # kernel self._arch_embedding_weights = self.add_weight( shape=(sum(self._num_choices_per_embedding),), initializer=tf.random_uniform_initializer(minval=-1e-3, maxval=1e-3), trainable=True, name='arch_embedding_weights') print("arch embedding weights: {}".format(self._arch_embedding_weights)) current_idx = 0 arch_embedding_masks_list = [] arch_embedding_weights_multiplier_list = [] arch_entropy_list = [] expected_emb_dims_list = [] expected_zero_embedding_size_weights_list = [] arch_embedding_weights_after_softmax_list = [] for i in range(len(self._slot_names)): num_choices = self._num_choices_per_embedding[i] max_emb_choice = sum(self._embedding_size_choices_list[i]) arch_embedding_weights_slice = self._arch_embedding_weights[ current_idx:current_idx + num_choices] arch_embedding_weights_after_softmax = tf.nn.softmax( arch_embedding_weights_slice) #softmax selection #arch_embedding_weights_after_softmax = tf.math.sigmoid(arch_embedding_weights_slice) #sigmoid selection like FairNAS arch_entropy_list.append( tf.compat.v1.nn.softmax_cross_entropy_with_logits_v2( labels=arch_embedding_weights_after_softmax, logits=arch_embedding_weights_slice)) expected_emb_dims_list.append( tf.reduce_sum(arch_embedding_weights_after_softmax * self._embedding_size_choices_list[i])) if self._embedding_size_choices_list[i][0] == 0: expected_zero_embedding_size_weights_list.append( arch_embedding_weights_after_softmax[0]) arch_embedding_weights_after_softmax = tf.concat([ arch_embedding_weights_after_softmax[0:1] * tf.cast( tf.minimum( tf.maximum( tf.compat.v1.train.get_global_step() - self._warmup_steps, 0.0), 1.0), self.dtype), arch_embedding_weights_after_softmax[1:] ], axis=0) embedding_masks = [] lower = 0 upper = 0 for j, embedding_size_choice in enumerate( self._embedding_size_choices_list[i]): name = 'arch_embedding_weights_after_softmax/{}_{}'.format( self._slot_names[i], embedding_size_choice) data = arch_embedding_weights_after_softmax[j] arch_embedding_weights_after_softmax_list.append((name, data)) upper += embedding_size_choice mask = tf.constant([ 1.0 if jj < upper and jj >= lower else 0.0 for jj in range(max_emb_choice) ], dtype=self.dtype) #mask = tf.constant([ # 1.0 / embedding_size_choice if jj < upper and jj >= lower else 0.0 # for jj in range(max_emb_choice) #], # dtype=self.dtype) # Balance the gradient of each slot choices lower += embedding_size_choice embedding_masks.append(mask) # [self._max_num_choices, max_emb_choice] embedding_mask = tf.pad( tf.stack(embedding_masks, 0), [[0, self._max_num_choices - num_choices], [0, 0]]) #print('embedding_mask: {}'.format(tf.keras.backend.eval(embedding_mask))) # [self._max_num_choices] arch_embedding_weights_after_softmax_per_slot_padded = tf.pad( arch_embedding_weights_after_softmax, [[0, self._max_num_choices - num_choices]]) probability = tf.where( tf.cast(tf.compat.v1.train.get_or_create_global_step(), tf.float32) < tf.cast(self._pretraining_steps, tf.float32), [0.5, 0.5], #[0.25, 0.25, 0.25, 0.25] [ arch_embedding_weights_after_softmax_per_slot_padded[i] for i in range(self._max_num_choices) ]) # [max_emb_choice] # method 1: Sampling-based nws indices = tf.random.categorical( tf.math.log(tf.expand_dims(probability, 0)), 1) index = tf.reduce_sum(indices) index_one_hot = tf.one_hot(index, self._max_num_choices) embedding_masks_selected = tf.reduce_sum( embedding_masks * tf.expand_dims(index_one_hot, -1), 0) arch_embedding_weights_after_softmax_per_slot_padded_chosen = tf.reduce_sum( arch_embedding_weights_after_softmax_per_slot_padded * index_one_hot, 0) arch_embedding_masks_list.append( embedding_masks_selected * (1 + arch_embedding_weights_after_softmax_per_slot_padded_chosen - tf.stop_gradient( arch_embedding_weights_after_softmax_per_slot_padded_chosen))) # [max_emb_choice] # method 2: MixedOp-based nws #arch_embedding_masks_list.append(embedding_mask) #arch_embedding_weights_multiplier_list.append( # tf.broadcast_to( # tf.expand_dims( # probability, # -1), tf.shape(embedding_mask) # ) # ) current_idx += num_choices # [total_emb_dims] self._arch_embedding_masks_multipler = tf.concat( arch_embedding_masks_list, 0) # method 1: Sampling-based nws #self._arch_embedding_masks_multipler = tf.concat( # arch_embedding_masks_list, 1) #method 2: MixedOp-based nws #self._arch_embedding_weights_multipler = tf.concat( # arch_embedding_weights_multiplier_list, 1) self._arch_entropy = tf.add_n(arch_entropy_list) self._expected_emb_dims = tf.add_n(expected_emb_dims_list) self._expected_zero_embedding_size_weights = tf.add_n( expected_zero_embedding_size_weights_list ) if expected_zero_embedding_size_weights_list else 0 self._arch_embedding_weights_after_softmax_list = arch_embedding_weights_after_softmax_list self.built = True def call(self, inputs): if self._teacher_embedding_sizes_list is not None: embedding = inputs[0] teacher_embedding = inputs[1] else: embedding = inputs # [batch_size, total_emb_dims] masked_embedding = embedding * self._arch_embedding_masks_multipler # method 1: Sampling-based nws #masked_embedding = tf.expand_dims( # embedding, 1) * self._arch_embedding_masks_multipler # method 2: MixedOp-based nws #mixed_embedding = tf.reduce_sum( # masked_embedding * self._arch_embedding_weights_multipler, 1) if self._teacher_embedding_sizes_list is not None: print("TeacherEmbeddingTransform: {} {}".format( self._max_choice_per_embedding, self._teacher_embedding_sizes_list)) teacher_embedding_transform = TeacherEmbeddingTransform( self._max_choice_per_embedding, self._teacher_embedding_sizes_list, dtype=self.dtype) teacher_embedding_transformed = teacher_embedding_transform( teacher_embedding) # [batch_size, total_emb_dims] -> [batch_size, 1, total_emb_dims] # -> [batch_size, self._max_num_choices, total_emb_dims] if not self._distillation_mask: distillation_loss = tf.losses.mean_squared_error( tf.broadcast_to(tf.expand_dims(teacher_embedding_transformed, 1), tf.shape(masked_embedding)), masked_embedding) else: masked_teacher_embedding_transformed = tf.expand_dims( teacher_embedding_transformed, 1) * self._arch_embedding_masks_multipler distillation_loss = tf.losses.mean_squared_error( masked_teacher_embedding_transformed, masked_embedding) return mixed_embedding, distillation_loss, teacher_embedding_transform.name else: return masked_embedding # method 1: Sampling-based nws #return mixed_embedding # method 2: MixedOp-based nws def compute_output_shape(self, input_shape): raise NotImplementedError("I don't think I need to implement this one.") def get_config(self): config = { 'slot_names': self._slot_names, 'embedding_size_choices_list': self._embedding_size_choices_list, 'warmup_steps': self._warmup_steps, 'teacher_embedding_sizes_list': self._teacher_embedding_sizes_list, } base_config = super(MixedEmbedOpComb, self).get_config() return dict(list(base_config.items()) + list(config.items())) def get_arch_embedding_weights(self): return self._arch_embedding_weights def get_summaries(self): return { 'arch_entropy': self._arch_entropy, 'expected_emb_dims': self._expected_emb_dims, 'arch_weights_after_softmax_list': self._arch_embedding_weights_after_softmax_list, } ================================================ FILE: monolith/core/model.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Check in TPU embedding feature from TensorFlow.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import functools import math from absl import logging import tensorflow as tf from tensorflow.python.tpu import tpu_embedding from monolith.core.feature import FeatureSlot, FeatureColumnV1, Env _FLOAT32_BYTES = 4 ### This code will be deprecated. Please update and usebase_embedding_task.py. class Model(object): """Sail-like TPU Model.""" def __init__(self, params): self._params = params self._vocab_size_per_slot = params["vocab_size_per_slot"] if self._vocab_size_per_slot is not None: logging.info("Set fixed vocab_size: {} for all the slots.".format( self._vocab_size_per_slot)) vocab_size_dict = self._create_vocab_dict(params["vocab_file_path"], self._vocab_size_per_slot) self._env = Env(vocab_size_dict=vocab_size_dict) # Run this to initialize slot, embedding dim information self.init_slot_to_dims() def _create_vocab_dict(self, file_path, vocab_size_per_slot=None): """Create vocab dict from a tsv file. Args: file_path: the path to the vocab dict vocab_size_per_slot: If None, this is set to the number of unique FIDs for each slot, obtained from the vocab_size file. Otherwise, this value is used to manually set the vocab sise per slot (this option is to speed up testing and modeling iteration). """ vocab_size_dict = {} with open(file_path) as f: for line in f: fields = line.strip().split("\t") assert len(fields) == 2, "each line in {} must have 2 fields".format( fields) if fields[0].isdigit() == False: continue slot_id = int(fields[0]) distinct_count = vocab_size_per_slot if vocab_size_per_slot is None: distinct_count = int(fields[1]) vocab_size_dict[slot_id] = distinct_count return vocab_size_dict def _get_feature_map(self): """Returns data format of the serialized tf record file.""" # Inherated class must implement this function. raise NotImplementedError def _post_process_example(self, example): """Postprocess example.""" # build tensors for each embeddings in each slot for slot_id, dims in self._env.slot_to_dims.items(): # If the vocab size per slot is set, we need to adjust the # vocab_id so that no vocab_id exceed this vocab size per slot if self._vocab_size_per_slot: embedding_tensor = example["slot_{}_0".format(slot_id)] new_embedding_tensor = tf.SparseTensor( indices=embedding_tensor.indices, values=tf.math.mod(embedding_tensor.values, self._vocab_size_per_slot), dense_shape=embedding_tensor.dense_shape) example["slot_{}_0".format(slot_id)] = new_embedding_tensor for i in range(1, len(dims)): example["slot_{}_{}".format(slot_id, i)] = example["slot_{}_0".format(slot_id)] return example def create_input_fn(self, file_pattern, repeat=True): def tf_example_parser(examples): """Parse multiple examples.""" feature_map = self._get_feature_map() example = tf.io.parse_example(serialized=examples, features=feature_map) return self._post_process_example(example) def input_fn(params): """Returns training or eval examples, batched as specified in params.""" logging.info("Model input_fn") # By shuffle=False, list_files will get all files already in time sorted order. files = tf.data.Dataset.list_files(file_pattern, shuffle=False) # This function will get called once per TPU task. Each task will process the files # with indexs which modulo num_calls equals to call_index. _, call_index, num_calls, _ = ( params["context"].current_input_fn_deployment()) files = files.shard(num_calls, call_index) skip_files_number = 0 if params["shard_skip_file_number"] is not None: skip_files_number = params["shard_skip_file_number"][call_index] logging.info("Shard {} skipped {} files.".format(call_index, skip_files_number)) files = files.skip(skip_files_number) def fetch_dataset(filename): dataset = tf.data.TFRecordDataset( filename, compression_type=params["compression_type"], buffer_size=None) return dataset # Read the data from disk in parallel. # Files will be process from the beginning to the end. With a local parallel of interleaving # multiple files currently. Number of interleaving files are defined by the cycle. dataset = files.interleave( fetch_dataset, cycle_length=params["cycle_length"], num_parallel_calls=params["num_parallel_calls"], deterministic=False) dataset = dataset.batch(params["batch_size"], drop_remainder=True).map( tf_example_parser, num_parallel_calls=tf.data.experimental.AUTOTUNE, deterministic=False) # The tensors returned from this dataset will be directly used as the ids # for the embedding lookup. If you want to have a separate vocab, apply a # '.map' here to the dataset which contains you vocab lookup. if repeat: dataset = dataset.repeat() dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) return dataset return input_fn def _padding_8(self, dim): return math.ceil(dim / 8) * 8 def _get_slot_number(self, optimizer, use_gradient_accumulation): slot_num = 0 if isinstance(optimizer, tf.compat.v1.tpu.experimental.FtrlParameters): slot_num = 3 elif isinstance(optimizer, tf.compat.v1.tpu.experimental.AdagradParameters): slot_num = 2 elif isinstance(optimizer, tf.compat.v1.tpu.experimental.AdamParameters): slot_num = 3 elif isinstance( optimizer, tf.compat.v1.tpu.experimental.StochasticGradientDescentParameters): slot_num = 1 else: assert ("We don't support this optimizer type yet: {}".format( type(optimizer))) if use_gradient_accumulation == True: slot_num += 1 return slot_num def _get_max_slot_number(self): max_slot_number = 0 for slot_id, dims in self._env.slot_to_dims.items(): feature_slot = self._env.slot_to_config[slot_id] for index, dim in enumerate(dims): optimizer = None # If index is 0 and feature slot uses bias, then we will use bias optimizer and table initializer. # Also please note if slot has bias, bias will always use index = 0. if index == 0 and feature_slot.use_bias: optimizer = feature_slot.bias_optimizer else: optimizer = feature_slot.vec_optimizer max_slot_number = max( max_slot_number, self._get_slot_number(optimizer, self._params["use_gradient_accumulation"])) return max_slot_number def create_feature_and_table_config_dict(self): """Prepares the table and feature config given the parameters.""" env = self._env assert env.is_finalized() feature_to_config_dict = {} table_to_config_dict = {} embedding_table_size = 0 embedding_table_size_after_padding_8 = 0 embedding_table_size_after_padding_8_and_use_max_auxiliary_parameters = 0 max_slot_number = self._get_max_slot_number() for slot_id, dims in env.slot_to_dims.items(): assert slot_id in env.vocab_size_dict, "slot_id {} must be in vocab file".format( slot_id) vocab_size = env.vocab_size_dict[slot_id] assert slot_id in env.slot_to_config, "slot_id {} must be in slot_to_config".format( slot_id) feature_slot = env.slot_to_config[slot_id] for index, dim in enumerate(dims): optimizer = None table_initializer = None # If index is 0 and feature slot uses bias, then we will use bias optimizer and table initializer. # Also please note if slot has bias, bias will always use index = 0. if index == 0 and feature_slot.use_bias(): optimizer = feature_slot.bias_optimizer table_initializer = feature_slot.bias_initializer else: optimizer = feature_slot.vec_optimizer table_initializer = feature_slot.vec_initializer table = tpu_embedding.TableConfig(vocabulary_size=vocab_size, dimension=dim, initializer=table_initializer, combiner="sum", optimization_parameters=optimizer) table_name = "table_{}_{}".format(slot_id, index) table_to_config_dict[table_name] = table feature_name = "slot_{}_{}".format(slot_id, index) feature_to_config_dict[feature_name] = tpu_embedding.FeatureConfig( table_name) slot_num = self._get_slot_number( optimizer, self._params["use_gradient_accumulation"]) embedding_table_size += vocab_size * dim * _FLOAT32_BYTES * slot_num embedding_table_size_after_padding_8 += vocab_size * self._padding_8( dim) * _FLOAT32_BYTES * slot_num embedding_table_size_after_padding_8_and_use_max_auxiliary_parameters += vocab_size * self._padding_8( dim) * _FLOAT32_BYTES * max_slot_number logging.info("Size of all embedding tables in bytes: {}".format( embedding_table_size)) logging.info( "Size after padding the width of all tables to 8 float multiples in bytes: {}" .format(embedding_table_size_after_padding_8)) logging.info( "Size after padding to 8 float multiples and using max auxiliary parameters: {}" .format( embedding_table_size_after_padding_8_and_use_max_auxiliary_parameters )) return feature_to_config_dict, table_to_config_dict def sum_pooling(self, fc_dict, input_map, features, dim, total_embeddings, add_into_embeddings=True): slot_embeddings = [] dims = 0 for slot in features: #allocate embedding embedding = fc_dict[slot].add_vector(dim) dims += dim if add_into_embeddings: total_embeddings.append((embedding, dim)) slot_embeddings.append(embedding) if slot in input_map: input_slots = input_map.keys() c = 0 for item in input_slots: if isinstance(item, str): if str(slot) + '_' in item: c += 1 if isinstance(item, int): if item == slot: c += 1 input_map[str(slot) + '_' + str(c)] = embedding else: input_map[slot] = embedding if len(features) == 1: #单特征无需sum return slot_embeddings[0] return tf.add_n(slot_embeddings) def logits_fn(self): """Calculate logits.""" # Inherated class must implement this function. raise NotImplementedError def init_slot_to_dims(self): """Run this in the beginning to initialize the slot and its embedding dims information.""" logging.info("Model init_slot_to_dims") self.logits_fn() self._env.finalize() logging.info("_slot_to_dims: {}".format(self._env.slot_to_dims)) def create_model_fn(self): """Creates the model_fn to be used by the TPUEstimator.""" # Inherated class must implement this function. raise NotImplementedError ================================================ FILE: monolith/core/model_imports.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== from __future__ import absolute_import from __future__ import division from __future__ import print_function import importlib import sys from absl import logging def _Import(name): """Imports the python module of the given name.""" logging.info('Attempt to import {} ...'.format(name)) try: importlib.import_module(name) logging.info('Imported {}'.format(name)) return True except ImportError as e: # It is expected that some imports may be missing. logging.error('Could not import: {}\n'.format(e)) return False _ROOT = 'monolith.tasks' _DIRS = () def ImportAllParams(task_root=_ROOT, task_dirs=_DIRS, require_success=False): """Import all ModelParams to add to the global registry.""" success = False for task in task_dirs: # By our code repository convention, there is a params.py under the task's # params directory. params.py imports _all_ modules that may registers a # model param. module_str = '{}.{}.params.{}'.format(task_root, task, path) success = _Import('{}.{}.params.params'.format(task_root, task)) or success if require_success and not success: raise LookupError('Could not import any task params. Make sure task params ' 'are linked into the binary.') return success def ImportParams(model_name, task_root=_ROOT, task_dirs=_DIRS, require_success=True): """Attempts to only import the files that may contain the model.""" # 'model_name' follows .. if '.' not in model_name: raise ValueError('Invalid model name %s' % model_name) model_module = model_name.rpartition('.')[0] logging.info("model_module:{}".format(model_module)) # Try importing the module directly, in case it's a local import. logging.info("Searching local import ...") success = _Import(model_module) # Try built-in tasks imports. logging.info("Searching built-in tasks ...") for task in sorted(task_dirs): logging.info('{} || {}'.format(task, model_module)) if model_module.startswith(task + '.'): logging.info("Found built-in task: {}".format(task)) path = model_module[len(task) + 1:] module_str = '{}.{}.params.{}'.format(task_root, task, path) success = _Import(module_str) or success if require_success and not success: raise LookupError( 'Could not find any valid import paths for module %s. Check the logs ' 'above to see if there were errors importing the module, and make sure ' 'the relevant params files are linked into the binary.' % model_module) return success ================================================ FILE: monolith/core/model_registry.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== from __future__ import absolute_import from __future__ import division from __future__ import print_function import traceback import inspect import sys import tensorflow as tf from absl import logging from monolith.core import model_imports from monolith.core.base_model_params import SingleTaskModelParams class _ModelRegistryHelper(object): # Global dictionary mapping subclass name to registered ModelParam subclass. _MODEL_PARAMS = {} # Global set of modules from which ModelParam subclasses have been registered. _REGISTERED_MODULES = set() @classmethod def _ClassPathPrefix(cls): return 'monolith.tasks.' @classmethod def _ModelParamsClassKey(cls, src_cls, shortcut=False): """Returns a string key used for `src_cls` in the model registry. Args: src_cls: A subclass of `BaseModel`. shortcut: (Deprecated) generate shortcut version of given task. """ path = src_cls.__module__ if shortcut: # Removes the prefix. path_prefix = cls._ClassPathPrefix() path = path.replace(path_prefix, '') # Removes 'params.' if exists. if 'params.' in path: path = path.replace('params.', '') return '{}.{}'.format(path, src_cls.__name__) @classmethod def _GetSourceInfo(cls, src_cls): """Gets a source info string given a source class.""" return '%s@%s:%d' % (cls._ModelParamsClassKey(src_cls), inspect.getsourcefile(src_cls), inspect.getsourcelines(src_cls)[-1]) @classmethod def _RegisterModel(cls, src_cls): """Registers a ModelParams subclass in the global registry.""" for key in set([ cls._ModelParamsClassKey(src_cls, shortcut=False), cls._ModelParamsClassKey(src_cls, shortcut=True) ]): module = src_cls.__module__ if key in cls._MODEL_PARAMS: raise ValueError('Duplicate model registered for key {}: {}.{}'.format( key, module, src_cls.__name__)) logging.debug('Registering model %s', key) # Log less frequently (once per module) but at a higher verbosity level. if module not in cls._REGISTERED_MODULES: logging.info('Registering models from module: %s', module) cls._REGISTERED_MODULES.add(module) # Decorate param methods to add source info metadata. cls._MODEL_PARAMS[key] = src_cls return cls._ModelParamsClassKey(src_cls, shortcut=False) @classmethod def RegisterSingleTaskModel(cls, src_cls): """Class decorator that registers a `.SingleTaskModelParams` subclass.""" logging.info("Register {} Start".format(src_cls.__name__)) if not issubclass(src_cls, SingleTaskModelParams): raise TypeError('src_cls %s is not a SingleTaskModelParams!' % src_cls.__name__) cls._RegisterModel(src_cls) all_params = _ModelRegistryHelper._MODEL_PARAMS logging.info("Register {} successfully".format(src_cls.__name__)) return src_cls @staticmethod def GetAllRegisteredClasses(): """Returns global registry map from model names to their param classes.""" all_params = _ModelRegistryHelper._MODEL_PARAMS if not all_params: logging.warning('No classes registered.') return all_params @classmethod def GetClass(cls, class_key): """Returns a ModelParams subclass with the given `class_key`. Args: class_key: string key of the ModelParams subclass to return. Returns: A subclass of `SingleTaskModelParams`. Raises: LookupError: If no class with the given key has been registered. """ all_params = cls.GetAllRegisteredClasses() if class_key not in all_params: for k in sorted(all_params): logging.info('Known model: %s', k) raise LookupError('Model %s not found from list of above known models.' % class_key) return all_params[class_key] @classmethod def GetParams(cls, class_key): """Constructs a `Params` object for given model. Args: class_key: String class key. Returns: Full `~.hyperparams.Params` for the model class. """ model_params_cls = cls.GetClass(class_key) model_params = model_params_cls() cfg = model_params.task() return cfg RegisterSingleTaskModel = _ModelRegistryHelper.RegisterSingleTaskModel def GetAllRegisteredClasses(): model_imports.ImportAllParams() return _ModelRegistryHelper.GetAllRegisteredClasses() def GetClass(class_key): model_imports.ImportParams(class_key) return _ModelRegistryHelper.GetClass(class_key) def GetParams(class_key): model_imports.ImportParams(class_key) return _ModelRegistryHelper.GetParams(class_key) ================================================ FILE: monolith/core/optimizers.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from tensorflow.compat.v1.train import AdagradOptimizer from tensorflow.compat.v1.train import MomentumOptimizer from tensorflow.compat.v1.train import RMSPropOptimizer from tensorflow.compat.v1.train import AdamOptimizer optimizers = { 'adagrad': AdagradOptimizer, 'momentum': MomentumOptimizer, 'rmsprop': RMSPropOptimizer, 'adam': AdamOptimizer, } ================================================ FILE: monolith/core/py_utils.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Common utilities.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import numpy as np import re import six _NAME_PATTERN = re.compile('[A-Za-z_][A-Za-z0-9_]*') class NestedMap(dict): """A simple helper to maintain a dict. It is a sub-class of dict with the following extensions/restrictions: - It supports attr access to its members (see examples below). - Member keys have to be valid identifiers. E.g.:: >>> foo = NestedMap() >>> foo['x'] = 10 >>> foo.y = 20 >>> assert foo.x * 2 == foo.y """ # Disable pytype attribute checking. _HAS_DYNAMIC_ATTRIBUTES = True # keys in this list are not allowed in a NestedMap. _RESERVED_KEYS = set(dir(dict)) # sentinel value for deleting keys used in Filter. _DELETE = object() def __init__(self, *args, **kwargs): super(NestedMap, self).__init__(*args, **kwargs) for key in self.keys(): assert isinstance(key, six.string_types), ( 'Key in a NestedMap has to be a six.string_types. Currently type: %s,' ' value: %s' % (str(type(key)), str(key))) NestedMap.CheckKey(key) assert key not in NestedMap._RESERVED_KEYS, ('%s is a reserved key' % key) def __setitem__(self, key, value): # Make sure key is a valid expression and is not one of the reserved # attributes. assert isinstance(key, six.string_types), ( 'Key in a NestedMap has to be a six.string_types. Currently type: %s, ' 'value: %s' % (str(type(key)), str(key))) NestedMap.CheckKey(key) assert key not in NestedMap._RESERVED_KEYS, ('%s is a reserved key' % key) super(NestedMap, self).__setitem__(key, value) def __setattr__(self, name, value): self.__setitem__(name, value) def __getattr__(self, name): try: return self[name] except KeyError as e: raise AttributeError('%s; available attributes: %s' % (e, sorted(list(self.keys())))) def __delattr__(self, name): try: del self[name] except KeyError as e: raise AttributeError('%s; available attributes: %s' % (e, sorted(list(self.keys())))) def copy(self): # Don't delegate w/ super: dict.copy() -> dict. return NestedMap(self) def __deepcopy__(self, unused_memo): """Deep-copies the structure but not the leaf objects.""" return self.DeepCopy() def DeepCopy(self): """Deep-copies the structure but not the leaf objects.""" return self.Pack(self.Flatten()) @staticmethod def FromNestedDict(x): """Converts every dict in nested structure 'x' to a NestedMap.""" if isinstance(x, dict): res = NestedMap() for k, v in six.iteritems(x): res[k] = NestedMap.FromNestedDict(v) return res elif isinstance(x, (list, tuple)): return type(x)(NestedMap.FromNestedDict(v) for v in x) else: return x @staticmethod def CheckKey(key): """Asserts that key is valid NestedMap key.""" if not (isinstance(key, six.string_types) and _NAME_PATTERN.match(key)): raise ValueError('Invalid NestedMap key \'{}\''.format(key)) def GetItem(self, key): """Gets the value for the nested `key`. Note that indexing lists is not supported, names with underscores will be considered as one key. Args: key: str of the form `([A-Za-z_][A-Za-z0-9_]*)(.[A-Za-z_][A-Za-z0-9_]*)*.`. Returns: The value for the given nested key. Raises: KeyError if a key is not present. """ current = self # Note: This can't support lists. List keys are ambiguous as underscore is # not reserved for list indexing but also allowed to be used in keys. # E.g., this is a valid nested map where the key 'a_0' is not well defined # {'a_0': 3, 'a': [4]}. for k in key.split('.'): current = current[k] return current def Get(self, key, default=None): """Gets the value for nested `key`, returns `default` if key does not exist. Note that indexing lists is not supported, names with underscores will be considered as one key. Args: key: str of the form `([A-Za-z_][A-Za-z0-9_]*)(.[A-Za-z_][A-Za-z0-9_]*)*.`. default: Optional default value, defaults to None. Returns: The value for the given nested key or `default` if the key does not exist. """ try: return self.GetItem(key) # TypeError is raised when an intermediate item is a list and we try to # access an element of it with a string. except (KeyError, TypeError): return default def Set(self, key, value): """Sets the value for a nested key. Note that indexing lists is not supported, names with underscores will be considered as one key. Args: key: str of the form `([A-Za-z_][A-Za-z0-9_]*)(.[A-Za-z_][A-Za-z0-9_]*)*.`. value: The value to insert. Raises: ValueError if a sub key is not a NestedMap or dict. """ current = self sub_keys = key.split('.') for i, k in enumerate(sub_keys): self.CheckKey(k) # We have reached the terminal node, set the value. if i == (len(sub_keys) - 1): current[k] = value else: if k not in current: current[k] = NestedMap() if not isinstance(current[k], (dict, NestedMap)): raise ValueError('Error while setting key {}. Sub key "{}" is of type' ' {} but must be a dict or NestedMap.' ''.format(key, k, type(current[k]))) current = current[k] def _RecursiveMap(self, fn, flatten=False): """Traverse recursively into lists and NestedMaps applying `fn`. Args: fn: The function to apply to each item (leaf node). flatten: If true, the result should be a single flat list. Otherwise the result will have the same structure as this NestedMap. Returns: The result of applying fn. """ def Recurse(v, key=''): """Helper function for _RecursiveMap.""" if isinstance(v, NestedMap): ret = [] if flatten else NestedMap() deleted = False for k in sorted(v.keys()): res = Recurse(v[k], key + '.' + k if key else k) if res is self._DELETE: deleted = True continue elif flatten: ret += res else: ret[k] = res if not ret and deleted: return self._DELETE return ret elif isinstance(v, list): ret = [] deleted = False for i, x in enumerate(v): res = Recurse(x, '%s[%d]' % (key, i)) if res is self._DELETE: deleted = True continue elif flatten: ret += res else: ret.append(res) if not ret and deleted: return self._DELETE return ret else: ret = fn(key, v) if flatten: ret = [ret] return ret res = Recurse(self) if res is self._DELETE: return [] if flatten else NestedMap() return res def Flatten(self): """Returns a list containing the flattened values in the `.NestedMap`. Unlike py_utils.Flatten(), this will only descend into lists and NestedMaps and not dicts, tuples, or namedtuples. """ return self._RecursiveMap(lambda _, v: v, flatten=True) def FlattenItems(self): """Flatten the `.NestedMap` and returns pairs in a list. Returns: A list of pairs, where keys for nested entries will be represented in the form of `foo.bar[10].baz`. """ return self._RecursiveMap(lambda k, v: (k, v), flatten=True) def Pack(self, lst): """Returns a copy of this with each value replaced by a value in lst.""" assert len(self.FlattenItems()) == len(lst) v_iter = iter(lst) return self._RecursiveMap(lambda unused_k, unused_v: next(v_iter)) def Transform(self, fn): """Returns a copy of this `.NestedMap` with fn applied on each value.""" return self._RecursiveMap(lambda _, v: fn(v)) def IsCompatible(self, other): """Returns true if self and other are compatible. If x and y are two compatible `.NestedMap`, `x.Pack(y.Flatten())` produces y and vice versa. Args: other: Another `.NestedMap`. """ items = self._RecursiveMap(lambda k, _: k, flatten=True) other_items = other._RecursiveMap(lambda k, _: k, flatten=True) # pylint: disable=protected-access return items == other_items def Filter(self, fn): """Returns a copy with entries where fn(entry) is True.""" return self.FilterKeyVal(lambda _, v: fn(v)) def FilterKeyVal(self, fn): """Returns a copy of this `.NestedMap` filtered by fn. If fn(key, entry) is True, the entry is copied into the returned NestedMap. Otherwise, it is not copied. Args: fn: a callable of (string, entry)->boolean. Returns: A `.NestedMap` contains copied entries from this `'.NestedMap`. """ return self._RecursiveMap(lambda k, v: v if fn(k, v) else self._DELETE) def _ToStrings(self): """Returns debug strings in a list for this `.NestedMap`.""" kv = self.FlattenItems() maxlen = max([len(k) for k, _ in kv]) if kv else 0 return sorted([k + ' ' * (4 + maxlen - len(k)) + str(v) for k, v in kv]) def DebugString(self): """Returns a debug string for this `.NestedMap`.""" return '\n'.join(self._ToStrings()) def VLog(self, level=None, prefix=None): """Logs the debug string at the level.""" if level is None: level = 0 if prefix is None: prefix = 'nmap: ' for l in self._ToStrings(): tf.logging.vlog(level, '%s %s', prefix, l) ================================================ FILE: monolith/core/testing_utils.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Utilities for unit-testing layers. The implementation for these utilities is similar to that in https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/testing_utils.py """ from __future__ import absolute_import from __future__ import division from __future__ import print_function import functools import threading import numpy as np import tensorflow as tf from tensorflow.python import tf2 from tensorflow.python.eager import context from tensorflow.python.framework import dtypes from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import test_util from tensorflow.python.keras import backend from tensorflow.python.keras import layers from tensorflow.python.keras import models from tensorflow.python.keras import testing_utils from tensorflow.python.keras.engine import base_layer_utils from tensorflow.python.keras.engine import keras_tensor from tensorflow.python.util import tf_contextlib from tensorflow.python.util import tf_decorator from tensorflow.python.util import tf_inspect # from tensorflow.python.keras import testing_utils @test_util.disable_cudnn_autotune def layer_test(layer_cls, kwargs=None, input_shape=None, input_dtype=None, input_data=None, expected_output=None, expected_output_dtype=None, expected_output_shape=None, validate_training=True, adapt_data=None, custom_objects=None, test_harness=None): """Test routine for a BaseLayer with a single input and single output. Args: layer_cls: BaseLayer class object. kwargs: Optional dictionary of keyword arguments for instantiating the layer. input_shape: Input shape tuple. input_dtype: Data type of the input data. input_data: Numpy array of input data. expected_output: Numpy array of the expected output. expected_output_dtype: Data type expected for the output. expected_output_shape: Shape tuple for the expected shape of the output. validate_training: Whether to attempt to validate training on this layer. This might be set to False for non-differentiable layers that output string or integer values. adapt_data: Optional data for an 'adapt' call. If None, adapt() will not be tested for this layer. This is only relevant for PreprocessingLayers. custom_objects: Optional dictionary mapping name strings to custom objects in the layer class. This is helpful for testing custom layers. test_harness: The Tensorflow test, if any, that this function is being called in. Returns: The output data (Numpy array) returned by the layer, for additional checks to be done by the calling code. Raises: ValueError: if `input_shape is None`. """ if input_data is None: if input_shape is None: raise ValueError('input_shape is None') if not input_dtype: input_dtype = 'float32' input_data_shape = list(input_shape) for i, e in enumerate(input_data_shape): if e is None: input_data_shape[i] = np.random.randint(1, 4) input_data = 10 * np.random.random(input_data_shape) if input_dtype[:5] == 'float': input_data -= 0.5 input_data = input_data.astype(input_dtype) elif input_shape is None: input_shape = input_data.shape if input_dtype is None: input_dtype = input_data.dtype if expected_output_dtype is None: expected_output_dtype = input_dtype if dtypes.as_dtype(expected_output_dtype) == dtypes.string: if test_harness: assert_equal = test_harness.assertAllEqual else: assert_equal = string_test else: if test_harness: assert_equal = test_harness.assertAllClose else: # assert_equal = tf.python.keras.testing_utils.numeric_test assert_equal = testing_utils.numeric_test # instantiation kwargs = kwargs or {} layer = layer_cls(**kwargs) # Test adapt, if data was passed. if adapt_data is not None: layer.adapt(adapt_data) # test get_weights , set_weights at layer level weights = layer.get_weights() layer.set_weights(weights) # test and instantiation from weights if 'weights' in tf_inspect.getargspec(layer_cls.__init__): kwargs['weights'] = weights layer = layer_cls(**kwargs) # test in functional API x = layers.Input(shape=input_shape[1:], dtype=input_dtype) y = layer(x) if backend.dtype(y) != expected_output_dtype: raise AssertionError('When testing layer %s, for input %s, found output ' 'dtype=%s but expected to find %s.\nFull kwargs: %s' % (layer_cls.__name__, x, backend.dtype(y), expected_output_dtype, kwargs)) def assert_shapes_equal(expected, actual): """Asserts that the output shape from the layer matches the actual shape.""" if len(expected) != len(actual): raise AssertionError( 'When testing layer %s, for input %s, found output_shape=' '%s but expected to find %s.\nFull kwargs: %s' % (layer_cls.__name__, x, actual, expected, kwargs)) for expected_dim, actual_dim in zip(expected, actual): if isinstance(expected_dim, tensor_shape.Dimension): expected_dim = expected_dim.value if isinstance(actual_dim, tensor_shape.Dimension): actual_dim = actual_dim.value if expected_dim is not None and expected_dim != actual_dim: raise AssertionError( 'When testing layer %s, for input %s, found output_shape=' '%s but expected to find %s.\nFull kwargs: %s' % (layer_cls.__name__, x, actual, expected, kwargs)) if expected_output_shape is not None: assert_shapes_equal(tensor_shape.TensorShape(expected_output_shape), y.shape) # check shape inference model = models.Model(x, y) computed_output_shape = tuple( layer.compute_output_shape( tensor_shape.TensorShape(input_shape)).as_list()) computed_output_signature = layer.compute_output_signature( tensor_spec.TensorSpec(shape=input_shape, dtype=input_dtype)) actual_output = model.predict(input_data) actual_output_shape = actual_output.shape assert_shapes_equal(computed_output_shape, actual_output_shape) assert_shapes_equal(computed_output_signature.shape, actual_output_shape) if computed_output_signature.dtype != actual_output.dtype: raise AssertionError( 'When testing layer %s, for input %s, found output_dtype=' '%s but expected to find %s.\nFull kwargs: %s' % (layer_cls.__name__, x, actual_output.dtype, computed_output_signature.dtype, kwargs)) if expected_output is not None: assert_equal(actual_output, expected_output) ================================================ FILE: monolith/core/tpu_variable.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Distributed variable implementation for TPU.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import contextlib from tensorflow.python.framework import ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gen_resource_variable_ops try: from tensorflow.python.types import core # pylint:disable=g-import-not-at-top,g-direct-tensorflow-import TF_23 = True except ImportError: TF_23 = False if TF_23: VariableBase = core.Tensor else: VariableBase = object @contextlib.contextmanager def _handle_graph(handle): with handle.graph.as_default(): yield def _enclosing_tpu_context(): # pylint: disable=protected-access context = ops.get_default_graph()._get_control_flow_context() # pylint: enable=protected-access while context is not None and not isinstance( context, control_flow_ops.XLAControlFlowContext): context = context.outer_context return context class ReplicatedVariable(VariableBase): """A replicated variable for use on TPUs. When accessed inside a tpu.replicate() context, this variable acts as if it is a single variable whose handle is a replicated input to the computation. Outside a tpu.replicate() context currently this object has pretty murky semantics, especially with respect to things such as * initialization * colocation. """ def __init__(self, name, variables): self._name = name self._primary_var = variables[0] self._vars = variables self._cached_value = None self._dtype = variables[0].dtype @property def handle(self): tpu_context = _enclosing_tpu_context() if tpu_context is None: return self._primary_var.handle return tpu_context.get_replicated_var_handle(self._name, self._vars) @contextlib.contextmanager def _assign_dependencies(self): """Makes assignments depend on the cached value, if any. This prevents undefined behavior with reads not ordered wrt writes. Yields: None. """ if self._cached_value is not None: with ops.control_dependencies([self._cached_value]): yield else: yield @property def initializer(self): return control_flow_ops.group([v.initializer for v in self._vars]) @property def graph(self): return self._primary_var.graph @property def _shared_name(self): return self._common_name @property def _unique_id(self): return self._primary_var._unique_id # pylint: disable=protected-access @property def name(self): return self._name @property def dtype(self): return self._primary_var.dtype @property def shape(self): return self._primary_var.shape def get_shape(self): return self._primary_var.get_shape() def to_proto(self, export_scope=None): return self._primary_var.to_proto(export_scope=export_scope) @property def constraint(self): return None @property def op(self): return self.get().op def _read_variable_op(self): if _enclosing_tpu_context() is None: return self._primary_var.read_value() v = gen_resource_variable_ops.read_variable_op(self.handle, self._dtype) return v def read_value(self): return self._read_variable_op() def assign(self, value, use_locking=None, name=None, read_value=False): del use_locking with _handle_graph(self.handle), self._assign_dependencies(): value_tensor = ops.convert_to_tensor(value, dtype=self.dtype) assign_op = gen_resource_variable_ops.assign_variable_op(self.handle, value_tensor, name=name) if read_value: return self._read_variable_op() return assign_op def assign_add(self, delta, use_locking=None, name=None, read_value=True): del use_locking with _handle_graph(self.handle), self._assign_dependencies(): assign_add_op = gen_resource_variable_ops.assign_add_variable_op( self.handle, ops.convert_to_tensor(delta, dtype=self.dtype), name=name) if read_value: return self._read_variable_op() return assign_add_op def assign_sub(self, delta, use_locking=None, name=None, read_value=True): del use_locking with _handle_graph(self.handle), self._assign_dependencies(): assign_sub_op = gen_resource_variable_ops.assign_sub_variable_op( self.handle, ops.convert_to_tensor(delta, dtype=self.dtype), name=name) if read_value: return self._read_variable_op() return assign_sub_op def get(self): return self._primary_var @property def _in_graph_mode(self): return self._primary_var._in_graph_mode # pylint: disable=protected-access def _should_act_as_resource_variable(self): """Pass resource_variable_ops.is_resource_variable check.""" pass def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False): """Converts a variable to a tensor.""" # pylint: disable=protected-access if _enclosing_tpu_context() is None: if hasattr(self._primary_var, '_dense_var_to_tensor'): return self._primary_var._dense_var_to_tensor(dtype, name, as_ref) else: return ops.convert_to_tensor(self._primary_var) # pylint: enable=protected-access if dtype is not None and dtype != self.dtype: return NotImplemented if as_ref: return self.handle else: return self.read_value() # Register a conversion function which reads the value of the variable, # allowing instances of the class to be used as tensors. def _tensor_conversion(var, dtype=None, name=None, as_ref=False): return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref) # pylint: disable=protected-access ops.register_tensor_conversion_function(ReplicatedVariable, _tensor_conversion) if not TF_23: ops.register_dense_tensor_like_type(ReplicatedVariable) ================================================ FILE: monolith/core/util.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from os import path import subprocess from absl import logging from google.cloud import storage import tensorflow.compat.v1 as tf _GS_PREFIX = "gs://" _CORE_NUMBER_PER_HOST = 8 _DATE_FORMAT_LEN = 8 _MIN_DATE = "00000000" _MAX_DATE = "99999999" def get_bucket_name_and_relavite_path(gs_file_path): """ Given gs file path, return gs bucket name and relavite gs path (not include gs bucket).""" assert gs_file_path.find(_GS_PREFIX) != -1, "File name: {}".format( gs_file_path) bucket_name_start = len(_GS_PREFIX) bucket_name_end = gs_file_path.find("/", bucket_name_start) bucket_name = gs_file_path[bucket_name_start:bucket_name_end] relavite_blob_path = gs_file_path[bucket_name_end + 1:] return bucket_name, relavite_blob_path def download_gcs_file(gs_file_path, local_file_name): """ Download gs file to local disk by giving gs path and local file name.""" logging.info("Start downloading {} => {} ...".format(gs_file_path, local_file_name)) bucket_name, relavite_blob_path = get_bucket_name_and_relavite_path( gs_file_path) download_gcs_file_with_relative_path(bucket_name, relavite_blob_path, local_file_name) def download_gcs_file_with_relative_path(bucket_name, gs_file_relative_path, local_file_name): """ Download gs file to local disk by giving gs relavite path and local file name.""" storage_client = storage.Client() bucket = storage_client.bucket(bucket_name) blob = bucket.blob(gs_file_relative_path) blob.download_to_filename(local_file_name) def list_gcs_files_with_prefix(gs_path_prefix): """ Given a gs path prefix, return the gs bucket name and relavite paths (not include gs bucket) for all matching blobs.""" storage_client = storage.Client() bucket_name, relavite_blob_path_prefix = get_bucket_name_and_relavite_path( gs_path_prefix) blob_relative_paths = storage_client.list_blobs( bucket_name, prefix=relavite_blob_path_prefix) return bucket_name, blob_relative_paths def parse_example_number_meta_file(meta_file, seperator): """Parse the meta file which contains file name and its tr record number.""" file_index = 0 file_example_number_list = [] with open(meta_file) as f: previous_file_name = "" lines = f.readlines() for line in lines: if line.find(",") == -1: continue split_str = line.split(",") file_name = split_str[0] assert previous_file_name < file_name, "File name must be in dictionary ascending order. Previous file name: {}, current file file name: {}".format( previous_file_name, file_name) previous_file_name = file_name count = int(split_str[1]) file_example_number_list.append((file_name, count)) return file_example_number_list def calculate_shard_skip_file_number(file_example_number, shard_num, completed_steps_number, batch_size_per_core): """Calculate for each shard (host), how many files it has completed processing from last check point.""" processed_example_number_per_host = batch_size_per_core * completed_steps_number * _CORE_NUMBER_PER_HOST shard_index = 0 # Keep number of completed files for each shard (host) in last checkpoint. shard_skip_file_number = [0] * shard_num # Keep number of completed examples for each shard (host) in last checkpoint. shard_accumulated_example_count = [0] * shard_num for example_number in file_example_number: if example_number + shard_accumulated_example_count[ shard_index] <= processed_example_number_per_host: shard_accumulated_example_count[shard_index] += example_number shard_skip_file_number[shard_index] += 1 shard_index = (shard_index + 1) % shard_num return shard_skip_file_number def get_checkpoint_completed_step_number(checkpoint_path): """Get the completed steps number in the latest checkpoint under checkpoint path.""" completed_steps_number = 0 bucket_name, blob_relative_paths = list_gcs_files_with_prefix( path.join(checkpoint_path, "model.ckpt")) for blob in blob_relative_paths: blob_relative_path = blob.name if blob_relative_path.endswith(".meta") == False: continue blob_name = blob_relative_path[blob_relative_path.rfind("/") + 1:] logging.info("Found checkpoint file {} under path {}".format( blob_name, checkpoint_path)) checkpoint_processed_steps = int(blob_name[blob_name.find("-") + 1:blob_name.rfind(".meta")]) completed_steps_number = max(completed_steps_number, checkpoint_processed_steps) return completed_steps_number def update_params(params, tpu_cluster_resolver): shard_num = tpu_cluster_resolver.cluster_spec().num_tasks("worker") assert ("batch_size_per_core" in params and params["batch_size_per_core"] is not None) \ or ("global_batch_size" in params and params["global_batch_size"] is not None), \ "batch_size_per_core and global_batch_size can't be both None." if "batch_size_per_core" not in params or params[ "batch_size_per_core"] is None: params["batch_size_per_core"] = params[ "global_batch_size"] / shard_num / _CORE_NUMBER_PER_HOST elif "global_batch_size" not in params or params["global_batch_size"] is None: params["global_batch_size"] = params[ "batch_size_per_core"] * shard_num * _CORE_NUMBER_PER_HOST else: assert params["batch_size_per_core"] * shard_num * _CORE_NUMBER_PER_HOST == params["global_batch_size"], \ "Batch size per core: {} and global batch size:{} doesn't align.".format(params["batch_size_per_core"], params["global_batch_size"]) logging.info("Batch size per core: {}, global batch size: {}".format( params["batch_size_per_core"], params["global_batch_size"])) # Get the completed steps number from the latest checkpoint. completed_step_number = get_checkpoint_completed_step_number( params["model_dir"]) logging.info( "Completed steps from last checkpoint: {}".format(completed_step_number)) if completed_step_number > 0: file_example_number = get_per_file_example_numbers_for_checkpoint_reload( params["train_dataset_path"], params["gcs_file_example_number"], ",") shard_skip_file_number = calculate_shard_skip_file_number( file_example_number, shard_num, completed_step_number, params["batch_size_per_core"]) params["shard_skip_file_number"] = shard_skip_file_number logging.info( "Set shard skip file number, shard number: {}, batch size per core: {}, completed steps of last chckpoint: {}, \ processed file number of each shard: {}".format( shard_num, params["batch_size_per_core"], completed_step_number, shard_skip_file_number)) def get_per_file_example_numbers_for_checkpoint_reload( train_dataset_path, file_example_number_meta, file_example_number_meta_seperator): # Firstly, we need verify whether checkpoint can be reloaded. To reload checkpoint, we need make sure # the training data is contained as continous subset as in file example number meta. # Currently only gsutil supports query gcs with regex path. Google storage client looks like does't support regex. # We will use gsutil tool directly here as a workaround to work with regex path. Later we will switch back # to google storage client once it supports regex path. logging.info("Querying train data set to validate checkpoint reload...") proc = subprocess.Popen(["gsutil", "ls", train_dataset_path], stdout=subprocess.PIPE) train_file_path_list = [] previous_relative_path = "" while True: line = proc.stdout.readline() if not line: break train_file_path = line.decode("utf-8").strip() bucket_name, relative_path = get_bucket_name_and_relavite_path( train_file_path) train_file_path_list.append(relative_path) assert previous_relative_path < relative_path, "train file path must be in ascend order. \ previous file path: {}, current file path: {}".format( previous_relative_path, relative_path) previous_relative_path = relative_path file_example_number_list = parse_example_number_meta_file( file_example_number_meta, file_example_number_meta_seperator) # Skip the files which are not in trained data set. assert len( train_file_path_list) > 0, "Train data set size must be greater than 0." # Find the first train file in file example meta for file_example_number_index, file_example_number in enumerate( file_example_number_list): file_path = file_example_number[0] count = file_example_number[1] if train_file_path_list[0] <= file_path: break assert len(train_file_path_list) <= len(file_example_number_list) - file_example_number_index, \ "Train file path list length {} can't be greater than the remaining length of file example number list length {} starting at index {}".format(len(train_file_path_list), len(file_example_number_list) - file_example_number_index, file_example_number_index) example_number_list = [] for train_file_index in range(0, len(train_file_path_list)): assert train_file_path_list[train_file_index] == file_example_number_list[file_example_number_index][0], \ "File {} in train data can not be found in file example meta {}".format(train_file_path_list[train_file_index], file_example_number_meta) example_number_list.append( file_example_number_list[file_example_number_index][1]) file_example_number_index += 1 logging.info("Checkpoint reload verification done.") return example_number_list def range_dateset(dataset: tf.data.Dataset, root_path: str, start_date: str = None, end_date: str = None): if start_date is None: start_date = _MIN_DATE if end_date is None: end_date = _MAX_DATE logging.info("start_date: {}, end_date: {}.".format(start_date, end_date)) def filter_fn(x): path_prefix_len = len(root_path) return tf.math.logical_and( tf.math.greater_equal( tf.strings.to_number(tf.strings.substr(x, path_prefix_len, _DATE_FORMAT_LEN), out_type=tf.int32), int(start_date)), tf.math.less_equal( tf.strings.to_number(tf.strings.substr(x, path_prefix_len, _DATE_FORMAT_LEN), out_type=tf.int32), int(end_date)), ) return dataset.filter(filter_fn) ================================================ FILE: monolith/core/util_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import tensorflow.compat.v1 as tf import monolith.core.util as util class UtilTest(tf.test.TestCase): """Base class for tpu test.""" root_path = "gs://test_folder/unzipped_tf_records_corrected_repartitioned/" def test_range_dataset_single(self): expected_results = [ "gs://test_folder/unzipped_tf_records_corrected_repartitioned/20200502/00/part", ] with self.session() as sess: input_dataset = tf.data.Dataset.from_tensor_slices([ "gs://test_folder/unzipped_tf_records_corrected_repartitioned/20200501/00/part", "gs://test_folder/unzipped_tf_records_corrected_repartitioned/20200502/00/part", "gs://test_folder/unzipped_tf_records_corrected_repartitioned/20200503/00/part", ]) dataset = util.range_dateset(input_dataset, self.root_path, "20200502", "20200502") iterator = tf.compat.v1.data.make_one_shot_iterator(dataset) next_element = iterator.get_next() i = 0 try: while True: self.assertEqual(sess.run(next_element).decode(), expected_results[i]) i += 1 except tf.errors.OutOfRangeError: pass self.assertEqual(i, len(expected_results)) def test_range_dataset_multiple(self): expected_results = [ "gs://test_folder/unzipped_tf_records_corrected_repartitioned/20200502/00/part", "gs://test_folder/unzipped_tf_records_corrected_repartitioned/20200503/00/part", "gs://test_folder/unzipped_tf_records_corrected_repartitioned/20200503/01/part", ] with self.session() as sess: input_dataset = tf.data.Dataset.from_tensor_slices([ "gs://test_folder/unzipped_tf_records_corrected_repartitioned/20200501/00/part", "gs://test_folder/unzipped_tf_records_corrected_repartitioned/20200502/00/part", "gs://test_folder/unzipped_tf_records_corrected_repartitioned/20200503/00/part", "gs://test_folder/unzipped_tf_records_corrected_repartitioned/20200503/01/part", "gs://test_folder/unzipped_tf_records_corrected_repartitioned/20200504/01/part", ]) dataset = util.range_dateset(input_dataset, self.root_path, "20200502", "20200503") iterator = tf.compat.v1.data.make_one_shot_iterator(dataset) next_element = iterator.get_next() i = 0 try: while True: self.assertEqual(sess.run(next_element).decode(), expected_results[i]) i += 1 except tf.errors.OutOfRangeError: pass self.assertEqual(i, len(expected_results)) def test_range_dataset_out_of_boundary(self): expected_results = [ "gs://test_folder/unzipped_tf_records_corrected_repartitioned/20200501/00/part", "gs://test_folder/unzipped_tf_records_corrected_repartitioned/20200502/00/part", ] with self.session() as sess: input_dataset = tf.data.Dataset.from_tensor_slices([ "gs://test_folder/unzipped_tf_records_corrected_repartitioned/20200501/00/part", "gs://test_folder/unzipped_tf_records_corrected_repartitioned/20200502/00/part", ]) dataset = util.range_dateset(input_dataset, self.root_path, "20200401", "20200505") iterator = tf.compat.v1.data.make_one_shot_iterator(dataset) next_element = iterator.get_next() i = 0 try: while True: self.assertEqual(sess.run(next_element).decode(), expected_results[i]) i += 1 except tf.errors.OutOfRangeError: pass self.assertEqual(i, len(expected_results)) def test_range_dataset_no_start_date(self): expected_results = [ "gs://test_folder/unzipped_tf_records_corrected_repartitioned/20200501/00/part", "gs://test_folder/unzipped_tf_records_corrected_repartitioned/20200502/00/part", ] with self.session() as sess: input_dataset = tf.data.Dataset.from_tensor_slices([ "gs://test_folder/unzipped_tf_records_corrected_repartitioned/20200501/00/part", "gs://test_folder/unzipped_tf_records_corrected_repartitioned/20200502/00/part", ]) dataset = util.range_dateset(input_dataset, self.root_path, start_date=None, end_date="20200505") iterator = tf.compat.v1.data.make_one_shot_iterator(dataset) next_element = iterator.get_next() i = 0 try: while True: self.assertEqual(sess.run(next_element).decode(), expected_results[i]) i += 1 except tf.errors.OutOfRangeError: pass self.assertEqual(i, len(expected_results)) def test_range_dataset_no_end_date(self): expected_results = [ "gs://test_folder/unzipped_tf_records_corrected_repartitioned/20200502/00/part", ] with self.session() as sess: input_dataset = tf.data.Dataset.from_tensor_slices([ "gs://test_folder/unzipped_tf_records_corrected_repartitioned/20200501/00/part", "gs://test_folder/unzipped_tf_records_corrected_repartitioned/20200502/00/part", ]) dataset = util.range_dateset(input_dataset, self.root_path, start_date="20200502", end_date=None) iterator = tf.compat.v1.data.make_one_shot_iterator(dataset) next_element = iterator.get_next() i = 0 try: while True: self.assertEqual(sess.run(next_element).decode(), expected_results[i]) i += 1 except tf.errors.OutOfRangeError: pass self.assertEqual(i, len(expected_results)) if __name__ == "__main__": tf.disable_eager_execution() tf.test.main() ================================================ FILE: monolith/core/variance_scaling.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== "Code to implement a custom Variance Scaling initializer that returns numpy arrays." from __future__ import absolute_import from __future__ import division from __future__ import print_function import numpy as np import scipy.stats as stats def _compute_fans(shape, data_format='channels_last'): """Computes the number of input and output units for a weight shape. Args: shape: Integer shape tuple. data_format: Image data format to use for convolution kernels. Note that all kernels in Keras are standardized on the `channels_last` ordering (even when inputs are set to `channels_first`). Returns: A tuple of scalars, `(fan_in, fan_out)`. # Raises ValueError: in case of invalid `data_format` argument. """ if len(shape) == 2: fan_in = shape[0] fan_out = shape[1] elif len(shape) in {3, 4, 5}: # Assuming convolution kernels (1D, 2D or 3D). # TH kernel shape: (depth, input_depth, ...) # TF kernel shape: (..., input_depth, depth) if data_format == 'channels_first': receptive_field_size = np.prod(shape[2:]) fan_in = shape[1] * receptive_field_size fan_out = shape[0] * receptive_field_size elif data_format == 'channels_last': receptive_field_size = np.prod(shape[:-2]) fan_in = shape[-2] * receptive_field_size fan_out = shape[-1] * receptive_field_size else: raise ValueError('Invalid data_format: ' + data_format) else: # No specific assumptions. fan_in = np.sqrt(np.prod(shape)) fan_out = np.sqrt(np.prod(shape)) return fan_in, fan_out class VarianceScaling(): """Initializer capable of adapting its scale to the shape of weights. With `distribution="truncated_normal"`, samples are drawn from a truncated normal distribution centered on zero, with `stddev = sqrt(scale / n)` where n is: - number of input units in the weight tensor, if mode = "fan_in" - number of output units, if mode = "fan_out" - average of the numbers of input and output units, if mode = "fan_avg" With `distribution="uniform"`, samples are drawn from a uniform distribution within [-limit, limit], with `limit = sqrt(3 * scale / n)`. With `distribution="untrucated_normal"`, samples are drawn from a truncated normal distribution centered on zero, with `stddev = sqrt(scale / n)`. When called, this initializer produces a numpy array, instead of Tensorflow tensors. Args: scale (float, optional): Scaling factor (positive float). mode (str, optional): One of "fan_in", "fan_out", "fan_avg". distribution (str, optional): Random distribution to use. One of "truncated_normal", "untruncated_normal", and "uniform". seed (int, optional): A Python integer. Used to seed the random generator. Raises: ValueError: In case of an invalid value for the "scale", mode" or "distribution" arguments. """ def __init__(self, scale=1.0, mode='fan_in', distribution='truncated_normal', seed=None): if scale <= 0.: raise ValueError('`scale` must be a positive float. Got:', scale) mode = mode.lower() if mode not in {'fan_in', 'fan_out', 'fan_avg'}: raise ValueError( 'Invalid `mode` argument: ' 'expected on of {"fan_in", "fan_out", "fan_avg"} ' 'but got', mode) distribution = distribution.lower() if distribution not in { 'truncated_normal', 'untruncated_normal', 'uniform' }: raise ValueError( 'Invalid `distribution` argument: ' 'expected one of {"truncated_normal", "untruncated_normal", "uniform"} ' 'but got', distribution) self.scale = scale self.mode = mode self.distribution = distribution self.seed = seed def __call__(self, shape, dtype=np.float32): fan_in, fan_out = _compute_fans(shape) scale = self.scale if self.mode == 'fan_in': scale /= max(1., fan_in) elif self.mode == 'fan_out': scale /= max(1., fan_out) else: scale /= max(1., float(fan_in + fan_out) / 2) np.random.seed(self.seed) if self.distribution == 'truncated_normal': mean = 0.0 # 0.879... = scipy.stats.truncnorm.std(a=-2, b=2, loc=0., scale=1.) stddev = np.sqrt(scale) / .87962566103423978 # Mimic the behavior of tf.random.truncated_normal, which truncates # at mean +/- 2 standard deviations lower_clip = mean - 2 * stddev upper_clip = mean + 2 * stddev a = (lower_clip - mean) / stddev b = (upper_clip - mean) / stddev return stats.truncnorm.rvs( a=a, b=b, loc=mean, scale=stddev, size=shape, ).astype(dtype) elif self.distribution == 'untruncated_normal': mean = 0.0 stddev = np.sqrt(scale) return np.random.normal( loc=mean, scale=stddev, size=shape, ).astype('float32') else: limit = np.sqrt(3. * scale) return np.random.uniform( low=-limit, high=limit, size=shape, ).astype(dtype) def get_config(self): return { 'scale': self.scale, 'mode': self.mode, 'distribution': self.distribution, 'seed': self.seed } ================================================ FILE: monolith/gpu_runner.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """GPU Runner.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function from absl import app from absl import flags from absl import logging import sys import os import time import tensorflow as tf import horovod.tensorflow as hvd from mpi4py import MPI from monolith.base_runner import BaseRunner from monolith.core import model_registry FLAGS = flags.FLAGS flags.DEFINE_string("task", default=None, help="Name of the task class to run.") flags.DEFINE_string( "model_dir", default=None, help=("The directory where the model and summaries are stored.")) flags.DEFINE_integer( "save_checkpoints_steps", default=None, help= ("Save checkpoint every save_checkpoints_steps. If None, no checkpoint saved." )) flags.DEFINE_enum("mode", "train", ["train_and_eval", "train", "eval"], "Job mode.") class GPURunner(BaseRunner): def __init__(self, task_param, *args, **kwargs): super(GPURunner, self).__init__(*args, **kwargs) # TODO(youlong.cheng): all the parse logic should genearte a hyperparam class. self._model_dir = FLAGS.model_dir self._save_checkpoints_steps = FLAGS.save_checkpoints_steps #TODO(hemang.jangle) Allow subclass task_params to override tpu_runner params self._task_param = task_param self._mode = FLAGS.mode def create_estimator(self, model_fn): """Creates the Estimator.""" if self._task_param.accelerator == "horovod": # Horovod: save checkpoints only on worker 0 to prevent other workers from # corrupting them. @Hao.sheng: However, we still need to use the same # model_dir so each worker where to load the checkpoint in the train_and_eval # mode. model_dir = self._model_dir #if hvd.rank() == 0 else None save_checkpoints_steps = self._save_checkpoints_steps if hvd.rank( ) == 0 else None config = tf.compat.v1.ConfigProto() config.graph_options.optimizer_options.global_jit_level = tf.compat.v1.OptimizerOptions.ON_1 config.gpu_options.allow_growth = True config.gpu_options.visible_device_list = str(hvd.local_rank()) config = tf.estimator.RunConfig( model_dir=model_dir, save_checkpoints_steps=save_checkpoints_steps, session_config=config) num_gpus = hvd.size() else: num_gpus = 1 config = tf.compat.v1.estimator.RunConfig( model_dir=self._model_dir, save_checkpoints_steps=self._save_checkpoints_steps) return tf.compat.v1.estimator.Estimator( model_fn=model_fn, params={ "train_batch_size": self._task_param.train.per_replica_batch_size, "eval_batch_size": self._task_param.eval.per_replica_batch_size, "accelerator": self._task_param.accelerator, "num_replicas": num_gpus, "hvd_rank": hvd.rank() if self._task_param.accelerator == "horovod" else 0 }, config=config) def run(self): try: current_step = tf.train.load_variable(self._model_dir, tf.compat.v1.GraphKeys.GLOBAL_STEP) except (TypeError, ValueError, tf.errors.NotFoundError): current_step = 0 logging.info("Current step :{}".format(current_step)) task = self._task_param.instantiate() input_fn_train = task.create_input_fn(tf.estimator.ModeKeys.TRAIN) input_fn_eval = task.create_input_fn(tf.estimator.ModeKeys.EVAL) model_fn = task.create_model_fn() if self._task_param.accelerator == "horovod": # Horovod: initialize Horovod. hvd.init() # Horovod: pin GPU to be used to process local rank (one GPU per process) gpus = tf.config.experimental.list_physical_devices('GPU') for gpu in gpus: tf.config.experimental.set_memory_growth(gpu, True) if gpus: tf.config.experimental.set_visible_devices(gpus[hvd.local_rank()], 'GPU') est = self.create_estimator(model_fn) start_timestamp = time.time() # This time will include compilation time if self._mode == 'train': if self._task_param.accelerator == "horovod": # Horovod: BroadcastGlobalVariablesHook broadcasts initial variable states from # rank 0 to all other processes. This is necessary to ensure consistent # initialization of all workers when training is started with random weights or # restored from a checkpoint. bcast_hook = hvd.BroadcastGlobalVariablesHook(0) est.train(input_fn_train, max_steps=self._task_param.train.max_steps, hooks=[bcast_hook]) else: est.train(input_fn_train, max_steps=self._task_param.train.max_steps) elif self._mode == 'eval': eval_output_dir = os.path.join(self._model_dir, 'eval') tf.io.gfile.makedirs(eval_output_dir) total_examples = self._task_param.input.eval_examples eval_batch_size = self._task_param.eval.per_replica_batch_size num_steps = total_examples // eval_batch_size logging.info( "Evaluation: total_examples:{} eval_batch_size:{} num_steps: {}". format(total_examples, eval_batch_size, num_steps)) eval_results = est.evaluate(input_fn_eval, steps=num_steps) logging.info("Eval results: {}".format(eval_results)) # Summary writer writes out eval metrics. summary_writer = tf.compat.v1.summary.FileWriter(eval_output_dir) self.write_summary(eval_results, summary_writer, current_step) summary_writer.close() else: # train_and_eval steps_per_eval = self._task_param.eval.steps_per_eval max_steps = self._task_param.train.max_steps eval_output_dir = os.path.join(self._model_dir, 'eval') tf.io.gfile.makedirs(eval_output_dir) while current_step < self._task_param.train.max_steps: # Train for up to steps_per_eval number of steps. # At the end of training, a checkpoint will be written to --model_dir. next_checkpoint = min(current_step + steps_per_eval, max_steps) if self._task_param.accelerator == "horovod": bcast_hook = hvd.BroadcastGlobalVariablesHook(0) est.train(input_fn_train, max_steps=next_checkpoint, hooks=[bcast_hook]) else: est.train(input_fn_train, max_steps=next_checkpoint) current_step = next_checkpoint logging.info( "Finished training up to step {}. Elapsed seconds {}.".format( next_checkpoint, time.time() - start_timestamp)) total_examples = self._task_param.input.eval_examples eval_batch_size = self._task_param.eval.per_replica_batch_size num_steps = total_examples // eval_batch_size if self._task_param.accelerator != "horovod" or hvd.rank() == 0: logging.info("Starting to evaluate.") time.sleep(10) #eval_results = hvd.allreduce(eval_results) eval_results = est.evaluate(input_fn_eval, steps=num_steps) logging.info("Eval results at step {}: {}".format( next_checkpoint, eval_results)) # Summary writer writes out eval metrics. summary_writer = tf.compat.v1.summary.FileWriter(eval_output_dir) self.write_summary(eval_results, summary_writer, current_step) summary_writer.close() # Hovorod: Make sure all workers are synced at the end of one round # https://github.com/horovod/horovod/issues/159 # https://github.com/horovod/horovod/issues/1380 if self._task_param.accelerator == "horovod": MPI.COMM_WORLD.barrier() elapsed_time = int(time.time() - start_timestamp) logging.info( "Finished training up to step {}. Elapsed seconds {}.".format( max_steps, elapsed_time)) def main(unused_argv): task_name = FLAGS.task task_param = model_registry.GetParams(task_name) logging.info("task_param: {}".format(str(task_param))) runner = GPURunner(task_param) runner.run() if __name__ == '__main__': logging.set_verbosity(logging.INFO) tf.compat.v1.disable_v2_behavior() app.run(main) ================================================ FILE: monolith/monolith_workspace.bzl ================================================ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") load("@bazel_tools//tools/build_defs/repo:git.bzl", "new_git_repository") load("@rules_python//python:pip.bzl", "pip_install") def monolith_workspace(): """Adds monolith workspace's dependencies.""" http_archive( name = "msgpack", build_file = "//third_party:msgpack/msgpack.BUILD", strip_prefix = "msgpack-3.3.0", sha256 = "6e114d12a5ddb8cb11f669f83f32246e484a8addd0ce93f274996f1941c1f07b", urls = ["https://github.com/msgpack/msgpack-c/releases/download/cpp-3.3.0/msgpack-3.3.0.tar.gz"], ) pip_install( name = "pip_deps", requirements = "//third_party/pip_deps:requirements.txt", python_interpreter = "python3", ) http_archive( name = "gperftools", build_file = "//third_party:gperftools/gperftools.BUILD", sha256 = "81bb34f546ac8cddd064f8935805f8eb19e3c9661188e127b4c90526e944ebff", urls = [ "https://github.com/gperftools/gperftools/releases/download/gperftools-2.7/gperftools-2.7.zip", ], patches = ["//third_party:gperftools/gperftools.patch"], patch_args = ["-p1"], strip_prefix = "gperftools-2.7", ) ================================================ FILE: monolith/native_training/BUILD ================================================ load("@com_google_protobuf//:protobuf.bzl", "py_proto_library") load("@pip_deps//:requirements.bzl", "requirement") load("@rules_python//python:defs.bzl", "py_binary", "py_library", "py_test") load("@org_tensorflow//tensorflow:tensorflow.bzl", "tf_cc_test", "tf_custom_op_library") package( default_visibility = [ "//monolith:__subpackages__", ], ) py_library( name = "monolith_export", srcs = ["monolith_export.py"], visibility = ["//visibility:public"], ) py_binary( name = "demo", srcs = ["demo.py"], deps = [ ":cpu_training", ":model", ":native_task", ], ) py_library( name = "model", srcs = ["model.py"], deps = [ ":feature", ":input", ":native_task", "//monolith/core:base_model_params", "//monolith/core:model_registry", "//monolith/native_training/metric:deep_insight_ops", ], ) py_library( name = "input", srcs = ["input.py"], ) py_library( name = "test_utils", testonly = 1, srcs = ["test_utils.py"], deps = [ ":entry", ":utils", "//monolith/native_training/runtime/hash_table:embedding_hash_table_py_proto", ], ) py_library( name = "clip_ops", srcs = ["clip_ops.py"], deps = [ ":device_utils", "//monolith:utils", "//monolith/native_training/runtime/ops:gen_monolith_ops", ], ) py_test( name = "clip_ops_test", srcs = ["clip_ops_test.py"], deps = [ ":clip_ops", ], ) py_proto_library( name = "hash_table_ops_py_proto", srcs = ["hash_table_ops.proto"], ) py_library( name = "hash_table_ops", srcs = ["hash_table_ops.py"], deps = [ ":basic_restore_hook", ":distributed_serving_ops", ":entry", ":feature", ":graph_meta", ":hash_filter_ops", ":hash_table_ops_py_proto", ":hash_table_utils", ":save_utils", ":utils", "//monolith:utils", "//monolith/native_training/model_export:export_context", "//monolith/native_training/runtime/hash_table:embedding_hash_table_py_proto", "//monolith/native_training/runtime/ops:gen_monolith_ops", "@org_tensorflow//tensorflow:tensorflow_py", ], ) py_test( name = "hash_table_ops_test", srcs = ["hash_table_ops_test.py"], deps = [ ":hash_filter_ops", ":hash_table_ops", ":learning_rate_functions", ], ) py_proto_library( name = "multi_hash_table_ops_py_proto", srcs = ["multi_hash_table_ops.proto"], ) py_library( name = "multi_hash_table_ops", srcs = ["multi_hash_table_ops.py"], deps = [ ":basic_restore_hook", ":distributed_serving_ops", ":entry", ":feature", ":graph_meta", ":hash_filter_ops", ":hash_table_utils", ":multi_hash_table_ops_py_proto", ":multi_type_hash_table", ":save_utils", ":utils", "//monolith:utils", "//monolith/native_training/model_export:export_context", "//monolith/native_training/proto:ckpt_info_py_proto", "//monolith/native_training/runtime/hash_table:embedding_hash_table_py_proto", "//monolith/native_training/runtime/ops:gen_monolith_ops", "@org_tensorflow//tensorflow:tensorflow_py", ], ) py_test( name = "multi_hash_table_ops_test", srcs = ["multi_hash_table_ops_test.py"], deps = [ ":hash_filter_ops", ":learning_rate_functions", ":multi_hash_table_ops", ":test_utils", ":utils", ], ) py_binary( name = "hash_table_utils", srcs = ["hash_table_utils.py"], ) py_test( name = "hash_table_utils_test", srcs = ["hash_table_utils_test.py"], deps = [ ":hash_table_ops", ":hash_table_utils", ], ) py_binary( name = "hash_table_ops_benchmark", srcs = ["hash_table_ops_benchmark.py"], deps = [ ":hash_filter_ops", ":hash_table_ops", ], ) py_library( name = "distribution_ops", srcs = ["distribution_ops.py"], deps = [ "//idl:example_py_proto", "//monolith:utils", "//monolith/native_training/runtime/ops:gen_monolith_ops", ], ) py_test( name = "distribution_ops_test", srcs = ["distribution_ops_test.py"], deps = [ ":distribution_ops", ], ) py_test( name = "fused_embedding_to_layout_test", srcs = ["fused_embedding_to_layout_test.py"], deps = [ ":distribution_ops", "//monolith/native_training/data:parsers_py", ], ) py_test( name = "distribution_ops_fused_test", srcs = ["distribution_ops_fused_test.py"], deps = [ ":distribution_ops", ], ) py_binary( name = "distribution_ops_benchmark", srcs = ["distribution_ops_benchmark.py"], deps = [ ":distribution_ops", ], ) py_binary( name = "distribution_ops_fused_benchmark", srcs = ["distribution_ops_fused_benchmark.py"], deps = [ ":distribution_ops", ], ) py_library( name = "distributed_ps", srcs = ["distributed_ps.py"], deps = [ ":distribution_ops", ":hash_table_ops", ":hash_table_utils", ":logging_ops", ":multi_type_hash_table", ":native_task_context", ":tensor_utils", ":utils", "//monolith/native_training/data:parsers_py", ], ) py_test( name = "distributed_ps_test", srcs = ["distributed_ps_test.py"], deps = [ ":distributed_ps", ":distributed_ps_factory", ":hash_filter_ops", ":learning_rate_functions", ":multi_hash_table_ops", ":test_utils", "//monolith/native_training/data:feature_utils_py", ], ) py_library( name = "distributed_ps_factory", srcs = ["distributed_ps_factory.py"], deps = [ ":distributed_ps", ":distributed_ps_sync", ":distribution_ops", ":embedding_combiners", ":entry", ":hash_filter_ops", ":hash_table_ops", ":multi_hash_table_ops", ":multi_type_hash_table", ":utils", ], ) py_test( name = "distributed_ps_factory_test", srcs = ["distributed_ps_factory_test.py"], deps = [ ":distributed_ps_factory", ":hash_filter_ops", ":test_utils", ], ) py_binary( name = "distributed_ps_benchmark", srcs = ["distributed_ps_benchmark.py"], deps = [ ":distributed_ps", ":hash_filter_ops", ":multi_type_hash_table", "//monolith/native_training/runtime/hash_table:embedding_hash_table_py_proto", ], ) py_library( name = "embedding_combiners", srcs = ["embedding_combiners.py"], deps = [ ":device_utils", ":distribution_ops", ":ragged_utils", ], ) py_test( name = "embedding_combiners_test", srcs = ["embedding_combiners_test.py"], deps = [ ":embedding_combiners", ], ) py_library( name = "distributed_ps_sync", srcs = ["distributed_ps_sync.py"], deps = [ ":distributed_ps", ":distribution_ops", ":feature_utils", ":multi_type_hash_table", ":prefetch_queue", ], ) py_test( name = "distributed_ps_sync_test", srcs = ["distributed_ps_sync_test.py"], deps = [ ":distributed_ps", ":distributed_ps_sync", ":learning_rate_functions", ":multi_hash_table_ops", ":test_utils", ], ) py_library( name = "multi_type_hash_table", srcs = ["multi_type_hash_table.py"], deps = [ ":device_utils", ":distribution_ops", ":entry", ":hash_filter_ops", ":hash_table_ops", ":hash_table_utils", ":prefetch_queue", ], ) py_test( name = "multi_type_hash_table_test", srcs = ["multi_type_hash_table_test.py"], deps = [ ":hash_filter_ops", ":hash_table_ops", ":learning_rate_functions", ":multi_type_hash_table", ":test_utils", ":utils", ], ) py_library( name = "hash_filter_ops", srcs = ["hash_filter_ops.py"], deps = [ ":basic_restore_hook", ":save_utils", ":utils", "//monolith/native_training/model_export:export_context", "//monolith/native_training/runtime/ops:gen_monolith_ops", "@org_tensorflow//tensorflow:tensorflow_py", ], ) py_test( name = "hash_filter_ops_test", srcs = ["hash_filter_ops_test.py"], deps = [ ":hash_filter_ops", "//monolith/native_training/runtime/hash_table:embedding_hash_table_py_proto", ], ) py_library( name = "touched_key_set_ops", srcs = ["touched_key_set_ops.py"], deps = [ "//monolith/native_training/runtime/ops:gen_monolith_ops", ], ) py_test( name = "touched_key_set_ops_test", srcs = ["touched_key_set_ops_test.py"], deps = [":touched_key_set_ops"], ) py_library( name = "distributed_serving_ops", srcs = ["distributed_serving_ops.py"], deps = [ "//monolith/agent_service:agent", "//monolith/native_training/runtime/ops:gen_monolith_ops", "//monolith/native_training/runtime/parameter_sync:parameter_sync_py_proto", ], ) py_test( name = "distributed_serving_ops_test", srcs = ["distributed_serving_ops_test.py"], deps = [ ":distributed_serving_ops", ":hash_table_ops", ], ) py_library( name = "entry", srcs = ["entry.py"], visibility = ["//visibility:public"], deps = [ ":monolith_export", "//monolith:utils", "//monolith/native_training/runtime/hash_table:embedding_hash_table_py_proto", ], ) py_test( name = "entry_test", srcs = ["entry_test.py"], deps = [ ":entry", ":learning_rate_functions", ], ) py_library( name = "feature", srcs = ["feature.py"], visibility = ["//visibility:public"], deps = [ "embedding_combiners", ":device_utils", ":distribution_ops", ":entry", ":learning_rate_functions", ":monolith_export", ":prefetch_queue", ":ragged_utils", "//monolith:utils", "//monolith/native_training/model_export:export_context", "//monolith/native_training/runtime/hash_table:embedding_hash_table_py_proto", ], ) py_library( name = "distribution_utils", srcs = ["distribution_utils.py"], visibility = ["//visibility:public"], deps = [ "//monolith/native_training/metric:metric_hook", "@org_tensorflow//tensorflow:tensorflow_py", ], ) py_library( name = "feature_utils", srcs = ["feature_utils.py"], visibility = ["//visibility:public"], deps = [ ":clip_ops", ":feature", ":native_task", ":prefetch_queue", ], ) py_test( name = "feature_utils_test", srcs = ["feature_utils_test.py"], deps = [ ":feature_utils", ":prefetch_queue", ], ) py_test( name = "feature_test", srcs = ["feature_test.py"], deps = [ ":feature", ":learning_rate_functions", ], ) py_library( name = "native_task", srcs = ["native_task.py"], visibility = ["//visibility:public"], deps = [ ":feature", ":prefetch_queue", "//monolith/core:base_task", "//monolith/native_training/model_export:export_context", ], ) py_library( name = "native_model", srcs = ["native_model.py"], visibility = ["//visibility:public"], deps = [ ":dense_reload_utils", ":entry", ":estimator", ":feature", ":file_ops", ":graph_utils", ":mlp_utils", ":monolith_export", ":native_task", ":native_task_context", ":utils", "//monolith:utils", "//monolith/core:base_layer", "//monolith/native_training:feature_utils", "//monolith/native_training/data:feature_list", "//monolith/native_training/data:utils", "//monolith/native_training/layers", "//monolith/native_training/metric:metric_hook", "//monolith/native_training/metric:utils", "//monolith/native_training/model_dump:dump_utils", ], ) py_library( name = "cpu_training_additional_deps", ) py_library( name = "cpu_training", srcs = ["cpu_training.py"], visibility = ["//visibility:public"], deps = [ ":barrier_ops", ":basic_restore_hook", ":cluster_manager", ":cpu_training_additional_deps", ":device_utils", ":distributed_ps_factory", ":distribution_ops", ":distribution_utils", ":env_utils", ":feature", ":gflags_utils", ":hash_table_ops", ":hash_table_utils", ":hvd_lib", ":logging_ops", ":mlp_utils", ":multi_type_hash_table", ":native_task", ":native_task_context", ":net_utils", ":prefetch_queue", ":ps_benchmark", ":save_utils", ":service_discovery", ":session_run_hooks", ":signal_utils", ":sync_hooks", ":sync_training_hooks", ":tensor_utils", ":utils", ":variables", ":yarn_runtime", "//monolith/agent_service:replica_manager", "//monolith/core:hyperparams", "//monolith/native_training:dense_reload_utils", "//monolith/native_training:hash_filter_ops", "//monolith/native_training/alert", "//monolith/native_training/data/training_instance:parser_utils", "//monolith/native_training/hooks:ckpt_hooks", "//monolith/native_training/hooks:ckpt_info", "//monolith/native_training/hooks:feature_engineering_hooks", "//monolith/native_training/hooks:hook_utils", "//monolith/native_training/hooks:ps_check_hooks", "//monolith/native_training/hooks:session_hooks", "//monolith/native_training/hooks/server:server_lib", "//monolith/native_training/metric:deep_insight_ops", "//monolith/native_training/metric:metric_hook", "//monolith/native_training/model_dump:dump_utils", "//monolith/native_training/model_export", "//monolith/native_training/model_export:export_context", "//monolith/native_training/model_export:export_hooks", "//monolith/native_training/model_export:export_utils", "//monolith/native_training/model_export:saved_model_exporters", "//monolith/native_training/proto:debugging_info_py_proto", "//monolith/native_training/runtime/hash_table:embedding_hash_table_py_proto", requirement("numpy"), requirement("pyarrow"), requirement("cityhash"), ], ) py_test( name = "cpu_training_test", size = "large", srcs = ["cpu_training_test.py"], data = ["cpu_training_distributed_test_binary"], shard_count = 5, deps = [ ":cpu_training", ":native_task", ":service_discovery", "//monolith/native_training/debugging:debugging_server", ], ) py_test( name = "cpu_training_multi_hash_table_test", size = "large", srcs = ["cpu_training_test.py"], args = ["--use_native_multi_hash_table"], main = "cpu_training_test.py", shard_count = 5, deps = [ ":cpu_training_test", ], ) py_binary( name = "cpu_training_distributed_test_binary", srcs = ["cpu_training_distributed_test_binary.py"], deps = [ ":cluster_manager", ":cpu_training", ":native_task", ":service_discovery", ], ) py_test( name = "cpu_sync_training_test", srcs = ["cpu_sync_training_test.py"], deps = [ ":cpu_training", ":native_task", ], ) py_test( name = "model_comp_test", srcs = ["model_comp_test.py"], deps = [ ":cpu_training", ":estimator", ":native_model", ], ) py_library( name = "native_task_context", srcs = ["native_task_context.py"], deps = ["//monolith/agent_service:backends"], ) py_library( name = "utils", srcs = ["utils.py"], deps = [ "//idl:proto_parser_py_proto", "//monolith/core:base_layer", "//monolith/core:hyperparams", "//monolith/core:py_utils", "@org_tensorflow//tensorflow:tensorflow_py", ], ) py_test( name = "utils_test", srcs = ["utils_test.py"], deps = [":utils"], ) py_library( name = "file_ops", srcs = ["file_ops.py"], deps = [ "//monolith/native_training/runtime/ops:gen_monolith_ops", ], ) py_test( name = "file_ops_test", srcs = ["file_ops_test.py"], deps = [ ":file_ops", "//monolith:utils", ], ) py_library( name = "env_utils", srcs = ["env_utils.py"], visibility = ["//visibility:public"], ) py_test( name = "env_utils_test", srcs = ["env_utils_test.py"], deps = [ ":env_utils", ], ) py_library( name = "logging_ops", srcs = ["logging_ops.py"], deps = [ "//monolith:utils", "//monolith/native_training/runtime/ops:gen_monolith_ops", "//monolith/native_training/runtime/ops:logging_ops_py_proto", ], ) py_test( name = "logging_ops_test", srcs = ["logging_ops_test.py"], deps = [ ":logging_ops", ], ) py_library( name = "session_run_hooks", srcs = ["session_run_hooks.py"], ) py_test( name = "session_run_hooks_test", srcs = ["session_run_hooks_test.py"], deps = [ ":session_run_hooks", requirement("freezegun"), ], ) py_library( name = "hvd_lib", srcs = ["hvd_lib.py"], ) py_library( name = "static_reshape_op", srcs = ["static_reshape_op.py"], deps = [ "//monolith/native_training/runtime/ops:gen_monolith_ops", ], ) py_test( name = "static_reshape_op_test", srcs = ["static_reshape_op_test.py"], deps = [ ":static_reshape_op", ], ) py_library( name = "sync_training_hooks", srcs = ["sync_training_hooks.py"], deps = [ ":distributed_serving_ops", ":hash_table_ops", ":hvd_lib", ":native_task", "//monolith/agent_service:backends", "//monolith/native_training/data:datasets_py", ], ) py_test( name = "sync_training_hooks_test", srcs = ["sync_training_hooks_test.py"], deps = [ ":sync_training_hooks", ], ) py_library( name = "service_discovery", srcs = ["service_discovery.py"], deps = [ ":consul", ":mlp_utils", ":zk_utils", ], ) py_test( name = "service_discovery_test", srcs = ["service_discovery_test.py"], deps = [ ":service_discovery", ], ) py_library( name = "ragged_utils", srcs = ["ragged_utils.py"], deps = [ "//monolith/native_training/runtime/ops:gen_monolith_ops", ], ) py_test( name = "ragged_utils_test", srcs = ["ragged_utils_test.py"], deps = [ ":ragged_utils", ], ) py_library( name = "tensor_utils", srcs = ["tensor_utils.py"], deps = [ ":static_reshape_op", ], ) py_test( name = "tensor_utils_test", srcs = ["tensor_utils_test.py"], deps = [ ":tensor_utils", ], ) py_library( name = "graph_meta", srcs = ["graph_meta.py"], ) py_library( name = "graph_utils", srcs = ["graph_utils.py"], ) py_library( name = "consul", srcs = ["consul.py"], ) py_test( name = "consul_test", srcs = ["consul_test.py"], deps = [ ":consul", ], ) py_library( name = "barrier_ops", srcs = ["barrier_ops.py"], deps = [ ":basic_restore_hook", ], ) py_test( name = "barrier_ops_test", srcs = ["barrier_ops_test.py"], deps = [ ":barrier_ops", ], ) py_library( name = "basic_restore_hook", srcs = ["basic_restore_hook.py"], ) py_test( name = "basic_restore_hook_test", srcs = ["basic_restore_hook_test.py"], deps = [ ":basic_restore_hook", ], ) py_library( name = "prefetch_queue", srcs = ["prefetch_queue.py"], deps = [ ":nested_tensors", ":utils", ], ) py_test( name = "prefetch_queue_test", srcs = ["prefetch_queue_test.py"], deps = [ ":prefetch_queue", ], ) py_proto_library( name = "monolith_checkpoint_state_proto", srcs = [":monolith_checkpoint_state.proto"], ) py_library( name = "save_utils", srcs = ["save_utils.py"], deps = [ ":dense_reload_utils", ":monolith_checkpoint_state_proto", ":session_run_hooks", ":utils", "//monolith/native_training:native_task_context", "//monolith/native_training/metric:cli", ], ) py_test( name = "save_utils_test", srcs = ["save_utils_test.py"], deps = [ ":save_utils", requirement("freezegun"), ], ) py_library( name = "sync_hooks", srcs = ["sync_hooks.py"], ) py_test( name = "sync_hooks_test", srcs = ["sync_hooks_test.py"], deps = [ ":sync_hooks", ], ) py_test( name = "restore_test", srcs = ["restore_test.py"], deps = [ ":hash_table_ops", ":save_utils", ":utils", ], ) py_library( name = "variables", srcs = ["variables.py"], deps = [ ":graph_meta", ], ) py_test( name = "variables_test", srcs = ["variables_test.py"], deps = [ ":test_utils", ":variables", ], ) py_library( name = "ps_benchmark", srcs = ["ps_benchmark.py"], deps = [ ":logging_ops", ":native_task", ":utils", "//monolith/native_training/optimizers:adamom", ], ) py_test( name = "ps_benchmark_test", srcs = ["ps_benchmark_test.py"], deps = [ ":cpu_training", ":ps_benchmark", ], ) py_library( name = "estimator", srcs = ["estimator.py"], deps = [ ":cpu_training", ":distribution_utils", ":env_utils", ":monolith_export", ":native_task", ":runner_utils", ":service_discovery", ":zk_utils", "//monolith:utils", "//monolith/agent_service:backends", "//monolith/agent_service:replica_manager", "//monolith/agent_service:utils", "//monolith/native_training/data:item_pool_hook", "//monolith/native_training/model_export:saved_model_exporters", ], ) # This test is buggy. py_test( name = "estimator_test", srcs = ["estimator_test.py"], deps = [ ":cpu_training", ":estimator", ":input", ":model", ":native_task", ":service_discovery", ":utils", "//monolith/native_training/data/training_instance:instance_dataset_ops_py", "//monolith/native_training/data/training_instance:parse_instance_ops_py", "//monolith/native_training/model_export:saved_model_exporters", ], ) # This test is buggy. # py_test( # name = "estimator_dist_test", # srcs = ["estimator_dist_test.py"], # deps = [ # ":cpu_training", # ":estimator", # ":input", # ":model", # ":native_task", # ":service_discovery", # ":utils", # "//monolith/native_training/data/training_instance:instance_dataset_ops_py", # "//monolith/native_training/data/training_instance:parse_instance_ops_py", # "//monolith/native_training/model_export:saved_model_exporters", # "//monolith/native_training/tasks/reckon:params", # ], # ) py_library( name = "zk_utils", srcs = ["zk_utils.py"], srcs_version = "PY3", deps = [ ":env_utils", requirement("kazoo"), ], ) py_library( name = "device_utils", srcs = ["device_utils.py"], deps = [ ":distribution_utils", "@org_tensorflow//tensorflow:tensorflow_py", ], ) py_test( name = "device_utils_test", srcs = ["device_utils_test.py"], deps = [ ":device_utils", ], ) py_library( name = "gflags_utils", srcs = ["gflags_utils.py"], ) py_test( name = "gflags_utils_test", srcs = ["gflags_utils_test.py"], deps = [ ":gflags_utils", ], ) py_library( name = "runner_utils", srcs = ["runner_utils.py"], deps = [ ":cpu_training", ":env_utils", ":gflags_utils", ":monolith_checkpoint_state_proto", ":service_discovery", ], ) py_test( name = "runner_utils_test", srcs = ["runner_utils_test.py"], deps = [ ":runner_utils", ], ) py_library( name = "learning_rate_functions", srcs = ["learning_rate_functions.py"], ) py_test( name = "learning_rate_functions_test", srcs = ["learning_rate_functions_test.py"], deps = [ ":learning_rate_functions", ], ) py_library( name = "yarn_runtime", srcs = ["yarn_runtime.py"], deps = [ ":net_utils", "//monolith/native_training/proto:primus_am_service_py_proto", "//monolith/native_training/proto:primus_am_service_py_proto_grpc", ], ) py_test( name = "yarn_runtime_test", srcs = ["yarn_runtime_test.py"], deps = [ "yarn_runtime", ], ) py_library( name = "net_utils", srcs = ["net_utils.py"], ) py_test( name = "net_utils_test", srcs = ["net_utils_test.py"], deps = [ ":net_utils", ], ) py_library( name = "cluster_manager", srcs = ["cluster_manager.py"], deps = [ ":service_discovery", "//monolith/native_training/metric:cli", ], ) py_test( name = "cluster_manager_test", srcs = ["cluster_manager_test.py"], deps = [ ":cluster_manager", ], ) py_library( name = "signal_utils", srcs = ["signal_utils.py"], ) py_test( name = "signal_utils_test", srcs = ["signal_utils_test.py"], deps = [ ":signal_utils", ], ) py_library( name = "gen_seq_mask", srcs = ["gen_seq_mask.py"], deps = [ "//monolith/native_training/runtime/ops:gen_monolith_ops", ], ) py_test( name = "gen_seq_mask_test", srcs = ["gen_seq_mask_test.py"], deps = [ ":gen_seq_mask", "//monolith:utils", ], ) py_test( name = "serving_ps_test", srcs = ["serving_ps_test.py"], data = ["//idl:example_cc_proto"], deps = [ ":distribution_ops", "//idl:example_py_proto", "//monolith:utils", "@org_tensorflow//tensorflow:tensorflow_py", ], ) py_library( name = "dense_reload_utils", srcs = ["dense_reload_utils.py"], deps = [ ":basic_restore_hook", "//monolith/native_training/model_export:export_context", "@org_tensorflow//tensorflow:tensorflow_py", ], ) py_test( name = "dense_reload_utils_test", srcs = ["dense_reload_utils_test.py"], deps = [ ":dense_reload_utils", "//monolith:utils", ], ) py_library( name = "nested_tensors", srcs = ["nested_tensors.py"], ) py_test( name = "nested_tensors_test", srcs = ["nested_tensors_test.py"], deps = [ ":nested_tensors", ], ) py_library( name = "mlp_utils", srcs = ["mlp_utils.py"], deps = [ ":distribution_utils", ":yarn_runtime", "//monolith/native_training/model_export:export_context", ], ) ================================================ FILE: monolith/native_training/alert/BUILD ================================================ load("@rules_python//python:defs.bzl", "py_library", "py_test") load("@com_google_protobuf//:protobuf.bzl", "py_proto_library") package(default_visibility = ["//monolith/native_training/alert:__subpackages__"]) # Including all public interfaces. py_library( name = "alert", visibility = ["//visibility:public"], deps = [ ":alert_manager", ":alert_py_proto", ], ) py_proto_library( name = "alert_py_proto", srcs = ["alert.proto"], ) py_library( name = "alert_manager_internal_deps", ) py_library( name = "alert_manager", srcs = ["alert_manager.py"], deps = [ ":alert_manager_internal_deps", ":alert_py_proto", ], ) py_test( name = "alert_manager_test", srcs = ["alert_manager_test.py"], deps = [ ":alert_manager", ], ) ================================================ FILE: monolith/native_training/alert/alert.proto ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. syntax = "proto2"; package monolith; message KafkaAlertProto { optional string topic = 1; optional string group = 2; } message TrainingAlertProto { optional string prefix = 1; } message AlertMessageProto { optional string user = 1; } message AlertProto { optional AlertMessageProto alert_message = 1; optional KafkaAlertProto kafka_alert = 2; optional TrainingAlertProto training_alert = 3; // How long the monitoring will be started optional int64 start_delay_sec = 1000; optional int64 check_interval_sec = 10001 [ default = 1800 ]; } ================================================ FILE: monolith/native_training/alert/alert_manager.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import copy import threading import time import traceback from typing import List, NamedTuple from absl import flags from absl import logging from google.protobuf import text_format FLAGS = flags.FLAGS flags.DEFINE_string("monolith_alert_proto", "", "The text format of alert proto.") def get_default_alert_manager(): return None ================================================ FILE: monolith/native_training/alert/alert_manager_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import threading import unittest from unittest import mock from absl.testing import absltest from absl.testing import flagsaver from absl import flags from absl import app from google.protobuf import text_format from monolith.native_training.alert import alert_pb2 from monolith.native_training.alert import alert_manager FLAGS = flags.FLAGS if __name__ == "__main__": absltest.main() ================================================ FILE: monolith/native_training/barrier_ops.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. #coding:utf-8 import time import threading from absl import logging import tensorflow as tf from monolith.native_training import basic_restore_hook class BarrierAlreadyPlacedError(Exception): pass class BarrierOp: """ A barrier operation that used to blocking worker by chief. Thread safe. """ def __init__(self, capacity, is_chief=True, wait_seconds=1, name_prefix="default", barrier_callbacks=None): self._capacity = capacity self._wait_seconds = wait_seconds with tf.name_scope(name_prefix + "_barrier_op"): # For non-chief workers, barrier vars are treated as global variables. collections = [tf.compat.v1.GraphKeys.LOCAL_VARIABLES ] if is_chief else [tf.compat.v1.GraphKeys.VARIABLES] self._barrier_vars = tf.compat.v1.get_variable("barrier_var", initializer=[False] * capacity, collections=collections) self._idx_ph = tf.compat.v1.placeholder(tf.int32, shape=[], name="index_placeholder") self._place_op = self._barrier_vars[self._idx_ph].assign(True) self._remove_op = self._barrier_vars[self._idx_ph].assign(False) self._barrier_placed_tensor = self._barrier_vars[0] self._barrier_callbacks = barrier_callbacks or [] self._action = tf.compat.v1.get_variable( "barrier_op_action", dtype=tf.string, initializer="", trainable=False, collections=[tf.compat.v1.GraphKeys.LOCAL_VARIABLES]) self._action_placeholder = tf.compat.v1.placeholder( tf.string, [], "barrier_op_action_placeholder") self._action_assign = self._action.assign(self._action_placeholder, read_value=False) self._lock = threading.Lock() def place_barrier(self, session, action: str = ""): with self._lock: if self.is_barrier_placed(session): raise BarrierAlreadyPlacedError() session.run([self._place_op, self._action_assign], feed_dict={ self._action_placeholder: action, self._idx_ph: 0 }) self._run_barrier_callbacks(action, session) def remove_barrier(self, session): with self._lock: # We are more generous about removing barrier # We don't check barrier state here session.run(self._remove_op, feed_dict={self._idx_ph: 0}) def is_barrier_placed(self, session): return session.run(self.barrier_placed_tensor) @property def barrier_placed_tensor(self): return self._barrier_placed_tensor @property def capacity(self): return self._capacity def is_barrier_removed(self, session): return not self.is_barrier_placed(session) def wait_until_barrier_removed(self, session, index): with self._lock: if index <= 0 or index >= self._capacity: raise ValueError( "Index [{}] must be non-negative and less than capacity [{}]. ". format(index, self._capacity)) session.run(self._place_op, feed_dict={self._idx_ph: index}) action = session.run(self._action).decode() self._run_barrier_callbacks(action, session) while not self.is_barrier_removed(session): logging.log_every_n_seconds( logging.INFO, "The worker {} waits until barrier removed.".format(index), 60) time.sleep(self._wait_seconds) session.run(self._remove_op, feed_dict={self._idx_ph: index}) def is_all_blocked(self, session): barriers = session.run(self._barrier_vars) count = sum(barriers) return count == self._capacity def is_none_blocked(self, session): barriers = session.run(self._barrier_vars) count = sum(barriers) return count == 0 def get_unblocked_indices(self, session): barriers = session.run(self._barrier_vars) return [i for i in range(self._capacity) if not barriers[i]] def get_blocked_indices(self, session): barriers = session.run(self._barrier_vars) return [i for i in range(self._capacity) if barriers[i]] def _run_barrier_callbacks(self, action: str, session: tf.compat.v1.Session): for callback in self._barrier_callbacks: callback(action, session) class BarrierHook(tf.estimator.SessionRunHook): """During training, check the barrier condition for worker.""" def __init__(self, index, barrier_op: BarrierOp): self._index = index self._barrier_op = barrier_op def before_run(self, run_context): return tf.estimator.SessionRunArgs(self._barrier_op.barrier_placed_tensor) def after_run(self, run_context, run_values): barrier_placed_value = run_values.results if self._index > 0 and barrier_placed_value: self._barrier_op.wait_until_barrier_removed(run_context.session, self._index) ================================================ FILE: monolith/native_training/barrier_ops_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import tempfile import threading import time import tensorflow as tf from tensorflow.python.training import monitored_session from monolith.native_training import barrier_ops class BarrierOpsTest(tf.test.TestCase): def test_basic(self): barrier_op = barrier_ops.BarrierOp(2, False) with tf.compat.v1.Session() as sess: self.evaluate(tf.compat.v1.global_variables_initializer()) self.evaluate(tf.compat.v1.local_variables_initializer()) barrier_op.place_barrier(sess) self.assertTrue(barrier_op.is_barrier_placed(sess)) with self.assertRaises(barrier_ops.BarrierAlreadyPlacedError): barrier_op.place_barrier(sess) barrier_op.remove_barrier(sess) self.assertTrue(barrier_op.is_barrier_removed(sess)) def _run(self, train_op, sess, step=1): for i in range(step): sess.run(train_op) def test_barrier_hook_not_blocked(self): with tf.compat.v1.Graph().as_default(): global_step = tf.compat.v1.train.get_or_create_global_step() train_op = tf.compat.v1.assign_add(global_step, 1) barrier_op = barrier_ops.BarrierOp(2, False) hook = barrier_ops.BarrierHook(1, barrier_op) with tf.compat.v1.Session() as sess: self.evaluate(tf.compat.v1.global_variables_initializer()) self.evaluate(tf.compat.v1.local_variables_initializer()) mon_sess = monitored_session._HookedSession(sess, [hook]) worker = threading.Thread(target=self._run, args=(train_op, mon_sess, 5)) worker.daemon = True worker.start() worker.join() self.assertEqual(5, sess.run(global_step)) def test_barrier_hook_blocked(self): with tf.compat.v1.Graph().as_default(): global_step = tf.compat.v1.train.get_or_create_global_step() train_op = tf.compat.v1.assign_add(global_step, 1) called_variable = tf.Variable(False, trainable=False) barrier_action = "test_action" def action_callback(action, session): if action == barrier_action: session.run(called_variable.assign(True)) barrier_op = barrier_ops.BarrierOp(2, False, barrier_callbacks=[action_callback]) hook = barrier_ops.BarrierHook(1, barrier_op) with tf.compat.v1.Session() as sess: self.evaluate(tf.compat.v1.global_variables_initializer()) self.evaluate(tf.compat.v1.local_variables_initializer()) mon_sess = monitored_session._HookedSession(sess, [hook]) barrier_op.place_barrier(sess, action=barrier_action) worker = threading.Thread(target=self._run, args=(train_op, mon_sess, 5)) worker.daemon = True worker.start() while not barrier_op.is_all_blocked(sess): time.sleep(0.1) # Hook is pending. self.assertEqual(1, sess.run(global_step)) self.assertEqual(sess.run(called_variable), True) barrier_op.remove_barrier(sess) worker.join() self.assertTrue(barrier_op.is_none_blocked(sess)) self.assertEqual(5, sess.run(global_step)) if __name__ == '__main__': tf.compat.v1.disable_eager_execution() tf.test.main() ================================================ FILE: monolith/native_training/basic_restore_hook.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. #coding:utf-8 from absl import logging from tensorflow.python.training import session_run_hook class CheckpointRestorerListener(): """Interface for listeners that take action before or after restore.""" def begin(self): pass def before_restore(self, session): pass def after_restore(self, session): pass def end(self, session): pass class CheckpointRestorerHook(session_run_hook.SessionRunHook): """ Restores checkpoints at the begining. Use to call 'CheckpointRestorerListener'. The real restore action is implemented in 'CheckpointRestorerListener'. """ def __init__(self, listeners=None): """Initializes a `CheckpointRestorerHook`. Args: listeners: List of `CheckpointRestorerListener` subclass instances. Used for callbacks that run immediately before or after this hook restores the checkpoint. """ logging.info("Create CheckpointRestorerHook.") self._listeners = listeners or [] def begin(self): for l in self._listeners: l.begin() def after_create_session(self, session, coord): self._restore(session) def _restore(self, session): """Restores the latest checkpoint.""" logging.info("Calling checkpoint restorer listeners.") for l in self._listeners: l.before_restore(session) # None restore actions in this hook. for l in self._listeners: l.after_restore(session) ================================================ FILE: monolith/native_training/basic_restore_hook_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import tensorflow as tf from tensorflow.python.training import session_run_hook from monolith.native_training import basic_restore_hook class CountCheckpointRestorerListener( basic_restore_hook.CheckpointRestorerListener): def __init__(self): self.begin_count = 0 self.before_restore_count = 0 self.after_restore_count = 0 def begin(self): self.begin_count += 1 def before_restore(self, session): self.before_restore_count += 1 def after_restore(self, session): self.after_restore_count += 1 def get_counts(self): return { 'begin': self.begin_count, 'before_restore': self.before_restore_count, 'after_restore': self.after_restore_count } class CountHook(session_run_hook.SessionRunHook): def __init__(self): self.after_create_session_count = 0 self.before_run_count = 0 self.after_run_count = 0 self.end_count = 0 def after_create_session(self, session, coord): self.after_create_session_count += 1 def before_run(self, run_context): self.before_run_count += 1 def after_run(self, run_context, run_values): self.after_run_count += 1 def end(self, session): self.end_count += 1 def get_counts(self): return { 'after_create_session': self.after_create_session_count, 'before_run': self.before_run_count, 'after_run': self.after_run_count, 'end': self.end_count, } class CheckpointRestorerHookTest(tf.test.TestCase): def test_restore_only_in_after_create_session(self): with tf.compat.v1.Graph().as_default(): global_step = tf.compat.v1.train.get_or_create_global_step() train_op = tf.compat.v1.assign_add(global_step, 1) listener = CountCheckpointRestorerListener() hook1 = basic_restore_hook.CheckpointRestorerHook(listeners=[listener]) hook2 = CountHook() with tf.compat.v1.train.SingularMonitoredSession( hooks=[hook1, hook2]) as sess: # after_create_session self.assertEqual({ 'begin': 1, 'before_restore': 1, 'after_restore': 1, }, listener.get_counts()) self.assertEqual( { 'after_create_session': 1, 'before_run': 0, 'after_run': 0, 'end': 0, }, hook2.get_counts()) for _ in range(2): sess.run(train_op) self.assertEqual({ 'begin': 1, 'before_restore': 1, 'after_restore': 1, }, listener.get_counts()) self.assertEqual( { 'after_create_session': 1, 'before_run': 2, 'after_run': 2, 'end': 1, }, hook2.get_counts()) def test_two_listeners_with_restorer(self): with tf.compat.v1.Graph().as_default(): global_step = tf.compat.v1.train.get_or_create_global_step() train_op = tf.compat.v1.assign_add(global_step, 1) listener1 = CountCheckpointRestorerListener() listener2 = CountCheckpointRestorerListener() hook = basic_restore_hook.CheckpointRestorerHook( listeners=[listener1, listener2]) with tf.compat.v1.train.SingularMonitoredSession(hooks=[hook]) as sess: self.assertEqual({ 'begin': 1, 'before_restore': 1, 'after_restore': 1, }, listener1.get_counts()) self.assertEqual(listener1.get_counts(), listener1.get_counts()) if __name__ == '__main__': tf.compat.v1.disable_eager_execution() tf.test.main() ================================================ FILE: monolith/native_training/clip_ops.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os from typing import List, Tuple import tensorflow as tf from monolith.native_training import device_utils from monolith.native_training.runtime.ops import gen_monolith_ops gen_clip_ops = gen_monolith_ops def _global_norm(t_list: List[tf.Tensor]) -> tf.Tensor: """Computes the global norm of multiple tensors.""" if len(t_list) == 0: return None l2_sum = gen_clip_ops.global_l2_reduce(t_list) return tf.sqrt(l2_sum) def clip_by_global_norm(t_list: List[tf.Tensor], clip_norm: tf.Tensor, use_norm=None) -> Tuple[List[tf.Tensor], tf.Tensor]: """Clips values of multiple tensors by the ratio of the sum of their norms. Given a tuple or list of tensors `t_list`, and a clipping ratio `clip_norm`, this operation returns a list of clipped tensors `list_clipped` and the global norm (`global_norm`) of all tensors in `t_list`. Optionally, if you've already computed the global norm for `t_list`, you can specify the global norm with `use_norm`. To perform the clipping, the values `t_list[i]` are set to: t_list[i] * clip_norm / max(global_norm, clip_norm) where: global_norm = sqrt(sum([l2norm(t)**2 for t in t_list])) If `clip_norm > global_norm` then the entries in `t_list` remain as they are, otherwise they're all shrunk by the global ratio. If `global_norm == infinity` then the entries in `t_list` are all set to `NaN` to signal that an error occurred. Args: t_list: A list of mixed `Tensors`. clip_norm: A 0-D (scalar) `Tensor` > 0. The clipping ratio. use_norm: A 0-D (scalar) `Tensor` of type `float` (optional). The global norm to use. If not provided, TensorFlow `global_norm()` is used to compute the norm. Returns: list_clipped: A list of `Tensors` of the same type as `list_t`. global_norm: A 0-D (scalar) `Tensor` representing the global norm. Raises: TypeError: If `t_list` is not a sequence. """ with tf.name_scope('clip_by_global_norm'): if not isinstance(t_list, list): raise TypeError("t_list should be a list") if len(t_list) == 0: return t_list, 0 if use_norm is not None: return gen_clip_ops.monolith_clip_by_global_norm( t_list, use_norm, clip_norm), use_norm if device_utils.within_placement_context_of("GPU"): return gen_clip_ops.monolith_clip_by_global_norm_fused(t_list, clip_norm) norm_fn = _global_norm if device_utils.within_placement_context_of( "GPU") else tf.linalg.global_norm global_norm = norm_fn(t_list) list_clipped = gen_clip_ops.monolith_clip_by_global_norm( t_list, global_norm, clip_norm) return list_clipped, global_norm ================================================ FILE: monolith/native_training/clip_ops_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import numpy as np import tensorflow as tf from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from monolith.native_training import clip_ops class ClipOpsTest(tf.test.TestCase): def _test_clip_by_global_norm(self, inputs, clip_norm, expected=None): with tf.compat.v1.Session() as sess, test_util.use_gpu(): t_list = [ops.convert_to_tensor(t, dtype=tf.float32) for t in inputs] clipped = clip_ops.clip_by_global_norm(t_list, clip_norm) r, second_branch_check_input_soundness = sess.run([clipped, t_list]) result, _ = r if expected is None: expected, _ = sess.run(tf.clip_by_global_norm(t_list, clip_norm)) self.assertAllClose(result, expected) # second_branch_check_input_soundness will break allclose, # if input mem (t_list) gets modified inplace (clipped). self.assertAllClose(second_branch_check_input_soundness, inputs) def test_clip_by_global_norm(self): # Simple example self._test_clip_by_global_norm([[-3.0, 0.0, 0.0], [4.0, 0.0, 0.0]], 4.0, [[-2.4, 0.0, 0.0], [3.2, 0.0, 0.0]]) # Uneven shape example self._test_clip_by_global_norm([[-3.0, 0.0, 0.0], [0.0, 0.0, 4.0, 0.0]], 4.0, [[-2.4, 0.0, 0.0], [0.0, 0.0, 3.2, 0.0]]) # No clipping. self._test_clip_by_global_norm([[1.0, 0.0, 0.0], [1.0, 0.0, 0.0]], 4.0, [[1.0, 0.0, 0.0], [1.0, 0.0, 0.0]]) # Zero norm. self._test_clip_by_global_norm([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], 4.0, [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]) # Exploded grad. nan_arr = np.empty((2, 3)) nan_arr[:] = np.nan self._test_clip_by_global_norm( [[float('inf'), float('inf'), float('inf')], [float('inf'), float('inf'), float('inf')]], 4.0, nan_arr) # Large grad. DENSE_SHAPES = [(328, 128), (128,), (128,), (128, 64), (64,), (64,), (1,), (256, 256), (256,), (256,), (256, 128), (128,), (128,), (128, 1), (1,), (1,), (2488, 256), (256,), (256,), (3184, 256), (256,), (256,), (96, 128), (128,), (128,), (128, 64), (64,), (64,), (1,), (64, 16), (16,), (16,), (1609, 2048), (2048,), (2048,), (2048, 1024), (1024,), (1024,), (1024, 512), (512,), (512,), (512, 256), (256,), (256,), (256, 1), (1,), (1,), (96, 64), (64,), (64,), (64, 1), (1,), (1,)] grads = [np.random.uniform(size=s) for s in DENSE_SHAPES] self._test_clip_by_global_norm(grads, 1.0) class NormOpsTest(tf.test.TestCase): def _test_global_norm(self, inputs, expected): with tf.compat.v1.Session() as sess, test_util.use_gpu(): inputs = [ops.convert_to_tensor(t, dtype=tf.float32) for t in inputs] g = sess.run(clip_ops._global_norm(inputs)) self.assertAllClose(g, expected) @test_util.run_gpu_only def test_it(self): self._test_global_norm( [[float('inf'), float('inf'), float('inf')], [float('inf'), float('inf'), float('inf')]], float('inf')) self._test_global_norm([[-3.0, 0.0, 0.0], [4.0, 0.0, 0.0]], 5.0) self._test_global_norm([[-3.0, 0.0, 0.0], [0.0, 0.0, 4.0, 0.0]], 5.0) if __name__ == "__main__": tf.compat.v1.disable_eager_execution() tf.test.main() ================================================ FILE: monolith/native_training/cluster_manager.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import time from typing import Dict, List, Tuple from absl import logging import tensorflow as tf from monolith.native_training.metric import cli from monolith.native_training.service_discovery import ServiceDiscovery _MCLI = cli.get_cli(prefix="monolith.containers") def emit_store(name, value, tagkv=None): _MCLI.emit_store(name, value, tagkv) def generate_session_config(cluster_and_task=None): if cluster_and_task is None: session_config = tf.compat.v1.ConfigProto(allow_soft_placement=True) else: cluster = cluster_and_task[0] task = cluster_and_task[1] spec = tf.train.ClusterSpec(cluster) device_filters = ["/job:ps", "/job:chief"] if task["type"] != "chief": device_filters += ["/job:{}/task:{}".format(task["type"], task["index"])] session_config = tf.compat.v1.ConfigProto(cluster_def=spec.as_cluster_def(), allow_soft_placement=True, device_filters=device_filters) session_config.share_cluster_devices_in_session = True session_config.experimental.share_session_state_in_clusterspec_propagation = True # grappler doesn't really understand RaggedTensor. session_config.graph_options.rewrite_options.disable_meta_optimizer = True return session_config def get_training_cluster( discovery: ServiceDiscovery, worker_addr: str, index: int, num_redundant_ps: int, num_required_ps: int, num_workers: int, model_dir: str, uuid: str, model_name: str = None, cluster_type: str = "stable") -> Tuple[Dict[str, List], Dict]: if index == 0: if num_redundant_ps: file_name = _get_ps_cluster_file_name(model_dir, uuid) # In the case of chief restart, first obtain the ps cluster from the file. ps_addrs = _fetch_ps_cluster_from_file(file_name, timeout=0) if len(ps_addrs) != num_required_ps: # The ps cluster cannot be obtained from the file, so it is queried # through service discovery. Then assign the ps cluster to the file. ps_addrs = _query_ps_cluster(discovery, num_required_ps, model_name, cluster_type) _save_ps_cluster_to_file(file_name, ps_addrs) else: # By default, the ps cluster is queried by discovery. ps_addrs = _query_ps_cluster(discovery, num_required_ps, model_name, cluster_type) fake_worker_list = ["0.0.0.0:{}".format(i) for i in range(1, num_workers)] cluster = { "chief": [worker_addr], "worker": fake_worker_list, "ps": ps_addrs, } task = {"type": "chief", "index": 0} else: chief_addr = _query_chief_addr(discovery) # Due to current TF limitation (TF_CONFIG doesn't support dict), # we need to provide a fake worker list in cluster worker_index = index - 1 fake_worker_list = ["0.0.0.0:{}".format(i) for i in range(1, num_workers)] fake_worker_list[worker_index] = worker_addr if num_redundant_ps: file_name = _get_ps_cluster_file_name(model_dir, uuid) # Get the ps cluster from the file. ps_addrs = _fetch_ps_cluster_from_file(file_name) else: # By default, the ps cluster is queried by discovery. ps_addrs = _query_ps_cluster(discovery, num_required_ps) cluster = { "chief": [chief_addr], "worker": fake_worker_list, "ps": ps_addrs, } task = {"type": "worker", "index": worker_index} assert len(cluster["ps"]) == num_required_ps return cluster, task def _cluster_query_failure_handler(): time.sleep(5) def _query_chief_addr(discovery: ServiceDiscovery): worker_addr_dict = None while True: worker_addr_dict = discovery.query("worker") if 0 in worker_addr_dict: break _cluster_query_failure_handler() return worker_addr_dict[0] def _query_ps_cluster(discovery: ServiceDiscovery, num_required_ps: int, model_name: str = None, cluster_type: str = "stable"): start = time.time() ps_addr_dict = None while True: ps_addr_dict = discovery.query("ps") num_left_ps = max(0, num_required_ps - len(ps_addr_dict)) logging.info("Got {} ps, {} left!".format(len(ps_addr_dict), num_left_ps)) if model_name: tags = { "model_name": model_name, "cluster_type": cluster_type, } emit_store("num_left_ps", num_left_ps, tags) emit_store("job_waiting", 1, tags) if len(ps_addr_dict) >= num_required_ps: break _cluster_query_failure_handler() ps_addrs = [addr for index, addr in sorted(ps_addr_dict.items()) ][:num_required_ps] return ps_addrs def _save_ps_cluster_to_file(file_name: str, ps_addrs: List[str]): ps_str = ",".join(ps_addrs) tf.io.gfile.makedirs(os.path.dirname(file_name)) tmp_name = file_name + "-tmp" with tf.io.gfile.GFile(tmp_name, mode="w") as f: f.write(ps_str) tf.io.gfile.rename(tmp_name, file_name, overwrite=True) def _fetch_ps_cluster_from_file(file_name: str, timeout=1800): ps_str = "" start_time = time.time() while True: try: with tf.io.gfile.GFile(file_name) as f: ps_str = f.read() except tf.errors.NotFoundError: pass if bool(ps_str) or time.time() - start_time > timeout: break _cluster_query_failure_handler() ps_addrs = ps_str.split(",") if ps_str else [] return ps_addrs def _get_ps_cluster_file_name(model_dir: str, uuid: str): return os.path.join(model_dir, "ps_cluster_dir", uuid or "ps_info") ================================================ FILE: monolith/native_training/cluster_manager_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import unittest from monolith.native_training import cluster_manager class ClusterManagerTest(unittest.TestCase): def testBasic(self): ps_addrs = ["0.0.0.0:{}".format(i) for i in range(3)] file_name = cluster_manager._get_ps_cluster_file_name( model_dir=os.path.join(os.environ["TEST_TMPDIR"], "ClusterManagerTest", self._testMethodName), uuid=self._testMethodName) cluster_manager._save_ps_cluster_to_file(file_name, ps_addrs) new_ps_addrs = cluster_manager._fetch_ps_cluster_from_file(file_name) self.assertEqual(ps_addrs, new_ps_addrs) if __name__ == "__main__": unittest.main() ================================================ FILE: monolith/native_training/consul.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Consul client from bytedance pylib. """ import json import logging import os import socket import sys import threading import time import traceback from typing import Dict from six.moves.http_client import HTTPConnection class ConsulException(Exception): pass class UnixHTTPConnection(HTTPConnection): def __init__(self, path, **kwargs): kwargs["host"] = "localhost" HTTPConnection.__init__(self, **kwargs) self.path = path def connect(self): sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) sock.connect(self.path) self.sock = sock class Client: def __init__(self): self._lock = threading.Lock() self._cache = {} self._consul_sock = "/opt/tmp/sock/consul.sock" self._consul_host = os.environ.get("CONSUL_HTTP_HOST") or os.environ.get( "TCE_HOST_IP") if not self._consul_host: if os.path.isfile(self._consul_sock): self._consul_host = self._consul_sock else: self._consul_host = "127.0.0.1" self._consul_port = int(os.environ.get("CONSUL_HTTP_PORT") or 2280) def lookup(self, name, timeout=3, cachetime=0): now = time.time() if cachetime > 0: cache = self._cache.get(name) if cache and now - cache["cachetime"] <= cachetime: return cache["ret"] timeout = timeout if cache else 30 with self._lock: ret = self.lookup(name, timeout) else: ret = self._lookup(name, timeout) with self._lock: self._cache[name] = { "ret": ret, "cachetime": now, } return ret def _lookup(self, name, timeout): if self._consul_host.startswith("/"): conn = UnixHTTPConnection(self._consul_host) else: conn = HTTPConnection(self._consul_host, self._consul_port, timeout=timeout) conn.request("GET", "/v1/lookup/name?name=" + name + "&addr-family=dual-stack") response = conn.getresponse() status = response.status data = response.read() conn.close() if status != 200: logging.error("consul: %s %s", status, data.decode("utf8")) return [] return json.loads(data.decode("utf8")) def register(self, name, port, tags=None, check_script=None, host=None): d = { "id": "%s-%s" % (name, port), "name": name, "port": int(port), "check": { "ttl": "60s", } } if tags is not None: d["tags"] = ["%s:%s" % (k, v) for k, v in tags.items()] if check_script: d["check"] = {"interval": "30s", "script": check_script} if not host: host = self._consul_host conn = HTTPConnection(host, self._consul_port, timeout=15) conn.request("PUT", "/v1/agent/service/register", json.dumps(d)) response = conn.getresponse() status = response.status data = response.read() if status != 200: raise ConsulException(data.decode("utf8")) def _health_check(): while True: now = time.time() try: conn.request("GET", f"/v1/agent/check/pass/service:{name}-{port}") conn.getresponse().read() except socket.error: print(traceback.format_exc(), file=sys.stderr) time.sleep(2) # Immediately retry now -= 30 time.sleep(max(30 + now - time.time(), 0)) th = threading.Thread(name=f"ConsulHealthCheck-{name}-{port}", target=_health_check, daemon=True) th.start() # Maybe in the future, we want to garbage collect threads. def deregister(self, name, port, host=None): host = host or self._consul_host conn = HTTPConnection(host, self._consul_port, timeout=15) conn.request("PUT", "/v1/agent/service/deregister/%s-%s" % (name, port)) response = conn.getresponse() status = response.status data = response.read() if status != 200: raise ConsulException(data.decode("utf8")) conn.close() ================================================ FILE: monolith/native_training/consul_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import json import unittest from unittest import mock import numpy as np from absl import logging from six.moves.http_client import OK from monolith.native_training import consul _HTTP_CONNECTION_TARGET = "monolith.native_training.consul.HTTPConnection" class ConsulTest(unittest.TestCase): def test_lookup(self): with mock.patch(_HTTP_CONNECTION_TARGET) as MockHttpConnection: resp = mock.MagicMock() resp.status = OK data = [{"Port": 1234, "Host": "192.168.0.1", "Tags": {"index": "0"}}] resp.read.return_value = json.dumps(data).encode("utf-8") MockHttpConnection.return_value.getresponse.return_value = resp client = consul.Client() result = client.lookup("test_name") self.assertEqual(result, data) def test_register(self): with mock.patch(_HTTP_CONNECTION_TARGET) as MockHttpConnection: resp = mock.MagicMock() resp.status = OK MockHttpConnection.return_value.getresponse.return_value = resp client = consul.Client() client = client.register("test_name", 12345) def test_deregister(self): with mock.patch(_HTTP_CONNECTION_TARGET) as MockHttpConnection: resp = mock.MagicMock() resp.status = OK MockHttpConnection.return_value.getresponse.return_value = resp client = consul.Client() client = client.deregister("test_name", 12345) if __name__ == "__main__": unittest.main() ================================================ FILE: monolith/native_training/cpu_sync_training_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import time import numpy as np import tensorflow as tf # Note: this needs to be set before monolith.native_training, to ensure the # imports work correctly. os.environ["MONOLITH_WITH_HOROVOD"] = "True" from monolith.native_training import cpu_training, embedding_combiners, feature, device_utils from monolith.native_training.native_task import NativeTask from monolith.native_training import entry from monolith.native_training.data.training_instance.python.parser_utils import advanced_parse import horovod.tensorflow as hvd test_folder = os.environ["TEST_TMPDIR"] class FeatureTask(NativeTask): """A test task that will collect some information in model_fn.""" def create_input_fn(self, _): def input_fn(): return tf.data.Dataset.from_tensors({ "feature": tf.ragged.constant([[1, 2, 3, 4]], dtype=np.int64) }).map(advanced_parse).repeat(5) return input_fn def create_model_fn(self): def model_fn(features, mode, config, **kwargs): slot = self.ctx.feature_factory.create_feature_slot( feature.FeatureSlotConfig(name="slot")) s = slot.add_feature_slice(5) fc = feature.FeatureColumnV1(slot, "feature") embedding = fc.embedding_lookup(s) if mode == tf.estimator.ModeKeys.PREDICT: return tf.estimator.EstimatorSpec(mode, predictions=tf.constant(0)) all_embeddings = [fc.get_all_embeddings_concat()] loss = tf.reduce_sum(embedding) grads = tf.gradients(loss, all_embeddings) print1 = tf.print("embedding: ", embedding) print2 = tf.print("all_embeddings: ", all_embeddings) with tf.control_dependencies([print1, print2]): train_op = tf.group( self._ctx.feature_factory.apply_gradients(zip( grads, all_embeddings))) return tf.estimator.EstimatorSpec(mode, train_op=train_op, loss=loss, predictions=tf.constant(0)) return model_fn class EmbeddingUpdateTask(NativeTask): """A test task that will compare TF and monolith embedding update.""" def create_input_fn(self, _): def input_fn(): return tf.data.Dataset.from_tensors({ "feature": tf.ragged.constant([[1, 2, 3, 4]], dtype=np.int64), "tf_feature": tf.constant([[0, 1, 2, 3]], dtype=np.int64), }).map(advanced_parse).repeat(10) return input_fn def create_model_fn(self): def model_fn(features, mode, config): slot = self.ctx.feature_factory.create_feature_slot( feature.FeatureSlotConfig( name="slot", default_vec_initializer=entry.ConstantsInitializer(0), default_vec_optimizer=entry.AdagradOptimizer( learning_rate=0.1, initial_accumulator_value=1))) s = slot.add_feature_slice(5) fc = feature.FeatureColumnV1(slot, "feature") embedding = fc.embedding_lookup(s) tf_embeddings = tf.Variable(initial_value=tf.zeros(shape=(4, 5)), name='embedding') tf_embedding = tf.reduce_sum(tf.nn.embedding_lookup( params=tf_embeddings, ids=features["tf_feature"]), axis=1) if mode == tf.estimator.ModeKeys.PREDICT: return tf.estimator.EstimatorSpec(mode, predictions=tf.constant(0)) all_embeddings = [fc.get_all_embeddings_concat()] loss = tf.reduce_sum(embedding) tf_loss = tf.reduce_sum(tf_embedding) grads = tf.gradients(loss, all_embeddings) tf_grads = tf.gradients(tf_loss, tf_embeddings) gs = tf.compat.v1.train.get_or_create_global_step() print1 = tf.print(gs, "embedding: ", embedding) print2 = tf.print(gs, "tf_embedding: ", tf_embedding) assert_equal = tf.compat.v1.assert_equal(embedding, tf_embedding) with tf.control_dependencies([print1, print2, assert_equal]): train_op = tf.group( self._ctx.feature_factory.apply_gradients(zip( grads, all_embeddings)), tf.compat.v1.train.AdagradOptimizer( learning_rate=0.1, initial_accumulator_value=1).apply_gradients( zip(tf_grads, [tf_embeddings])), gs.assign_add(1)) return tf.estimator.EstimatorSpec(mode, train_op=train_op, loss=loss + tf_loss, predictions=tf.constant(0)) return model_fn class FloatFeatureTask(NativeTask): """A test task that will use float feature in model_fn.""" def create_input_fn(self, _): def input_fn(): return tf.data.Dataset.from_tensors({ "ragged_feature": tf.ragged.constant([[0, 0]], dtype=np.int64), "float_feature": tf.constant([[1.]], dtype=tf.float32) }).map(advanced_parse) return input_fn def create_model_fn(self): def model_fn(features, mode, **kwargs): slot = self.ctx.feature_factory.create_feature_slot( feature.FeatureSlotConfig(name="slot")) s = slot.add_feature_slice(5) fc = feature.FeatureColumnV1(slot, "ragged_feature") embedding = fc.embedding_lookup(s) float_feature = features["float_feature"] predictions = tf.reduce_sum(float_feature, axis=-1) if mode == tf.estimator.ModeKeys.PREDICT: return tf.estimator.EstimatorSpec(mode, predictions=predictions) all_embeddings = [fc.get_all_embeddings_concat()] loss = tf.reduce_sum(embedding) grads = tf.gradients(loss, all_embeddings) train_op = tf.group( self._ctx.feature_factory.apply_gradients(zip(grads, all_embeddings))) return tf.estimator.EstimatorSpec(mode, train_op=train_op, loss=loss, predictions=predictions) return model_fn class NonFeatureTask(NativeTask): def create_input_fn(self, _): def input_fn(): return tf.data.Dataset.from_tensors([1]) return input_fn def create_model_fn(self): def model_fn(features, mode, config): return tf.estimator.EstimatorSpec(mode, train_op=tf.group(features), loss=tf.constant(0.0), predictions=tf.constant(0)) return model_fn class SequenceFeatureTask(NativeTask): """A test task that will use float feature in model_fn.""" def create_input_fn(self, mode): del mode def input_fn(): return tf.data.Dataset.from_tensors({ "sequence_feature": tf.ragged.constant([[1, 2], [], [3, 4, 5]], dtype=np.int64), }).map(advanced_parse) return input_fn def create_model_fn(self): def model_fn(features, mode, **kwargs): slot = self.ctx.feature_factory.create_feature_slot( feature.FeatureSlotConfig(name="slot")) s = slot.add_feature_slice(5) fc = feature.FeatureColumnV1(slot, "sequence_feature", combiner=embedding_combiners.FirstN(2)) embedding = fc.embedding_lookup(s) sequence_feature = features["sequence_feature"] predictions = tf.reduce_sum(sequence_feature, axis=-1) if mode == tf.estimator.ModeKeys.PREDICT: return tf.estimator.EstimatorSpec(mode, predictions=predictions) all_embeddings = [fc.get_all_embeddings_concat()] loss = tf.reduce_sum(all_embeddings) grads = tf.gradients(loss, all_embeddings) train_op = tf.group( self._ctx.feature_factory.apply_gradients(zip(grads, all_embeddings))) return tf.estimator.EstimatorSpec(mode, train_op=train_op, loss=loss, predictions=predictions) return model_fn class CpuSyncTrainTest(tf.test.TestCase): def test_cpu_training_feature(self): hvd.init() p = FeatureTask.params() p.name = "feature_task" task = FeatureTask(p) training = cpu_training.CpuTraining( cpu_training.CpuTrainingConfig(num_workers=hvd.size(), num_ps=0, reorder_fids_in_data_pipeline=True, embedding_prefetch_capacity=1, enable_sync_training=True), task) run_config = tf.estimator.RunConfig( model_dir=os.path.join(test_folder, "test_cpu_sync_training_feature"), device_fn=device_utils.default_device_fn) est = tf.estimator.Estimator(training.create_model_fn(), config=run_config) est.train(training.create_input_fn(tf.estimator.ModeKeys.TRAIN), steps=2) def test_embedding_update(self): hvd.init() p = EmbeddingUpdateTask.params() p.name = "embedding_update_task" task = EmbeddingUpdateTask(p) training = cpu_training.CpuTraining( cpu_training.CpuTrainingConfig(num_workers=hvd.size(), num_ps=0, reorder_fids_in_data_pipeline=True, embedding_prefetch_capacity=0, enable_sync_training=True), task) run_config = tf.estimator.RunConfig( model_dir=os.path.join(test_folder, "test_embedding_update"), device_fn=device_utils.default_device_fn) est = tf.estimator.Estimator(training.create_model_fn(), config=run_config) est.train(training.create_input_fn(tf.estimator.ModeKeys.TRAIN), steps=10) def test_cpu_training_float_feature(self): hvd.init() p = FloatFeatureTask.params() p.name = "float_feature_task" task = FloatFeatureTask(p) training = cpu_training.CpuTraining( cpu_training.CpuTrainingConfig(num_workers=hvd.size(), num_ps=0, reorder_fids_in_data_pipeline=True, embedding_prefetch_capacity=1, enable_sync_training=True), task) run_config = tf.estimator.RunConfig( model_dir=os.path.join(test_folder, "test_cpu_sync_training_float_feature"), device_fn=device_utils.default_device_fn) est = tf.estimator.Estimator(training.create_model_fn(), config=run_config) est.train(training.create_input_fn(tf.estimator.ModeKeys.TRAIN), steps=2) def test_cpu_training_sequence_feature(self): hvd.init() p = SequenceFeatureTask.params() p.name = "sequence_feature_task" task = SequenceFeatureTask(p) training = cpu_training.CpuTraining( cpu_training.CpuTrainingConfig(num_workers=hvd.size(), num_ps=0, reorder_fids_in_data_pipeline=True, embedding_prefetch_capacity=1, enable_sync_training=True, hashtable_init_capacity=100000, enable_embedding_postpush=True), task) run_config = tf.estimator.RunConfig( model_dir=os.path.join(test_folder, "test_cpu_training_sequence_feature"), device_fn=device_utils.default_device_fn) est = tf.estimator.Estimator(training.create_model_fn(), config=run_config) est.train(training.create_input_fn(tf.estimator.ModeKeys.TRAIN), steps=2) def test_cpu_training_non_feature(self): hvd.init() p = NonFeatureTask.params() p.name = "non_feature_task" task = NonFeatureTask(p) training = cpu_training.CpuTraining( cpu_training.CpuTrainingConfig(num_workers=hvd.size(), num_ps=0, embedding_prefetch_capacity=1, hashtable_init_capacity=100000, enable_sync_training=True), task) run_config = tf.estimator.RunConfig( model_dir=os.path.join(test_folder, "test_cpu_sync_training_non_feature"), device_fn=device_utils.default_device_fn) est = tf.estimator.Estimator(training.create_model_fn(), config=run_config) est.train(training.create_input_fn(tf.estimator.ModeKeys.TRAIN), steps=2) class DistributedSyncTrainTest(tf.test.TestCase): def test_basic(self): hvd.init() model_dir = os.path.join(test_folder, "sync_training_basic") params = FeatureTask.params() params.name = "test_task" params.train.max_steps = 2 # TODO(zouxuan): async push breaks the test, needs further triage. cpu_training.distributed_sync_train( cpu_training.DistributedCpuTrainingConfig( model_dir=model_dir, enable_sync_training=True, reorder_fids_in_data_pipeline=True, embedding_prefetch_capacity=1, hashtable_init_capacity=100000, enable_embedding_postpush=False), params) def test_sparse_pipelining(self): hvd.init() model_dir = os.path.join(test_folder, "sync_training_pipelined") params = FeatureTask.params() params.name = "test_task" params.train.max_steps = 4 cpu_training.distributed_sync_train( cpu_training.DistributedCpuTrainingConfig( model_dir=model_dir, enable_sync_training=True, reorder_fids_in_data_pipeline=True, embedding_prefetch_capacity=1, enable_pipelined_bwda2a=True, enable_pipelined_fwda2a=True, hashtable_init_capacity=100000, enable_embedding_postpush=False), params) if __name__ == "__main__": tf.compat.v1.disable_eager_execution() tf.test.main() ================================================ FILE: monolith/native_training/cpu_training.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """This module defines how to run a native task in CPU training environment. CpuTraining defines those conversion. """ import contextlib import copy import dataclasses import getpass import json import os import io import platform import socket import threading import timeit import sys import traceback from datetime import datetime from typing import Callable, Dict, Iterable, List, Set, Tuple, Union from urllib.parse import urlparse import time from absl import logging from absl import flags import numpy as np import tensorflow as tf from tensorflow.python.lib.io import file_io from tensorflow.python.training.summary_io import SummaryWriterCache from tensorflow.python.ops import resources from tensorflow.python.ops import variables as tfvariables from tensorflow.python.ops.control_flow_ops import NoOp from monolith.agent_service.agent_service_pb2 import ServerType from monolith.agent_service.backends import SyncBackend from monolith.core.hyperparams import InstantiableParams from monolith.native_training import barrier_ops from monolith.native_training import basic_restore_hook from monolith.native_training import cluster_manager from monolith.native_training import device_utils from monolith.native_training import distributed_ps_factory, distributed_ps from monolith.native_training import distribution_ops from monolith.native_training import distributed_ps_sync from monolith.native_training import embedding_combiners from monolith.native_training import entry from monolith.native_training import feature from monolith.native_training import gflags_utils from monolith.native_training import hash_filter_ops from monolith.native_training import hash_table_ops from monolith.native_training import hvd_lib from monolith.native_training import multi_hash_table_ops from monolith.native_training import logging_ops from monolith.native_training import mlp_utils from monolith.native_training import monolith_checkpoint_state_pb2 from monolith.native_training import multi_type_hash_table from monolith.native_training import native_task from monolith.native_training import native_task_context from monolith.native_training import net_utils from monolith.native_training import ps_benchmark from monolith.native_training import save_utils from monolith.native_training import session_run_hooks from monolith.native_training import sync_hooks from monolith.native_training import sync_training_hooks from monolith.native_training import tensor_utils from monolith.native_training import utils from monolith.native_training import variables from monolith.native_training import distributed_serving_ops from monolith.native_training import yarn_runtime from monolith.native_training.alert import alert_manager from monolith.native_training.data import datasets from monolith.native_training.hash_table_utils import infer_dim_size from monolith.native_training.distributed_serving_ops import ParameterSyncClient from monolith.native_training.hash_filter_ops import FilterType from monolith.native_training.hooks import ckpt_hooks from monolith.native_training.hooks import ckpt_info from monolith.native_training.hooks import ps_check_hooks from monolith.native_training.hooks import hook_utils from monolith.native_training.hooks import session_hooks from monolith.native_training.hooks import feature_engineering_hooks from monolith.native_training.hooks.server import server_lib as server_hook_lib from monolith.native_training.metric import cli from monolith.native_training.metric.metric_hook import Tf2ProfilerHook, NVProfilerHook from monolith.native_training.metric.metric_hook import ByteCCLTelemetryHook from monolith.native_training.metric.metric_hook import ThroughputMetricHook from monolith.native_training.model_export import export_hooks from monolith.native_training.model_export import export_utils from monolith.native_training.model_export import saved_model_exporters from monolith.native_training.model_export import export_context from monolith.native_training.model_export.export_context import \ is_exporting, is_exporting_distributed, is_dry_run_or_exporting, ExportMode from monolith.native_training.native_task import NativeTask from monolith.native_training.prefetch_queue import \ enqueue_dicts_with_queue_return, EnqueueHook from monolith.native_training import prefetch_queue from monolith.native_training.proto import debugging_info_pb2 from monolith.native_training.runtime.hash_table import \ embedding_hash_table_pb2 from monolith.native_training.runtime.parameter_sync import \ parameter_sync_pb2 from monolith.native_training.service_discovery import ServiceDiscovery from monolith.native_training.service_discovery import TfConfigServiceDiscovery from monolith.native_training.service_discovery import MLPServiceDiscovery from monolith.native_training.data.training_instance.python import parser_utils from monolith.native_training.model_dump.dump_utils import DumpUtils, DRY_RUN from monolith.native_training.data.parsers import ParserCtx, get_default_parser_ctx from monolith.native_training.dense_reload_utils import CustomRestoreListenerKey, CustomRestoreListener from monolith.native_training.data.item_pool_hook import ItemPoolSaveRestoreHook, POOL_KEY from monolith.native_training.distribution_utils import get_sync_run_hooks, \ update_session_config_for_gpu, get_mpi_rank, get_mpi_size, get_mpi_local_rank flags.DEFINE_string( "monolith_chief_alert_proto", "", "The text format of alert proto. Will only be activated by chief.") FLAGS = flags.FLAGS def _combine_slices_as_table( slices: List[feature.SliceConfig], hashtable_config: entry.HashTableConfig) -> entry.HashTableConfigInstance: table_config = embedding_hash_table_pb2.EmbeddingHashTableConfig() entry_config = table_config.entry_config learning_rate_fns = list() if is_exporting(): entry_config.entry_type = embedding_hash_table_pb2.EntryConfig.EntryType.SERVING for s in slices: entry_config.segments.append(s.segment) learning_rate_fns.append(s.learning_rate_fn) hashtable_config.mutate_table(table_config) return entry.HashTableConfigInstance(table_config, learning_rate_fns) def _lookup_embedding_ids( hash_table: multi_type_hash_table.BaseMultiTypeHashTable, name_to_embedding_ids: Dict[str, tf.RaggedTensor]) -> Dict[str, tf.Tensor]: name_to_ids = {k: v.values for k, v in name_to_embedding_ids.items()} return hash_table.lookup(name_to_ids) def _convert_parquets_to_instance(parquets_path, instance_path): import pyarrow.parquet as pq from struct import pack from cityhash import CityHash64 from idl.matrix.proto.proto_parser_pb2 import Instance # choose latest date in parquets_path if not tf.io.gfile.isdir(parquets_path): raise ValueError(f"Argument parquet_path is not a directory. {parquets_path}") valid_dates = [fn for fn in tf.io.gfile.listdir(parquets_path) if fn.isdigit() and len(fn)==8 and tf.io.gfile.isdir(os.path.join(parquets_path, fn))] if len(valid_dates) == 0: raise ValueError(f"No vaild subdirectory in parquet_path: {parquets_path}") selected_date = max(valid_dates) parquets_path = os.path.join(parquets_path, selected_date) logging.info(f"start to convert parquets files to a instance pb file, latest parquet_path={parquets_path}") # collect item_to_fids dict item_to_fids = {} parquet_files = [os.path.join(parquets_path, fn) for fn in tf.io.gfile.listdir(parquets_path) if fn.endswith(".snappy.parquet")] if len(parquet_files) == 0: raise ValueError(f"None of .snappy.parquet file found in {parquets_path}") logging.info(f"{len(parquet_files)} .snappy.paruqet file found.") for file_id, file_path in enumerate(parquet_files): logging.info(f"{file_id+1}/{len(parquet_files)} start to parse parquet file: {file_path}") with tf.io.gfile.GFile(file_path, "rb") as f: f_bin = f.read() pq_data = pq.read_table(io.BytesIO(f_bin)) logging.info(f"{len(pq_data)} items detected.") item_id_col = pq_data['item_id'].to_pylist() fids_col = pq_data['fids'].to_pylist() for i in range(len(pq_data)): item_id = CityHash64(str(item_id_col[i])) & ((1<<63)-1) fids = [int(fid) for fid in fids_col[i].split()] if item_id in item_to_fids: logging.info(f"{item_id} already in dict, use latest") item_to_fids[item_id] = fids logging.info(f"convert finished, totally {len(item_to_fids)} items collected.") # generate instance pb file logging.info(f"start to generate items instance pb file to {instance_path}.") with tf.io.gfile.GFile(instance_path, "wb") as f: for item_id, fids in item_to_fids.items(): inst = Instance() inst.line_id.item_id = item_id for fid in fids: inst.fid.append(fid) serialized = inst.SerializeToString() f.write(pack(" CpuTrainingConfig: if export_context.is_exporting(): return self._serving_config_do_not_refer_directly return self._config_do_not_refer_directly @property def feature_configs( self ) -> Tuple[Dict[str, entry.HashTableConfigInstance], Dict[str, List[int]], Dict[str, embedding_combiners.Combiner]]: if export_context.is_exporting(): return self._serving_feature_configs_do_not_refer_directly return self._feature_configs_do_not_refer_directly def _init_fused_layout_params(self) -> None: parse_ctx = get_default_parser_ctx() parse_ctx.enable_fused_layout = self.config.enable_fused_layout if parse_ctx.enable_fused_layout: parse_ctx.sharding_sparse_fids_op_params = None # same param to fused_layout (feature_name_config, feature_to_unmerged_slice_dims, feature_to_combiner) = self.feature_configs parse_ctx.sharding_sparse_fids_op_params = distributed_ps.PartitionedHashTable.gen_feature_configs( num_ps=self.config.num_workers if self._params.train.use_gpu_emb_table else self.config.num_ps, feature_name_to_config=feature_name_config, layout_configs=self._task.layout_dict, feature_to_combiner=feature_to_combiner, feature_to_unmerged_slice_dims=feature_to_unmerged_slice_dims, use_native_multi_hash_table=self.config.use_native_multi_hash_table, unique=lambda: False if is_exporting() else True, transfer_float16=False, enable_gpu_emb=self._params.train.use_gpu_emb_table, use_gpu=export_context.get_current_export_ctx().with_remote_gpu if export_context.is_exporting() else self.config.enable_gpu_training) logging.info( f"_init_fused_layout_params {export_context.is_exporting()} {self._params.train.use_gpu_emb_table} {parse_ctx.sharding_sparse_fids_op_params.enable_gpu_emb} {parse_ctx.sharding_sparse_fids_op_params.use_gpu}" ) def create_input_fn(self, mode): input_fn = self._task.create_input_fn(mode) enable_reorder = (mode != tf.estimator.ModeKeys.PREDICT and self.config.reorder_fids_in_data_pipeline and not self.config.enable_fused_layout) use_dataservice = self.config.use_dataservice feature_name_config = self.feature_configs[0] embedding_feature_names = feature_name_config.keys() def input_fn_factory(input_fn, enable_reorder, use_dataservice, feature_name_config, embedding_feature_names): def reorder_parse_fn(*args): logging.info( 'Wrapping parser to dedup and reorder fids in data pipeline...') # features = parse_fn(*args, **kwargs) features = args[0] # CpuTraining.create_model_fn: def model_fn embedding_ragged_ids = { k: v for k, v in features.items() if k in embedding_feature_names } dense_features = { k: v for k, v in features.items() if k not in embedding_feature_names } if self.config.use_native_multi_hash_table: # when multi hash table is used, this is unmerged merged_slot_dims = multi_hash_table_ops.infer_dims( feature_name_config) sorted_slot_keys = sorted(embedding_feature_names) sorted_input = [ embedding_ragged_ids[k].values for k in sorted_slot_keys ] else: merged_slot_to_id, merged_slot_to_sizes = self._dummy_merged_table._get_merged_to_indexed_tensor( {k: v.values for k, v in embedding_ragged_ids.items()}) merged_slot_dims = self._dummy_merged_table.get_table_dim_sizes() sorted_slot_keys = sorted(merged_slot_to_id.keys()) sorted_input = [merged_slot_to_id[k] for k in sorted_slot_keys] reordered_pack = distribution_ops.fused_reorder_by_indices( sorted_input, self.config.num_workers, merged_slot_dims) reordered_pack = (*reordered_pack, get_req_time(dense_features)) if self.config.use_native_multi_hash_table: # DistributedMultiTypeHashTableMpi.lookup lookup_args = reordered_pack else: # merged_multi_type_hash_table.lookup lookup_args = ( merged_slot_to_sizes, # DistributedMultiTypeHashTableMpi.lookup reordered_pack) # Results include the following intermediate tensors res = ( dense_features, # Dense features # CpuTraining.create_model_fn: def model_fn ( parser_utils.RaggedEncodingHelper.expand( embedding_ragged_ids, # Sparse Features with_precomputed_nrows=True, with_precomputed_value_rowids=False # Because most GPU-downstream poolings # are not using value_rowids anymore, # we choose not to precompute it here. ), lookup_args)) # Use dict here to prevent tf.Estimator from automatically treating the second in the return tuple as labels return {"1": res} def wrapped_input_fn(): with native_task_context.with_ctx( make_native_task_context(self.config, self._sync_backend)): ds = input_fn() if isinstance(ds, tf.data.Dataset): if enable_reorder: ds = ds.map(reorder_parse_fn, num_parallel_calls=tf.data.AUTOTUNE) if use_dataservice and not is_dry_run_or_exporting(): # This is a temporary hack. Will revisit here once we decided to # do the remanagement. tmp_mlp_env = mlp_utils.MLPEnv() ds = datasets.distribute(ds, target=tmp_mlp_env.dispatcher_target(), num_worker=tmp_mlp_env.num_replicas(role='worker'), worker_idx=tmp_mlp_env.index) # Always enable prefetch since input_fn might be wrapped by # many other decorators. ds = ds.prefetch(tf.data.AUTOTUNE) return ds return wrapped_input_fn return input_fn_factory(input_fn, enable_reorder, use_dataservice, feature_name_config, embedding_feature_names) def create_model_fn(self): def create_hash_table_and_filters_fn(): (feature_name_config, feature_to_unmerged_slice_dims, feature_to_combiner) = self.feature_configs logging.vlog( 1, "feature_to_unmerged_slice_dims: {}".format( feature_to_unmerged_slice_dims)) slot_occurrence_threshold_config = embedding_hash_table_pb2 \ .SlotOccurrenceThresholdConfig() for slot, occurrence_threshold in self._slot_to_occurrence_threshold.items( ): slot_occurrence_threshold = slot_occurrence_threshold_config.slot_occurrence_thresholds.add( ) slot_occurrence_threshold.slot = slot slot_occurrence_threshold.occurrence_threshold = occurrence_threshold if occurrence_threshold > 0: self._enable_hash_filter = True # In the sync training, hash filter and hashtables are inside worker. if is_exporting(): hash_filters = [None] * max(1, self.config.num_ps) else: with device_utils.maybe_device_if_allowed( '/device:GPU:0' ) if self._params.train.use_gpu_emb_table else contextlib.nullcontext(): hash_filters = hash_filter_ops.create_hash_filters( self.config.num_ps, self._enable_hash_filter, config=slot_occurrence_threshold_config.SerializeToString(), filter_capacity=self.config.filter_capacity, filter_split_num=self.config.filter_split_num, filter_type=self.config.filter_type) slot_to_expire_time_config = embedding_hash_table_pb2.SlotExpireTimeConfig( ) for slot, expire_time in self._slot_to_expire_time.items(): slot_expire_time = slot_to_expire_time_config.slot_expire_times.add() slot_expire_time.slot = slot slot_expire_time.expire_time = expire_time for config in feature_name_config.values(): config.table_config.slot_expire_time_config.CopyFrom( slot_to_expire_time_config) sync_clients = [None] * max(1, self.config.num_ps) if self.config.enable_realtime_training and not is_exporting(): sync_clients = distributed_serving_ops.create_parameter_sync_clients( self.config.num_ps) if not self.config.enable_full_sync_training: if self.config.enable_fused_layout: return distributed_ps_factory.create_partitioned_hash_table( num_ps=self.config.num_ps, use_native_multi_hash_table=self.config. use_native_multi_hash_table, max_rpc_deadline_millis=self.config.max_rpc_deadline_millis, hash_filters=hash_filters, sync_clients=sync_clients), hash_filters elif self.config.use_native_multi_hash_table: return distributed_ps_factory.create_native_multi_hash_table( self.config.num_ps, feature_name_config, hash_filters, sync_clients=sync_clients, max_rpc_deadline_millis=self.config.max_rpc_deadline_millis, ), hash_filters else: return distributed_ps_factory.create_multi_type_hash_table( self.config.num_ps, feature_name_config, hash_filters, sync_clients=sync_clients, reduce_network_packets=True, max_rpc_deadline_millis=self.config.max_rpc_deadline_millis, ), hash_filters else: queue_configs = { k: int(getattr(self.config, k)) for k in ("embedding_prefetch_capacity", "enable_async_optimize", "enable_pipelined_fwda2a", "enable_pipelined_bwda2a") } if self.config.enable_fused_layout: return distributed_ps_factory.create_partitioned_hash_table( num_ps=self.config.num_workers if self._params.train.use_gpu_emb_table else self.config.num_ps, use_native_multi_hash_table=self.config. use_native_multi_hash_table, max_rpc_deadline_millis=self.config.max_rpc_deadline_millis, hash_filters=hash_filters * self.config.num_workers if self._params.train.use_gpu_emb_table else hash_filters, sync_clients=sync_clients * self.config.num_workers if self._params.train.use_gpu_emb_table else sync_clients, enable_gpu_emb=self._params.train.use_gpu_emb_table, queue_configs=queue_configs), hash_filters elif self.config.use_native_multi_hash_table: return distributed_ps_factory.create_in_worker_native_multi_hash_table( self.config.num_workers, feature_name_config, hash_filter=hash_filters[0], sync_client=sync_clients[0], queue_configs=queue_configs), hash_filters else: return distributed_ps_factory.create_in_worker_multi_type_hash_table( self.config.num_workers, feature_name_config, hash_filters[0], sync_client=sync_clients[0], queue_configs=queue_configs), hash_filters with native_task_context.with_ctx( make_native_task_context(self.config, self._sync_backend)): return self._get_pipelined_model_fn(create_hash_table_and_filters_fn) def _generate_valid_features(self) -> Dict[str, tf.Tensor]: """Generates a valid feature dict which can be fed into model_fn in TRAIN mode.""" input_fn = self._task.create_input_fn(tf.estimator.ModeKeys.TRAIN) dataset = input_fn() return tf.data.experimental.get_single_element(dataset) def _collect_feature_name_to_table_config( self ) -> Tuple[Dict[str, entry.HashTableConfigInstance], Dict[str, List[int]], Dict[str, embedding_combiners.Combiner]]: per_replica_batch_size = self._params.train.per_replica_batch_size with tf.Graph().as_default() as g, ParserCtx(enable_fused_layout=False): setattr(g, DRY_RUN, True) feature_factory = feature.DummyFeatureFactory(per_replica_batch_size) self._task.ctx.feature_factory = feature_factory self._task.ctx.layout_factory = None self._task.ctx.async_function_mgr = prefetch_queue.AsyncFunctionMgr( is_async=False) global_step = tf.compat.v1.train.get_or_create_global_step() model_fn = self._task.create_model_fn() features = self._generate_valid_features() model_fn(features=features, mode=tf.estimator.ModeKeys.TRAIN, config=tf.estimator.RunConfig()) table_to_config = feature_factory.get_table_name_to_table_config() feature_to_config: Dict[str, entry.HashTableConfigInstance] = {} feature_to_unmerged_slice_dims: Dict[str, List[int]] = {} feature_to_combiner: Dict[str, embedding_combiners.Combiner] = {} for k, table_config in table_to_config.items(): for feature_name in table_config.feature_names: assert not feature_name in feature_to_config, "Feature must only belongs to one table." feature_to_config.update({ feature_name: _combine_slices_as_table(table_config.slice_configs, table_config.hashtable_config) }) feature_to_unmerged_slice_dims[ feature_name] = table_config.unmerged_slice_dims feature_to_combiner[feature_name] = table_config.feature_to_combiners[ feature_name] if self.config.hashtable_init_capacity > 0: for conf in feature_to_config.values(): conf.table_config.initial_capacity = self.config.hashtable_init_capacity self._slot_to_occurrence_threshold = feature_factory.slot_to_occurrence_threshold self._slot_to_expire_time = feature_factory.slot_to_expire_time if self.config.enable_full_sync_training: # To improve hash table performance for config in feature_to_config.values(): config.table_config.entry_type = embedding_hash_table_pb2.EmbeddingHashTableConfig.RAW return feature_to_config, feature_to_unmerged_slice_dims, feature_to_combiner # TODO(leqi.zou): Add a function to disable pipelining. def _get_pipelined_model_fn(self, create_hash_table_and_filters_fn: Callable[ [], Tuple[multi_type_hash_table.MultiTypeHashTable, List[tf.Tensor]]]): (feature_name_config, feature_to_unmerged_slice_dims, feature_to_combiner) = self.feature_configs embedding_feature_names: Iterable[str] = feature_name_config.keys() if not embedding_feature_names: # We need to skip pipeline phase since dequeue might never be called if # embedding feature is not used. self._task.ctx.feature_factory = None return self._task.create_model_fn() def get_hooks_for_restore(model_dir: str, hash_filters: List[tf.Tensor], ps_monitor: save_utils.PsMonitor): if not model_dir: return () basename = os.path.join(model_dir, "model.ckpt") restore_listeners = [ hash_table_ops.HashTableCheckpointRestorerListener( basename, ps_monitor), multi_hash_table_ops.MultiHashTableCheckpointRestorerListener( basename, ps_monitor), hash_filter_ops.HashFilterCheckpointRestorerListener( basename, hash_filters, self._enable_hash_filter, enable_save_restore=( not self.config.enable_full_sync_training and self.config.filter_type != FilterType.PROBABILISTIC_FILTER)), CustomRestoreListener( self.config.reload_alias_map, self.config.clear_nn, self.config.continue_training, model_dir=model_dir, enable_alias_map_auto_gen=self.config.enable_alias_map_auto_gen) ] return (basic_restore_hook.CheckpointRestorerHook( listeners=restore_listeners),) def get_saver_listeners_for_exporting(save_path: str, export_dir_base: str = None, dense_only=False, exempt_checkpoint_paths=None, include_graphs=None): # TODO(leqi.zou): Add a test for this when graceful shutdown is implemented. if is_exporting(): # Safety check. To prevent infinite recursion. raise ValueError( "Logic corrupted. Try to call exporting listeners inside exporting." ) if dense_only and self._params.serving.export_mode is not ExportMode.DISTRIBUTED: raise ValueError( "Please set params.serving.export_mode = ExportMode.DISTRIBUTED. " "Only DISTRIBUTED mode is allowed when dense_only=True, got", self._params.serving.export_mode) if not self._params.serving.export_when_saving: return [] model_dir = os.path.dirname(save_path) serving_input_receiver_fn = self.create_serving_input_receiver_fn() if not serving_input_receiver_fn: raise ValueError("A valid serving_input_receiver_fn must be provided ", "if exporting is enabled. Got ", serving_input_receiver_fn) if not export_dir_base: export_dir_base = os.path.join(model_dir, self._params.serving.export_dir_base) # TODO(leqi.zou): Needs to do lifecycle management for exported model. exporter = create_exporter(self, model_dir=model_dir, warmup_file=self.config.warmup_file, export_dir_base=export_dir_base, dense_only=dense_only, include_graphs=include_graphs, export_context_list=self._export_context_list) barrier_listeners = [] if self.config.enable_sync_training and not dense_only and not self.config.enable_partial_sync_training: barrier_listeners.append( sync_training_hooks.SyncTrainingBarrierSaverListener()) return barrier_listeners + [ export_hooks.ExportSaverListener( save_path, serving_input_receiver_fn, exporter, exempt_checkpoint_paths=exempt_checkpoint_paths, dense_only=dense_only) ] def get_hooks_for_save(model_dir: str, hash_filters: List[tf.Tensor], barrier_op: barrier_ops.BarrierOp, ps_monitor: save_utils.PsMonitor): logging.info("get_hooks_for_save model_dir is " + model_dir) if not model_dir: raise ValueError("model_dir must be provided") hooks = [] exempt_checkpoint_paths = list() monolith_ckpt_state = save_utils.get_monolith_checkpoint_state( model_dir, remove_invalid_path=True) ckpt_state = tf.train.get_checkpoint_state(model_dir) if monolith_ckpt_state and monolith_ckpt_state.exempt_model_checkpoint_paths: exempt_checkpoint_paths = [ os.path.basename(p) for p in monolith_ckpt_state.exempt_model_checkpoint_paths ] logging.info( 'Exempt checkpoint paths: {}'.format(exempt_checkpoint_paths)) existing_checkpoint_paths = set([ os.path.basename(p) for p in ckpt_state.all_model_checkpoint_paths ]) assert all( [p in existing_checkpoint_paths for p in exempt_checkpoint_paths]) is_root_node = not self.config.enable_sync_training or self.config.index == 0 def create_saver(): return save_utils.PartialRecoverySaver( sharded=not self.config.enable_full_sync_training, max_to_keep=self.config.checkpoints_max_to_keep, keep_checkpoint_every_n_hours=24, ps_monitor=ps_monitor, exempt_checkpoint_paths=exempt_checkpoint_paths, skip_save=not is_root_node, model_dir=model_dir) saver = create_saver() tf.compat.v1.add_to_collection(tf.compat.v1.GraphKeys.SAVERS, saver) basename = os.path.join(model_dir, "model.ckpt") save_listeners = [ hash_table_ops.HashTableCheckpointSaverListener(basename), multi_hash_table_ops.MultiHashTableCheckpointSaverListener( basename, write_ckpt_info=is_root_node), hash_filter_ops.HashFilterCheckpointSaverListener( basename, hash_filters, self._enable_hash_filter, enable_save_restore=(not self.config.enable_full_sync_training)) ] include_graphs = None if self.config.enable_full_sync_training: include_graphs = [f"ps_{self.config.index}"] if is_root_node: include_graphs.append("entry") include_graphs.append("dense_0") save_listeners += get_saver_listeners_for_exporting( basename, exempt_checkpoint_paths=exempt_checkpoint_paths, include_graphs=include_graphs) if self.config.enable_model_ckpt_info and is_root_node: save_listeners.append(ckpt_info.FidSlotCountSaverListener(model_dir)) if self.config.feature_eviction_on_save: save_listeners.extend([ hash_table_ops.HashTableRestorerSaverListener(basename), multi_hash_table_ops.MultiHashTableRestorerSaverListener(basename), ]) save_checkpoints_secs = self.config.save_checkpoints_secs or self._params.train.save_checkpoints_secs save_checkpoints_steps = self.config.save_checkpoints_steps or self._params.train.save_checkpoints_steps if save_checkpoints_secs is None and save_checkpoints_steps is None: save_checkpoints_steps = 100000000 if (save_checkpoints_secs is not None) and (save_checkpoints_steps is not None): raise ValueError( "Can not provide both save_checkpoints_secs and save_checkpoints_steps." ) # We do not use barrier for the sync training. guard_listeners = [] if not self.config.enable_sync_training: guard_listeners.append( ckpt_hooks.BarrierSaverListener( barrier_op, max_pending_seconds=self._params.train. max_pending_seconds_for_barrier)) # In the rare case, we need to do the first save. # Otherwise, partial_recovery won't work and will go through initialization phase. # It is supposed to be # should_do_first_save = self.config.partial_recovery and ckpt_state is None # Here we just make it false because there are issues with uninitialized iterator. should_do_first_save = False if self.config.enable_model_dump: save_utils.NoFirstSaveCheckpointSaverHook._in_model_dump_mode = True saver_hook = save_utils.NoFirstSaveCheckpointSaverHook( model_dir, save_secs=save_checkpoints_secs, save_steps=save_checkpoints_steps, saver=saver, listeners=save_listeners, guard_saver_listeners=guard_listeners, save_graph_def=is_root_node, tide_start_hour=self.config.tide_start_hour, tide_start_minute=self.config.tide_start_minute, tide_end_hour=self.config.tide_end_hour, tide_end_minute=self.config.tide_end_minute, tide_save_secs=self.config.tide_save_secs, ignore_save_errors=self.config.enable_realtime_training, is_dense_only=False, use_native_multi_hash_table=self.config.use_native_multi_hash_table, no_first_save=not should_do_first_save) hooks.append(saver_hook) if not self.config.enable_sync_training: server_hook = server_hook_lib.ServerHook(model_dir, barrier_op, saver_hook) hooks.extend([server_hook]) dense_only_save_checkpoints_steps = self.config.dense_only_save_checkpoints_steps or self._params.train.dense_only_save_checkpoints_steps dense_only_save_checkpoints_secs = self.config.dense_only_save_checkpoints_secs or self._params.train.dense_only_save_checkpoints_secs if (dense_only_save_checkpoints_steps or dense_only_save_checkpoints_secs) and is_root_node: dense_saver = create_saver() dense_model_dir = os.path.join(model_dir, "dense_only") stats = tf.train.get_checkpoint_state(dense_model_dir) if stats: dense_saver.recover_last_checkpoints(stats.all_model_checkpoint_paths) dense_basename = os.path.join(dense_model_dir, "model.ckpt") export_dir_base = os.path.join(model_dir, self._params.serving.export_dir_base) save_utils.NoFirstSaveCheckpointSaverHook._has_dense_only = True dense_guard_listeners = guard_listeners if self.config.dense_only_stop_training_when_save else [] dense_saver_hook = save_utils.NoFirstSaveCheckpointSaverHook( dense_model_dir, save_secs=dense_only_save_checkpoints_secs, save_steps=dense_only_save_checkpoints_steps, saver=dense_saver, listeners=get_saver_listeners_for_exporting( dense_basename, export_dir_base=export_dir_base, dense_only=True, exempt_checkpoint_paths=exempt_checkpoint_paths), guard_saver_listeners=dense_guard_listeners, tide_start_hour=self.config.tide_start_hour, tide_start_minute=self.config.tide_start_minute, tide_end_hour=self.config.tide_end_hour, tide_end_minute=self.config.tide_end_minute, tide_save_secs=self.config.tide_save_secs, ignore_save_errors=self.config.enable_realtime_training, is_dense_only=True) if self.config.enable_sync_training and self.config.enable_realtime_training: hooks.append( sync_training_hooks.SyncTrainingSaverControlHook( model_dir, dense_saver_hook.timer)) hooks.append(dense_saver_hook) return tuple(hooks) def get_slow_start_hook(slow_start_steps: int): if slow_start_steps: return (session_run_hooks.CustomGlobalStepWaiterHook( int(slow_start_steps * np.log(1 + self.config.index)), max_non_tide_wait_minute=self.config.max_slow_start_wait_minute),) return () def get_tide_stopping_hook(): if self.config.tide_start_hour is not None and self.config.tide_end_hour is not None: return (session_run_hooks.TideStoppingHook( self.config.tide_start_hour, self.config.tide_start_minute, self.config.tide_end_hour, self.config.tide_end_minute),) return () def get_hooks_for_metrics(model_dir: str, save_steps: int): hooks = [] if self._params.metrics.enable_tf2_profiler_hook and is_chief(self.config): start_step = self.config.profile_some_steps_from end_step = None if start_step is None else start_step + 10 save_steps = self.config.profile_save_steps_interval hooks.append( Tf2ProfilerHook( logdir=model_dir, init_step_range=[start_step, end_step], save_steps=save_steps)) if self.config.profile_with_nvprof_from_to and is_chief(self.config): s, e = self.config.profile_with_nvprof_from_to.split(',') save_steps = self.config.profile_save_steps_interval hooks.append( NVProfilerHook(init_step_range=[int(s), int(e)], save_steps=save_steps)) if self._params.metrics.enable_throughput_hook and is_chief(self.config): hooks.append( ThroughputMetricHook( model_name=self.config.model_name, start_time_secs=self.config.containers_ready_time_secs, cluster_type=self.config.cluster_type)) return tuple(hooks) def variable_prefetch_enabled(): return not self.config.enable_sync_training and self.config.enable_variable_prefetch def get_cached_variable_context(): if variable_prefetch_enabled(): return tf.variable_creator_scope(variables.cached_variable_creator) return contextlib.nullcontext() def get_partitioner_variable_context(): if not is_exporting() and self.config.enable_variable_partition: # TODO(leqi.zou): This only works for tf.compat.v1.get_variable, # but not for tf.Variable. # # Finally, we can use something similar to PSStrategy to solve # this problem. logging.info("partition max_shards={}".format(self.config.num_ps)) return tf.compat.v1.variable_scope( "", partitioner=tf.compat.v1.variable_axis_size_partitioner( max_shard_bytes=1 << 17, max_shards=self.config.num_ps)) return contextlib.nullcontext() def get_variable_prefetch_hooks(): if variable_prefetch_enabled(): return (variables.FetchAllCachedVariablesHook(),) return () def get_itempool_hook(model_dir, mode): pools = tf.compat.v1.get_collection(POOL_KEY) if pools and mode != tf.estimator.ModeKeys.PREDICT: logging.info("append itempool_save_restore_hook in training_hooks") save_checkpoints_secs = self.config.save_checkpoints_secs or self._params.train.save_checkpoints_secs save_checkpoints_steps = self.config.save_checkpoints_steps or self._params.train.save_checkpoints_steps if save_checkpoints_secs is None and save_checkpoints_steps is None: save_checkpoints_steps = 100000000 item_pool_hook = ItemPoolSaveRestoreHook( model_dir=model_dir, save_steps=save_checkpoints_steps, mode=mode) if hasattr(self._task, 'add_training_hook'): self._task.add_training_hook(item_pool_hook) def model_fn(features: Dict[str, tf.Tensor], mode: str, config: tf.estimator.RunConfig): hash_table, hash_filters = create_hash_table_and_filters_fn() logging.info( f'\n> hash_table: {hash_table}\n> hash_filters: {hash_filters}') # For prefetch queue, collect auxiliary tensors get_itempool_hook(config.model_dir, mode=mode) auxiliary_bundle = {} async_function_mgr = prefetch_queue.AsyncFunctionMgr( is_async=self.config.enable_variable_postpush) self._task.ctx.async_function_mgr = async_function_mgr eof_key = sync_training_hooks.EofAwareTask.EOF_KEY if '2' in features: auxiliary_bundle[eof_key] = features.pop('2') logging.info(f'eof: {auxiliary_bundle[eof_key]}, {id(auxiliary_bundle[eof_key])}') features = features.pop('1') elif eof_key in features: auxiliary_bundle[eof_key] = features.pop(eof_key) logging.info(f'eof: {auxiliary_bundle[eof_key]}, {id(auxiliary_bundle[eof_key])}') def call_raw_model_fn(features): raw_model_fn = self._task.create_model_fn() with get_cached_variable_context(), get_partitioner_variable_context(): spec = raw_model_fn(features=features, mode=mode, config=config) return spec if self.config.enable_fused_layout: self._task.ctx.feature_factory = None lookup_callable_fn = hash_table.lookup( features, auxiliary_bundle, ret_lookup_callable_fn=True, embedding_prefetch_capacity=self.config.embedding_prefetch_capacity) # args are data we will transfer to remote deivce if needed. args = (auxiliary_bundle, features) logging.info( f"remote input: auxiliary_bundle[{auxiliary_bundle}], features:[{features}]" ) def call_model_fn(args): # add lookup_callable_fn here to support with_remote_gpu auxiliary_bundle_ = args[0] features_ = args[1] layout_embeddings = lookup_callable_fn(auxiliary_bundle_, features_) logging.info( f"hash_table lookup when enable_fused_layout res: {layout_embeddings} {auxiliary_bundle_} {features_}" ) auxiliary_bundle_.update(features_) # set layout_factory, this step must after embedding_prefetch self._task.ctx.layout_factory = feature.EmbeddingLayoutFactory( hash_table, layout_embeddings, auxiliary_bundle=auxiliary_bundle_, async_function_mgr=async_function_mgr, async_push=self.config.enable_embedding_postpush) return call_raw_model_fn(features_) else: if self.config.reorder_fids_in_data_pipeline: features, res_pack = features["1"] features = { k: v for k, v in features.items() if not isinstance(v, tf.RaggedTensor) } embedding_ragged_ids, res_pack = res_pack name_to_ids = { k: None for k in embedding_ragged_ids.keys() # None to keep interface, we will only use the keys } auxiliary_bundle["features"] = features auxiliary_bundle[ "embedding_ragged_ids"] = parser_utils.RaggedEncodingHelper.contract( embedding_ragged_ids) logging.info( f"input: auxiliary_bundle[{auxiliary_bundle}], features:[{features}]" ) embeddings, auxiliary_bundle = hash_table.lookup( name_to_ids, auxiliary_bundle, res_pack) dequeued_embeddings = embeddings else: embedding_ragged_ids: Dict[str, tf.RaggedTensor] = { k: v for k, v in features.items() if k in embedding_feature_names } features: Dict[str, tf.Tensor] = { # Dense features or labels k: v for k, v in features.items() if not isinstance(v, tf.RaggedTensor) } auxiliary_bundle["features"] = features auxiliary_bundle["embedding_ragged_ids"] = embedding_ragged_ids # 'feature_name' -> name_to_ids: Dict[str, tf.Tensor] = { k: v.values for k, v in embedding_ragged_ids.items() } # for MergedMultiTypeHashTable, lookup returns: embeddings: Dict[str, tf.Tensor] = hash_table.lookup(name_to_ids) logging.info( f"input: auxiliary_bundle[{auxiliary_bundle}], features:[{features}]" ) (dequeued_embeddings, auxiliary_bundle), q = enqueue_dicts_with_queue_return( (embeddings, auxiliary_bundle), capacity=self.config.embedding_prefetch_capacity) if q: hash_table.add_queue_hook(EnqueueHook(q)) dequeued_embedding_ragged_ids = auxiliary_bundle.pop( "embedding_ragged_ids") dequeued_features = auxiliary_bundle.pop("features") # record dequeued features, primarily for evaluation purposes for name, tensor in dequeued_embedding_ragged_ids.items(): tensor.feature_name = name tf.compat.v1.add_to_collection('dequeued_sparse_features', tensor) for name, tensor in dequeued_features.items(): try: tensor.feature_name = name tf.compat.v1.add_to_collection('dequeued_features', tensor) except Exception as e: logging.error(f'tensor name is {name}') logging.error(f'exception is {str(e)}') embedding_slices = feature.create_embedding_slices( dequeued_embeddings, dequeued_embedding_ragged_ids, feature_to_combiner, feature_to_unmerged_slice_dims) args = ( dequeued_features, embedding_slices, ) # To enable remote inference, we need to transfer all tensors # which are needed by using remote predict. All tensors used should be # listed in the parameter, otherwise export graph will complain that # tensor may come from another graph. # # Notice this is only for the inference. In the training, TensorFlow # will automatically add send/recv if tensors are on the different graph. def call_model_fn(args): self._task.ctx.layout_factory = None if self.config.enable_full_sync_training: # TODO(zouxuan): enable this for async training later on. self._task.ctx.feature_factory = _FusedCpuFeatureFactory( hash_table, dequeued_embeddings, args[1], get_req_time(args[0]), auxiliary_bundle=auxiliary_bundle, use_native_multi_hash_table=self.config. use_native_multi_hash_table) else: self._task.ctx.feature_factory = _CpuFeatureFactory( hash_table, dequeued_embedding_ragged_ids, dequeued_embeddings, args[1], get_req_time(args[0]), async_function_mgr=async_function_mgr, async_push=self.config.enable_embedding_postpush) return call_raw_model_fn({**args[0], **dequeued_embeddings}) spec = None if export_context.is_exporting( ) and export_context.get_current_export_ctx().with_remote_gpu: def remote_call(tensors): with tf.device("/device:GPU:0"): spec = call_model_fn(tensors) return spec.predictions g = export_context.get_current_export_ctx().sub_graph("dense_0") with g.as_default(): helper = export_utils.RemotePredictHelper("gpu_remote_call", args, remote_call) predictions = helper.call_remote_predict( model_name=f"{native_task_context.get().model_name or ''}:dense_0", old_model_name="dense_0") spec = tf.estimator.EstimatorSpec(mode=mode, predictions=predictions) else: spec = call_model_fn(args) if not is_exporting(): ps_monitor = save_utils.PsMonitor( self.config.num_ps ) if self.config.partial_recovery and self.config.num_ps > 1 else None training_hooks = () training_chief_hooks = () ckpt_helper = ckpt_hooks.WorkerCkptHelper(config.model_dir, self.config.index) # SetCurrentSessionHook must present first. training_hooks += (session_hooks.SetCurrentSessionHook(),) barrier_op = None enable_sync_hook = self.config.num_workers > 1 and not self.config.enable_sync_training sync_hook_helper = sync_hooks.TrainingHooksHelper( enable_sync_hook, self.config.num_workers, self.config.index, chief_timeout_seconds=self.config.chief_timeout_secs) if not self.config.enable_sync_training: barrier_op = barrier_ops.BarrierOp( self.config.num_workers, is_chief=is_chief(self.config), barrier_callbacks=[ ckpt_helper.create_save_iterator_callback(), ]) training_hooks += sync_hook_helper.training_hooks + ( barrier_ops.BarrierHook(self.config.index, barrier_op),) if self._params.mode == tf.estimator.ModeKeys.TRAIN: training_hooks += get_slow_start_hook( self._params.train.slow_start_steps) training_hooks += (ckpt_helper.create_restorer_hook(),) training_chief_hooks += (ps_check_hooks.PsHealthCheckerHook( ps_check_hooks.Config(barrier_op=barrier_op, num_ps=self.config.num_ps)),) # Prefetch hooks should be put after control hooks (like slow start hook). # Just in case, we have shuffle in dataset and we start reading too much # data before we actually start the training. training_hooks += tuple(hash_table.get_queue_hooks()) training_hooks += get_variable_prefetch_hooks() training_hooks += tuple(self._task.ctx.async_function_mgr.hooks) """ Make sure sync hook running after restore hook and before save hook. The running order is similar to: 'after_create_session' : restore, chief start, worker start 'end' : worker end, save, chief end """ training_chief_hooks += ( get_hooks_for_restore(config.model_dir, hash_filters, ps_monitor) + sync_hook_helper.training_chief_hooks + get_hooks_for_save( config.model_dir, hash_filters, barrier_op, ps_monitor) + get_hooks_for_metrics(config.model_dir, config.save_summary_steps)) predicting_hooks = (get_hooks_for_restore(config.model_dir, hash_filters, ps_monitor)) if self.config.enable_partial_sync_training and self.config.index != 0: elements = [] local_init_ops = tf.compat.v1.get_collection( tf.compat.v1.GraphKeys.LOCAL_INIT_OP) if local_init_ops: elements.extend(local_init_ops) else: local_init_op = tf.compat.v1.train.Scaffold.get_or_default( 'local_init_op', tf.compat.v1.GraphKeys.LOCAL_INIT_OP, tf.compat.v1.train.Scaffold.default_local_init_op) elements.append(local_init_op) init_ops = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.INIT_OP) if init_ops: elements.extend(init_ops) else: def default_init_op(): return tf.group( tfvariables.global_variables_initializer(), resources.initialize_resources(resources.shared_resources())) init_op = tf.compat.v1.train.Scaffold.get_or_default( 'init_op', tf.compat.v1.GraphKeys.INIT_OP, default_init_op) elements.append(init_op) logging.info(f'local_init_op is {elements}') scaffold = tf.compat.v1.train.Scaffold( local_init_op=tf.group(elements) if elements else None, ready_for_local_init_op=NoOp()) else: scaffold = None spec = spec._replace( training_chief_hooks=training_chief_hooks + spec.training_chief_hooks, training_hooks=training_hooks + spec.training_hooks, prediction_hooks=predicting_hooks + spec.prediction_hooks, scaffold=scaffold) logging.info("Training Chief Hooks: {}".format( spec.training_chief_hooks)) logging.info("Training Hooks: {}".format(spec.training_hooks)) logging.info(f'dequeue input: auxiliary_bundle[{auxiliary_bundle}], features:[{features}]') dequeued_eof = auxiliary_bundle.pop(eof_key, None) if dequeued_eof is not None: tf.compat.v1.add_to_collection(eof_key, dequeued_eof) logging.info(f'eof dequeue: {dequeued_eof}, {id(dequeued_eof)}') return spec def wrapped_model_fn(features: Dict[str, tf.Tensor], mode: str, config: tf.estimator.RunConfig): with native_task_context.with_ctx( make_native_task_context(self.config, self._sync_backend)): return model_fn(features, mode, config) return wrapped_model_fn def create_serving_input_receiver_fn(self): return self._task.create_serving_input_receiver_fn() @dataclasses.dataclass class DistributedCpuTrainingConfig(CpuTrainingConfig): """The training config for distributed training. attributes: :param model_dir: The directory where the model is load/saved. :param tensorboard_log_path: The logdir of tensorboard, use model_dir instead if empty :param intra_op_parallelism_threads: intra_op parallelism threads. :param inter_op_parallelism_threads: inter_op parallelism threads. :param num_extra_ps: The number of extra ps for ps benchmark. :param num_redundant_ps: The number of redundant ps for quickly starting. We will pick |num_ps| from |num_ps + num_extra_ps + num_redundant_ps| ps :param uuid: uuid of cpu training. :param operation_timeout_in_ms: Global timeout for all blocking operations in this session. :param session_creation_timeout_secs: Max time workers should wait for a session to become available. :param max_retry_times: Maximum retry times for workers to start train. :param retry_wait_in_secs: Sleep time interval to wait for worker retry. :param fountain_zk_host: zk_host for fountain service. :param fountain_model_name: model_name for fountain service. :param dc_aware: data-center aware or not. """ model_dir: str = "" tensorboard_log_path: str = "" intra_op_parallelism_threads: int = 8 inter_op_parallelism_threads: int = 16 num_extra_ps: int = 0 num_redundant_ps: int = 0 uuid: str = "" operation_timeout_in_ms: int = -1 session_creation_timeout_secs: int = 7200 max_retry_times: int = 0 retry_wait_in_secs: int = 30 fountain_zk_host: str = "" fountain_model_name: str = "" dc_aware: bool = False def _prepare_server(target: str, config: DistributedCpuTrainingConfig): """Do some preparation before we register the server to the server discovery""" if config.server_type == "ps": session_config = cluster_manager.generate_session_config() with tf.compat.v1.Session(target, config=session_config) as sess: # Creates machine info so the following access won't create new machine info. sess.run( logging_ops.machine_info( shared_name=ps_check_hooks.get_ps_machine_info_shared_name( config.index))) def _shutdown_ps(target, cluster, task, num_ps): cluster = copy.deepcopy(cluster) # Worker has already shutdowned. if "worker" in cluster: del cluster["worker"] session_config = cluster_manager.generate_session_config((cluster, task)) with tf.compat.v1.Session(target, config=session_config) as sess: for i in range(num_ps): logging.info('Try to shutdown ps {}'.format(i)) with tf.device(utils.ps_device(i)): queue = tf.queue.FIFOQueue(1, tf.int32, shared_name="ps_queue_" + str(i)) sess.run(queue.enqueue(1)) logging.info('Shutdown ps {} successfully!'.format(i)) def _join_ps(target, ps_index, sync_backend: SyncBackend = None): session_config = cluster_manager.generate_session_config() with tf.compat.v1.Session(target, config=session_config) as sess: queue = tf.queue.FIFOQueue(1, tf.int32, shared_name="ps_queue_" + str(ps_index)) finished = False t = None if sync_backend is not None: sync_client = ParameterSyncClient( distributed_serving_ops.parameter_sync_client_from_config( name_suffix=str(ps_index))) sync_config_str = tf.compat.v1.placeholder(tf.string, shape=(), name="sync_config_str") sync_run_step = sync_client.create_sync_op(sync_config_str) def parameter_sync_job(sess, sync_run_step: tf.Tensor): with sess.graph.as_default( ): # To make sure the graphs in/out thread are same nonlocal finished while not finished: start = timeit.default_timer() try: sess.run(sync_run_step, feed_dict={ sync_config_str: distributed_serving_ops.refresh_sync_config( sync_backend, ps_index) }, options=tf.compat.v1.RunOptions(timeout_in_ms=1000 * 60)) except tf.errors.OpError as e: logging.error('Error occurred when synchronizing parameter: %s', str(e)) exc_type, exc_value, exc_traceback_obj = sys.exc_info() logging.error(f"exc_type: {exc_type}") logging.error(f"exc_value: {exc_value}") traceback.print_tb(exc_traceback_obj, limit=10) total_cost = timeit.default_timer() - start # Synchronizing parameter per 10 seconds time.sleep(max(0, 10 - total_cost)) logging.info( "Ps {} received chief's shutdown signal...".format(ps_index)) t = threading.Thread(target=parameter_sync_job, args=(sess, sync_run_step)) t.start() logging.info( 'Ps {} started a thread for parameter sync!'.format(ps_index)) # Try to dequeue, if success means chief will finish soon. sess.run(queue.dequeue()) finished = True if t: t.join() logging.info("Ps {} shutdown successfully!".format(ps_index)) def _get_blocked_addrs(cluster: Dict, ignored_jobs: Set = {}): cluster_spec = tf.train.ClusterSpec(cluster) addrs = set() for job in cluster_spec.jobs: if job not in ignored_jobs: for addr in cluster_spec.job_tasks(job): addrs.add(addr) return list(addrs) class NodeAliveCheckerError(Exception): def __init__(self, msg): super(NodeAliveCheckerError, self).__init__(self) self.msg = msg def __str__(self): return self.msg def _do_worker_train(config: DistributedCpuTrainingConfig, params: InstantiableParams, cluster: Dict, task: Dict, user_hooks = None): params.mode = config.mode native_task = params.instantiate() if not isinstance(native_task, NativeTask): raise ValueError( "distributed train only support NativeTask. Got {}".format(native_task)) if params.serving.with_remote_gpu and config.enable_model_dump: # 当with_remote_gpu=True时,export时会在dense subgraph中运行model_fn # 导致dump下来的infer model不完整,缺少serving_input_receiver_fn信息 raise ValueError("unsupport enable_model_dump while with_remote_gpu=True") session_config = cluster_manager.generate_session_config((cluster, task)) session_config.operation_timeout_in_ms = config.operation_timeout_in_ms check_addrs = _get_blocked_addrs(cluster=cluster, ignored_jobs={'worker'}) alive_checker = net_utils.NodeAliveChecker(check_addrs, timeout=60) if not alive_checker.all_nodes_alive(): raise NodeAliveCheckerError("{} is unreachable".format(','.join( alive_checker.get_dead_nodes()))) os.environ["TF_CONFIG"] = json.dumps({"cluster": cluster, "task": task}) try: update_session_config_for_gpu(session_config) run_config = tf.estimator.RunConfig( model_dir=config.model_dir, session_config=session_config, save_summary_steps=config.save_summary_steps * config.num_workers, log_step_count_steps=config.log_step_count_steps * config.num_workers, session_creation_timeout_secs=config.session_creation_timeout_secs, device_fn=config.device_fn) training = CpuTraining(config, native_task) if config.enable_partial_sync_training or config.use_dataservice: training = sync_training_hooks.EofAwareTask(training, config.use_dataservice) estimator = tf.estimator.Estimator(training.create_model_fn(), config=run_config) if is_chief(config): _save_debugging_info(config, cluster, training) run_hooks = get_sync_run_hooks(False) if user_hooks is not None: run_hooks += user_hooks estimator.train(training.create_input_fn(config.mode), hooks=run_hooks, max_steps=params.train.max_steps) if is_chief(config) and config.enable_resource_constrained_roughsort and config.mode==tf.estimator.ModeKeys.TRAIN: logging.info(f"roughsort_items_use_parquet: {config.roughsort_items_use_parquet}") if config.roughsort_items_use_parquet: items_path = os.path.join(config.model_dir, "candidate_items.pb") _convert_parquets_to_instance(config.roughsort_candidate_items_path, items_path) else: items_path = config.roughsort_candidate_items_path logging.info("Start to evaluate item data...") # params.p.only_save_item_cache_hashtable = True params.mode = tf.estimator.ModeKeys.PREDICT native_task = params.instantiate() training = CpuTraining(config, native_task) if config.enable_partial_sync_training or config.use_dataservice: training = sync_training_hooks.EofAwareTask(training, config.use_dataservice) estimator = tf.estimator.Estimator(training.create_model_fn(), config=run_config) estimator.train(training._task.create_item_input_fn( items_path), max_steps=params.train.max_steps) finally: # TODO(leqi.zou): we have some thread safety issue in the test. if "TF_CONFIG" in os.environ: del os.environ["TF_CONFIG"] return estimator _EXTRA_PS_BENCHMARK_SECS = 120 def _run_ps_benchmark(config: DistributedCpuTrainingConfig, num_ps_required: int, cluster: dict, task: dict, user_hooks): config = copy.deepcopy(config) cluster = copy.deepcopy(cluster) bm_params = ps_benchmark.PsBenchMarkTask.params() ps_list = copy.copy(cluster["ps"]) bm_params.bm_config = ps_benchmark.BenchmarkConfig( ps_list=ps_list, num_ps_required=num_ps_required, num_workers=config.num_workers, index=config.index, benchmark_secs=_EXTRA_PS_BENCHMARK_SECS) config.num_ps += config.num_extra_ps config.model_dir = os.path.join(config.model_dir, "benchmark_dir") config.operation_timeout_in_ms = int(_EXTRA_PS_BENCHMARK_SECS * 1000 + 30 * 1000) logging.info("Run PS benchmark") _do_worker_train(config, bm_params, cluster, task, user_hooks) cluster["ps"] = ps_list return cluster def _save_debugging_info(config: DistributedCpuTrainingConfig, cluster: dict, training: CpuTraining): debugging_info = debugging_info_pb2.DebuggingInfo() debugging_info.cluster.chief_addr = cluster["chief"][0] for addr in cluster["ps"]: debugging_info.cluster.ps_addrs.append(addr) debugging_info.num_workers = config.num_workers for k, v in training.feature_configs[0].items(): feature_name_config = debugging_info.feature_name_configs.add() feature_name_config.feature_name = k feature_name_config.config_str = str(v) debugging_info_file_name = utils.get_debugging_info_file_name( config.model_dir) tf.io.gfile.makedirs(os.path.dirname(debugging_info_file_name)) file_io.atomic_write_string_to_file(debugging_info_file_name, debugging_info.SerializeToString()) def _get_replica_device_setter(config): if config.task_type: worker_device = '/job:%s/task:%d' % (config.task_type, config.task_id) else: worker_device = '/job:worker' if config.num_ps_replicas > 0: from tensorflow.python.training import device_setter return tf.compat.v1.train.replica_device_setter( ps_tasks=config.num_ps_replicas, worker_device=worker_device, merge_devices=True, ps_ops=list(device_setter.STANDARD_PS_OPS), cluster=config.cluster_spec) else: return None def _do_worker_feature_engineering(target, config: DistributedCpuTrainingConfig, params: InstantiableParams, cluster: Dict, task: Dict): logging.info("do worker feature engineering. mode: %s.", config.mode) params.mode = config.mode native_task = params.instantiate() if not isinstance(native_task, NativeTask): raise ValueError( "distributed train only support NativeTask. Got {}".format(native_task)) session_config = cluster_manager.generate_session_config((cluster, task)) session_config.operation_timeout_in_ms = config.operation_timeout_in_ms check_addrs = _get_blocked_addrs(cluster=cluster, ignored_jobs={'worker'}) alive_checker = net_utils.NodeAliveChecker(check_addrs, timeout=60) if not alive_checker.all_nodes_alive(): raise NodeAliveCheckerError("{} is unreachable".format(','.join( alive_checker.get_dead_nodes()))) os.environ["TF_CONFIG"] = json.dumps({"cluster": cluster, "task": task}) run_config = tf.estimator.RunConfig( model_dir=config.model_dir, session_config=session_config, save_summary_steps=config.save_summary_steps * config.num_workers, log_step_count_steps=config.log_step_count_steps * config.num_workers, session_creation_timeout_secs=config.session_creation_timeout_secs) device_fn = _get_replica_device_setter(run_config) with tf.Graph().as_default() as g, g.device(device_fn): dataset = native_task.input_fn(config.mode) itr = tf.compat.v1.data.make_initializable_iterator(dataset) nxt_elem = itr.get_next() # TODO(ltli): 后续转 ExampleBatch 的逻辑转移到 C++,效率更高 fe_save_hook = feature_engineering_hooks.FeatureEngineeringSaveHook( config, nxt_elem) with tf.compat.v1.train.MonitoredTrainingSession( target, hooks=[fe_save_hook], config=session_config) as sess: sess.run(itr.initializer) while not sess.should_stop(): sess.run(nxt_elem) if "TF_CONFIG" in os.environ: del os.environ["TF_CONFIG"] logging.info("finish worker feature engineering. mode: %s.", config.mode) return None def make_config_backward_compatible(model_dir: str, config: CpuTrainingConfig): # Will remove this compatible logic after 1/1/2023 if config.use_native_multi_hash_table is None: monolith_ckpt = save_utils.get_monolith_checkpoint_state(model_dir) if monolith_ckpt is not None and monolith_ckpt.builtin_hash_table_type in ( monolith_checkpoint_state_pb2.MonolithCheckpointState.UNKNOWN, monolith_checkpoint_state_pb2.MonolithCheckpointState.CUCKOO_HASH_MAP): config.use_native_multi_hash_table = False def distributed_train(config: DistributedCpuTrainingConfig, discovery: ServiceDiscovery, params: InstantiableParams, sync_backend: SyncBackend = None, user_hooks = None): """Trains the server in a distributed fashion.""" if config.index is None: raise ValueError("Index can't be none.") if config.num_ps is None: raise ValueError("Num ps can't be none.") if config.num_workers is None: raise ValueError("Num workers can't be none.") if not config.server_type in ["ps", "worker"]: raise ValueError("Unknown server type. type: {}".format(config.server_type)) if not config.model_dir: raise ValueError("model dir can't be empty.") if is_chief(config): FLAGS.monolith_alert_proto = FLAGS.monolith_chief_alert_proto make_config_backward_compatible(config.model_dir, config) server_config = tf.compat.v1.ConfigProto( intra_op_parallelism_threads=config.intra_op_parallelism_threads, inter_op_parallelism_threads=config.inter_op_parallelism_threads) if isinstance(discovery, (MLPServiceDiscovery, TfConfigServiceDiscovery)): addr = discovery.addr config.index = discovery.index server = tf.distribute.Server({"local": [addr]}, config=server_config) else: assert isinstance(discovery, ServiceDiscovery) ip = yarn_runtime.get_local_host() server = tf.distribute.Server( {"local": [net_utils.concat_ip_and_port(ip, 0)]}, config=server_config) addr = urlparse(server.target).netloc _prepare_server(server.target, config) discovery.register(config.server_type, config.index, addr) logging.info("Started %s %d at %s.", config.server_type, config.index, addr) estimator = None metric_heart_beat_thread = None if config.server_type == "ps": if not config.model_name: if isinstance(params, InstantiableParams): default_name = f'di_name_{params.cls.__name__}' else: default_name = f'di_name_{params.__class__.__name__}' config.model_name = params.metrics.deep_insight_name or default_name with native_task_context.with_ctx( make_native_task_context(config, sync_backend)): _join_ps(server.target, config.index, sync_backend) elif config.server_type == "worker": num_retries, worker_failover_cnt = 0, 0 max_retries = config.max_retry_times or ( 6 if config.partial_recovery and not config.enable_sync_training else 0) cluster, task = {}, {} num_required_ps = config.num_ps + config.num_extra_ps def _get_cluster_and_task(): cluster, task = cluster_manager.get_training_cluster( discovery, addr, config.index, config.num_redundant_ps, num_required_ps, config.num_workers, config.model_dir, config.uuid, params.metrics.deep_insight_name, config.cluster_type) filtered_cluster = copy.copy(cluster) if config.submit_time_secs and config.index == 0 and params.metrics.deep_insight_name: container_ready_elapsed_time = int( time.time()) - config.submit_time_secs logging.info( "Containers ready took {}s.".format(container_ready_elapsed_time)) tags = { "model_name": config.model_name or params.metrics.deep_insight_name, "cluster_type": config.cluster_type } cli.get_cli(utils.get_metric_prefix()).emit_timer( "container_ready_elapsed_time.all", container_ready_elapsed_time, tags) config.containers_ready_time_secs = int(time.time()) if config.num_extra_ps: filtered_cluster = _run_ps_benchmark(config, config.num_ps, filtered_cluster, task, user_hooks) return filtered_cluster, task cluster, task = _get_cluster_and_task() if is_chief(config): metric_heart_beat_thread = _MetricsHeartBeatThread() metric_heart_beat_thread.start() captured_exception = None start_ts = datetime.timestamp(datetime.now()) logging.info("Worker Start %s", str(start_ts)) logging.info("only_feature_engineering: {}.".format( config.only_feature_engineering)) try: while True: try: if config.only_feature_engineering: estimator = _do_worker_feature_engineering(server.target, config, params, cluster, task) else: if config.enable_gpu_training: device_utils.enable_gpu_training() params.train.use_gpu_emb_table = False estimator = _do_worker_train(config, params, cluster, task, user_hooks) break except (tf.errors.DeadlineExceededError, tf.errors.UnavailableError, NodeAliveCheckerError) as e: worker_failover_cnt += 1 tags = { "model_name": config.model_name, "worker_index": str(config.index) } cli.get_cli(utils.get_metric_prefix()).emit_timer( "worker_failover_cnt", f'worker_failover_cnt: {worker_failover_cnt}, msg: {e}', tags) time.sleep(config.retry_wait_in_secs) old_cluster = cluster cluster, task = _get_cluster_and_task() if cluster == old_cluster: logging.info('Temporary error: %s. Retrying...', str(e)) continue num_retries += 1 if num_retries <= max_retries: logging.error( 'error is "{}", we try to the {}-th retry, sleep for {} seconds!' .format(e, num_retries, config.retry_wait_in_secs)) else: captured_exception = e raise e except Exception as e: captured_exception = e raise e finally: if is_chief(config): try: if metric_heart_beat_thread: metric_heart_beat_thread.stop() if config.num_redundant_ps or config.num_extra_ps: num_required_ps += config.num_redundant_ps # Query the total ps cluster for shutdown. cluster, task = cluster_manager.get_training_cluster( discovery, addr, config.index, config.num_redundant_ps, num_required_ps, config.num_workers, config.model_dir, config.uuid) # In the realtime training, we want to keep ps alive so we can # restart chief without side effect. if not config.enable_realtime_training or config.force_shutdown_ps: _shutdown_ps(server.target, cluster, task, num_required_ps) finally: if captured_exception is None: yarn_runtime.maybe_finish_application() else: success = yarn_runtime.maybe_kill_application( str(captured_exception)) end_ts = datetime.timestamp(datetime.now()) logging.info("Worker End %s, Cost: %s(s)", str(end_ts), str(end_ts - start_ts)) logging.info("Finished %s %d.", config.server_type, config.index) return estimator def distributed_sync_train(config: DistributedCpuTrainingConfig, params: InstantiableParams, sync_backend: SyncBackend = None, user_hooks = None): """ This is the entry point for synchronous distributed training. This system allows the model to train in a half sync manner as well, when set embedding_prefetch_capacity value > 0. All the dense parameters are synced via allreduce and no asynchronicity is allowed for dense paramters. No Worker num is needed, the system derives the number of workers via MPI API. Args: config: the configs for monolith cpu training. params: the parameters for the model and other modules. """ assert get_mpi_rank() == config.index, \ "Given RunConfig.index should be consistent with hvd.rank()." # To remove this contraint future if config.enable_gpu_training: device_utils.enable_gpu_training() params.train.use_gpu_emb_table = True task = params.instantiate() if not isinstance(task, NativeTask): raise ValueError( "distributed train only support NativeTask. Got {}".format(task)) training = CpuTraining(config, task) training = sync_training_hooks.EofAwareTask(training, config.use_dataservice) session_config = tf.compat.v1.ConfigProto(allow_soft_placement=False, log_device_placement=False) # CPU Configs session_config.intra_op_parallelism_threads = config.intra_op_parallelism_threads session_config.inter_op_parallelism_threads = config.inter_op_parallelism_threads # GPU Configs update_session_config_for_gpu(session_config) # By default the grappler (meta_optimizer) is enabled. # session_config.graph_options.rewrite_options.disable_meta_optimizer = True session_config.graph_options.rewrite_options.memory_optimization = 1 if os.environ.get('TF_XLA_FLAGS', None): session_config.graph_options.optimizer_options.global_jit_level = 1 # We reduce the frequency of saving to HDFS summary, otherwise it slows down # the training. # TODO(zouxuan): always use the TF v2 summary with flush will fix this issue. class Nop(object): def nop(*args, **kwargs): pass def __getattr__(self, _): return self.nop # only rank 0 writes events if config.index != 0: SummaryWriterCache.get = lambda _: Nop() run_config = tf.estimator.RunConfig( model_dir=config.model_dir, device_fn=device_utils.default_device_fn, session_config=session_config, save_summary_steps=None if config.index != 0 else int( os.environ.get('MONOLITH_SAVE_SUMMARY_INTERVAL', config.save_summary_steps)), log_step_count_steps=params.train.max_steps if config.index != 0 else int( os.environ.get('MONOLITH_ROOT_LOG_INTERVAL', config.log_step_count_steps))) estimator = tf.estimator.Estimator(training.create_model_fn(), config=run_config) run_hooks = get_sync_run_hooks(True) if sync_backend is not None: run_hooks.append( sync_training_hooks.ParameterSyncHook(sync_backend, config.index)) run_hooks.append(sync_training_hooks.SyncTrainingInfoHook()) if user_hooks is not None: run_hooks += user_hooks estimator.train(training.create_input_fn(config.mode), hooks=run_hooks, max_steps=params.train.max_steps) logging.info("Finished worker %d.", config.index) return estimator def local_train_internal(params: InstantiableParams, conf: CpuTrainingConfig, model_dir: str, steps: int = 100, profiling: bool = False, user_hooks = None) -> tf.estimator.Estimator: """Do a local training. Especially useful in the local demo.""" if tf.compat.v1.executing_eagerly(): raise EnvironmentError( "Local train is not supported in the eager mode. Please call `tf.compat.v1.disable_eager_execution()`" ) task = params.instantiate() if conf.num_ps <= 0: session_config = tf.compat.v1.ConfigProto() training = CpuTraining(conf, task) if "TF_CONFIG" in os.environ: del os.environ["TF_CONFIG"] else: training = CpuTraining(conf, task) ps_servers = [] for _ in range(conf.num_ps): ps_servers.append(tf.distribute.Server.create_local_server()) master = tf.distribute.Server.create_local_server() def get_addr(server: tf.distribute.Server): return server.target[len('grpc://'):] cluster = { "chief": [get_addr(master)], "ps": [get_addr(server) for server in ps_servers], } os.environ["TF_CONFIG"] = json.dumps({ "cluster": cluster, "task": { "type": "chief", "index": 0 } }) spec = tf.train.ClusterSpec(cluster) session_config = tf.compat.v1.ConfigProto( cluster_def=spec.as_cluster_def(), allow_soft_placement=True, share_cluster_devices_in_session=True) session_config.experimental.share_session_state_in_clusterspec_propagation = True # grappler doesn't really understand RaggedTensor. session_config.graph_options.rewrite_options.disable_meta_optimizer = True config = tf.estimator.RunConfig(model_dir=model_dir, session_config=session_config, save_summary_steps=conf.save_summary_steps, log_step_count_steps=conf.log_step_count_steps) estimator = tf.estimator.Estimator(training.create_model_fn(), config=config) estimator.train(training.create_input_fn(tf.estimator.ModeKeys.TRAIN), hooks=user_hooks, steps=steps) if conf.enable_resource_constrained_roughsort and conf.mode == tf.estimator.ModeKeys.TRAIN: logging.info(f"roughsort_items_use_parquet: {conf.roughsort_items_use_parquet}") if conf.roughsort_items_use_parquet: items_path = os.path.join(conf.model_dir, "candidate_items.pb") _convert_parquets_to_instance(conf.roughsort_candidate_items_path, items_path) else: items_path = conf.roughsort_candidate_items_path params.mode = tf.estimator.ModeKeys.PREDICT # params.p.only_save_item_cache_hashtable = True task = params.instantiate() training = CpuTraining(conf, task) config = tf.estimator.RunConfig(model_dir=model_dir, session_config=session_config, save_summary_steps=conf.save_summary_steps, log_step_count_steps=conf.log_step_count_steps) estimator = tf.estimator.Estimator(training.create_model_fn(), config=config) estimator.train(training._task.create_item_input_fn( items_path), steps=steps) if "TF_CONFIG" in os.environ: del os.environ["TF_CONFIG"] return estimator def local_feature_engineering_internal( params: InstantiableParams, conf: CpuTrainingConfig, model_dir: str, profiling: bool = False) -> tf.estimator.Estimator: """Do a local feature engineer. Especially useful in the local demo.""" if tf.compat.v1.executing_eagerly(): raise EnvironmentError( "Local train is not supported in the eager mode. Please call `tf.compat.v1.disable_eager_execution()`" ) task = params.instantiate() if conf.num_ps <= 0: session_config = tf.compat.v1.ConfigProto() if "TF_CONFIG" in os.environ: del os.environ["TF_CONFIG"] else: ps_servers = [] for _ in range(conf.num_ps): ps_servers.append(tf.distribute.Server.create_local_server()) master = tf.distribute.Server.create_local_server() def get_addr(server: tf.distribute.Server): return server.target[len('grpc://'):] cluster = { "chief": [get_addr(master)], "ps": [get_addr(server) for server in ps_servers], } os.environ["TF_CONFIG"] = json.dumps({ "cluster": cluster, "task": { "type": "chief", "index": 0 } }) spec = tf.train.ClusterSpec(cluster) session_config = tf.compat.v1.ConfigProto( cluster_def=spec.as_cluster_def(), allow_soft_placement=True, share_cluster_devices_in_session=True) session_config.experimental.share_session_state_in_clusterspec_propagation = True # grappler doesn't really understand RaggedTensor. session_config.graph_options.rewrite_options.disable_meta_optimizer = True if not model_dir: model_dir = "/tmp/{}/{}".format(getpass.getuser(), params.name) run_config = tf.estimator.RunConfig(model_dir=model_dir, session_config=session_config, save_summary_steps=conf.save_summary_steps, log_step_count_steps=conf.log_step_count_steps) device_fn = _get_replica_device_setter(run_config) if profiling: tf.profiler.experimental.start(model_dir) with tf.Graph().as_default() as g, g.device(device_fn): dataset = task.input_fn(conf.mode) itr = tf.compat.v1.data.make_initializable_iterator(dataset) nxt_elem = itr.get_next() fe_save_hook = feature_engineering_hooks.FeatureEngineeringSaveHook( conf, nxt_elem) with tf.compat.v1.train.MonitoredTrainingSession( master.target, hooks=[fe_save_hook], config=session_config) as sess: sess.run(itr.initializer) while not sess.should_stop(): sess.run(nxt_elem) if profiling: tf.profiler.experimental.stop() if "TF_CONFIG" in os.environ: del os.environ["TF_CONFIG"] return None def local_train(params: InstantiableParams, num_ps=0, model_dir: str = None, steps=100, save_checkpoints_steps=50, profiling=False, enable_embedding_prefetch: bool = True, enable_embedding_postpush: bool = True, remove_model_dir_if_exists: bool = True, only_feature_engineering: bool = False): embedding_prefetch_capacity = 1 if enable_embedding_prefetch else 0 conf = CpuTrainingConfig( model_name=params.name, num_ps=num_ps, embedding_prefetch_capacity=embedding_prefetch_capacity, enable_embedding_postpush=enable_embedding_postpush, save_checkpoints_steps=save_checkpoints_steps) if not model_dir: model_dir = "/tmp/{}/{}".format(getpass.getuser(), params.name) if remove_model_dir_if_exists: try: tf.io.gfile.rmtree(model_dir) except tf.errors.NotFoundError: pass make_config_backward_compatible(model_dir, conf) if only_feature_engineering: return local_feature_engineering_internal(params, conf, model_dir, profiling) else: return local_train_internal(params, conf, model_dir, steps, profiling) ================================================ FILE: monolith/native_training/cpu_training_distributed_test_binary.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import io import os import time from absl import app from absl import flags from absl import logging import tensorflow as tf from monolith.native_training import cluster_manager from monolith.native_training import cpu_training from monolith.native_training import feature from monolith.native_training import native_task from monolith.native_training import service_discovery from monolith.native_training import utils flags.DEFINE_integer("test_case", None, "The number of test case.") flags.DEFINE_string("test_dir", None, "The test folder.") flags.DEFINE_string("server_type", None, "The type of this process. Can be 'ps' or 'worker'") flags.DEFINE_integer("index", None, "The index of the current process in servers.") flags.DEFINE_integer("num_ps", None, "The number of ps") flags.DEFINE_integer("num_workers", None, "The number of worker") flags.DEFINE_integer("num_extra_ps", 0, "The number of extra ps.") flags.DEFINE_integer("num_redundant_ps", 0, "The number of redundant ps.") flags.DEFINE_string("uuid", "", "uuid") flags.DEFINE_bool("use_native_multi_hash_table", False, "Use native MultiHashTable.") FLAGS = flags.FLAGS def _sleep_short(): time.sleep(0.1) # In the test, we want query as fast as possible. cluster_manager._cluster_query_failure_handler = _sleep_short cpu_training._EXTRA_PS_BENCHMARK_SECS = 0.5 class SyncHook(tf.estimator.SessionRunHook): def __init__(self, num_workers, index): self._num_workers = num_workers self._index = index self._var = None self._assign_op = None def begin(self): collections = [tf.compat.v1.GraphKeys.LOCAL_VARIABLES ] if self._index == 0 else [tf.compat.v1.GraphKeys.VARIABLES] self._var = tf.compat.v1.get_variable( "TEST_SYNC_VAR", initializer=[False] * self._num_workers, dtype=tf.bool, trainable=False, collections=collections, ) self._assign_op = self._var[self._index].assign(True) def after_create_session(self, session, coord): session.run(self._assign_op) if self._index == 0: # To prevent chief finishing before other workers start while True: if sum(session.run(self._var)) == self._num_workers: break time.sleep(0.5) class FeatureTask(native_task.NativeTask): """A test task that will collect some information in model_fn.""" @classmethod def params(cls): p = super().params() p.define("training_hooks", [], "Training hooks") return p def create_input_fn(self, mode): del mode def input_fn(): return tf.data.Dataset.from_tensors( {"feature": tf.ragged.constant([[0, 1]], dtype=tf.int64)}) return input_fn def create_model_fn(self): def model_fn(mode, features, config): slot = self.ctx.feature_factory.create_feature_slot( feature.FeatureSlotConfig(name="slot")) s = slot.add_feature_slice(5) fc = feature.FeatureColumnV1(slot, "feature") embedding = fc.embedding_lookup(s) if mode == tf.estimator.ModeKeys.PREDICT: return tf.estimator.EstimatorSpec(mode, predictions=tf.constant(0)) all_embeddings = [fc.get_all_embeddings_concat()] grads = tf.gradients(-embedding, all_embeddings) global_step = tf.compat.v1.train.get_or_create_global_step() train_op = tf.group( global_step.assign_add(1), self._ctx.feature_factory.apply_gradients(zip(grads, all_embeddings)), features["feature"]) return tf.estimator.EstimatorSpec( mode, train_op=train_op, loss=tf.constant(0.0), training_hooks=[SyncHook(FLAGS.num_workers, FLAGS.index)] + self.p.training_hooks) return model_fn class HostServiceDiscovery(service_discovery.ServiceDiscovery): def __init__(self, base_path: str): self._base_path = base_path def register(self, name: str, index: int, addr: str): os.makedirs(self._named_path(name), exist_ok=True) with io.open(os.path.join(self._named_path(name), str(index)), "w") as writer: writer.write(addr) def deregister(self, name: str, index: int, addr: str): pass def query(self, name: str): basepath = self._named_path(name) if not os.path.exists(basepath): return {} indexes = os.listdir(basepath) result = {} for index in indexes: f = os.path.join(basepath, index) with io.open(f, "r") as reader: addr = reader.read() result[int(index)] = addr return result def _named_path(self, name: str): return os.path.join(self._base_path, name) def test_run(params): model_dir = os.path.join(FLAGS.test_dir, f"{FLAGS.uuid}/model") config = cpu_training.DistributedCpuTrainingConfig( server_type=FLAGS.server_type, index=FLAGS.index, num_ps=FLAGS.num_ps, num_extra_ps=FLAGS.num_extra_ps, num_redundant_ps=FLAGS.num_redundant_ps, num_workers=FLAGS.num_workers, model_dir=model_dir, uuid=FLAGS.uuid, enable_model_ckpt_info=True, use_native_multi_hash_table=FLAGS.use_native_multi_hash_table) # It is not easy to prevent worker doing things too fast params.train.max_pending_seconds_for_barrier = 2 discovery = HostServiceDiscovery( os.path.join(FLAGS.test_dir, f"{FLAGS.uuid}/service_discovery")) cpu_training.distributed_train(config, discovery, params) def test0(): params = FeatureTask.params() params.name = "test_task" test_run(params) def test1(): def no_shutdown(*args, **kwargs): while True: time.sleep(1) cpu_training._shutdown_ps = no_shutdown test0() class RaiseErrorHook(tf.estimator.SessionRunHook): def __init__(self, first): self._first = first def before_run(self, run_context): if self._first: self._first = False raise tf.errors.DeadlineExceededError(None, None, "test ddl exceeded error") def test2(): params = FeatureTask.params() params.name = "test_task" first = True params.training_hooks = [RaiseErrorHook(first)] test_run(params) def main(_): test_cases = [test0, test1, test2] test_cases[FLAGS.test_case]() if __name__ == "__main__": app.run(main) ================================================ FILE: monolith/native_training/cpu_training_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from collections import defaultdict import copy import os import subprocess import threading import time from google.protobuf import text_format from typing import Dict, List from unittest import mock from absl import app from absl import flags import numpy as np import tensorflow as tf from tensorflow.python.lib.io import file_io from monolith.native_training import cpu_training from monolith.native_training import entry from monolith.native_training import feature from monolith.native_training import utils from monolith.native_training.debugging import debugging_server from monolith.native_training.model_export import saved_model_exporters from monolith.native_training.model_export.export_context import ExportMode from monolith.native_training.native_task import NativeTask from monolith.native_training.proto import debugging_info_pb2 from monolith.native_training.runtime.hash_table import \ embedding_hash_table_pb2 from monolith.native_training.service_discovery import ServiceDiscovery FLAGS = flags.FLAGS # TODO(leqi.zou): Finally remove this or rework with a better gflag util. flags.DEFINE_bool("use_native_multi_hash_table", False, "The test flag to control if use multi hash table.") def inc_global_step_op() -> tf.Operation: global_step = tf.compat.v1.train.get_or_create_global_step() global_step = tf.compat.v1.assign_add(global_step, 1) return tf.group(global_step) class FeatureTask(NativeTask): """A test task that will collect some information in model_fn.""" def create_input_fn(self, mode): del mode def input_fn(): tensor = tf.ragged.constant([[0, 0]], dtype=tf.int64) return tf.data.Dataset.from_tensors({"feature": tensor}) return input_fn def create_model_fn(self): def model_fn(mode, features, config): slot = self.ctx.feature_factory.create_feature_slot( feature.FeatureSlotConfig(name="slot")) s = slot.add_feature_slice(5) fc = feature.FeatureColumnV1(slot, "feature") embedding = fc.embedding_lookup(s) all_embeddings = [fc.get_all_embeddings_concat()] if mode == tf.estimator.ModeKeys.PREDICT: return tf.estimator.EstimatorSpec( mode, predictions=tf.math.reduce_sum(embedding)) grads = tf.gradients(-embedding, all_embeddings) train_op = tf.group( inc_global_step_op(), self._ctx.feature_factory.apply_gradients(zip(grads, all_embeddings)), features["feature"]) return tf.estimator.EstimatorSpec(mode, train_op=train_op, loss=tf.constant(0.0), predictions=tf.constant(0)) return model_fn def create_serving_input_receiver_fn(self): def serving_input_receiver_fn(): return tf.estimator.export.ServingInputReceiver( {"feature": tf.ragged.constant([[0, 0]], dtype=tf.int64)}, tf.compat.v1.placeholder(tf.string)) return serving_input_receiver_fn class FloatFeatureTask(NativeTask): """A test task that will use float feature in model_fn.""" def create_input_fn(self, mode): del mode def input_fn(): return tf.data.Dataset.from_tensors({ "ragged_feature": tf.ragged.constant([[0, 0]], dtype=np.int64), "float_feature": tf.constant([[1.]], dtype=tf.float32) }) return input_fn def create_model_fn(self): def model_fn(features, mode, **kwargs): slot = self.ctx.feature_factory.create_feature_slot( feature.FeatureSlotConfig(name="slot")) s = slot.add_feature_slice(5) fc = feature.FeatureColumnV1(slot, "ragged_feature") embedding = fc.embedding_lookup(s) float_feature = features["float_feature"] predictions = tf.reduce_sum(float_feature, axis=-1) all_embeddings = [fc.get_all_embeddings_concat()] grads = tf.gradients(-embedding, all_embeddings) train_op = tf.group( inc_global_step_op(), self._ctx.feature_factory.apply_gradients(zip(grads, all_embeddings))) return tf.estimator.EstimatorSpec(mode, train_op=train_op, loss=tf.constant(0.0), predictions=predictions) return model_fn class SequenceFeatureTask(NativeTask): """A test task that will use float feature in model_fn.""" def create_input_fn(self, mode): del mode def input_fn(): return tf.data.Dataset.from_tensors({ "sequence_feature": tf.ragged.constant([[1, 2], [], [3, 4, 5]], dtype=np.int64), }) return input_fn def create_model_fn(self): def model_fn(features, mode, **kwargs): slot = self.ctx.feature_factory.create_feature_slot( feature.FeatureSlotConfig(name="slot")) s = slot.add_feature_slice(5) fc = feature.FeatureColumnV1(slot, "sequence_feature", combiner=feature.FeatureColumnV1.first_n(2)) embedding = fc.embedding_lookup(s) sequence_feature = features["sequence_feature"] predictions = tf.reduce_sum(sequence_feature, axis=-1) all_embeddings = [fc.get_all_embeddings_concat()] grads = tf.gradients(-embedding, all_embeddings) train_op = tf.group( inc_global_step_op(), self._ctx.feature_factory.apply_gradients(zip(grads, all_embeddings))) return tf.estimator.EstimatorSpec(mode, train_op=train_op, loss=tf.constant(0.0), predictions=predictions) return model_fn class FeatureWithSlotOccurrenceThresholdTask(NativeTask): """A test task that will collect some information in model_fn.""" def create_input_fn(self, mode): del mode def input_fn(): return tf.data.Dataset.from_tensors( {"feature": tf.ragged.constant([[0, 0]], dtype=np.int64)}) return input_fn def create_model_fn(self): def model_fn(mode, **kwargs): slot = self.ctx.feature_factory.create_feature_slot( feature.FeatureSlotConfig(name="slot", slot_id=2021, occurrence_threshold=3)) s = slot.add_feature_slice(5) fc = feature.FeatureColumnV1(slot, "feature") embedding = fc.embedding_lookup(s) all_embeddings = [fc.get_all_embeddings_concat()] grads = tf.gradients(-embedding, all_embeddings) train_op = tf.group( inc_global_step_op(), self._ctx.feature_factory.apply_gradients(zip(grads, all_embeddings))) return tf.estimator.EstimatorSpec(mode, train_op=train_op, loss=tf.constant(0.0), predictions=tf.constant(0)) return model_fn class FeatureWithExpireTimeTask(NativeTask): """A test task that will collect some information in model_fn.""" def create_input_fn(self, mode): del mode def input_fn(): return tf.data.Dataset.from_tensors({ "feature_1": tf.ragged.constant([[1 << 48, (1 << 48) + 1]], dtype=np.int64), "feature_2": tf.ragged.constant([[2 << 48, (2 << 48) + 1]], dtype=np.int64), "req_time": tf.constant([[100]], dtype=tf.int64), }) return input_fn def create_model_fn(self): def model_fn(mode, features, **kwargs): slot_1 = self.ctx.feature_factory.create_feature_slot( feature.FeatureSlotConfig( name="slot_1", slot_id=1, expire_time=0, default_vec_initializer=entry.ZerosInitializer())) s_1 = slot_1.add_feature_slice(5) fc_1 = feature.FeatureColumnV1(slot_1, "feature_1") embedding_1 = fc_1.embedding_lookup(s_1) slot_2 = self.ctx.feature_factory.create_feature_slot( feature.FeatureSlotConfig( name="slot_2", slot_id=2, expire_time=1, default_vec_initializer=entry.ZerosInitializer())) s_2 = slot_2.add_feature_slice(5) fc_2 = feature.FeatureColumnV1(slot_2, "feature_2") embedding_2 = fc_2.embedding_lookup(s_2) predictions = tf.concat([embedding_1, embedding_2], axis=0) all_embeddings = [ fc_1.get_all_embeddings_concat(), fc_2.get_all_embeddings_concat(), ] grads = tf.gradients([embedding_1, embedding_2], all_embeddings) train_op = tf.group( inc_global_step_op(), self._ctx.feature_factory.apply_gradients(zip(grads, all_embeddings))) return tf.estimator.EstimatorSpec(mode, train_op=train_op, loss=tf.constant(0.0), predictions=predictions) return model_fn class NonFeatureTask(NativeTask): def create_input_fn(self, mode): del mode def input_fn(): return tf.data.Dataset.from_tensors([1]) return input_fn def create_model_fn(self): def model_fn(features, mode, config): return tf.estimator.EstimatorSpec(mode, train_op=tf.group( inc_global_step_op(), features), loss=tf.constant(0.0), predictions=tf.constant(0)) return model_fn class CpuTrainTest(tf.test.TestCase): def test_cpu_training_feature(self): p = FeatureTask.params() p.name = "feature_task" task = FeatureTask(p) training = cpu_training.CpuTraining(cpu_training.CpuTrainingConfig(), task) est = tf.estimator.Estimator( training.create_model_fn(), os.path.join(os.environ["TEST_TMPDIR"], "test_cpu_training_feature")) est.train(training.create_input_fn(tf.estimator.ModeKeys.TRAIN)) def test_with_misc_features(self): p = FeatureTask.params() p.name = "misc_features" task = FeatureTask(p) training = cpu_training.CpuTraining( cpu_training.CpuTrainingConfig(feature_eviction_on_save=True), task) est = tf.estimator.Estimator( training.create_model_fn(), os.path.join(os.environ["TEST_TMPDIR"], "test_with_misc_features")) est.train(training.create_input_fn(tf.estimator.ModeKeys.TRAIN)) def test_with_export_when_saving(self): p = FeatureTask.params() p.serving.export_when_saving = True task = FeatureTask(p) training = cpu_training.CpuTraining(cpu_training.CpuTrainingConfig(), task) est = tf.estimator.Estimator( training.create_model_fn(), os.path.join(os.environ["TEST_TMPDIR"], "test_with_export_when_saving")) est.train(training.create_input_fn(tf.estimator.ModeKeys.TRAIN)) def test_dense_only_export(self): p = FeatureTask.params() p.serving.export_when_saving = True p.serving.export_mode = ExportMode.DISTRIBUTED task = FeatureTask(p) training = cpu_training.CpuTraining( cpu_training.CpuTrainingConfig(dense_only_save_checkpoints_steps=10), task) est = tf.estimator.Estimator( training.create_model_fn(), os.path.join(os.environ["TEST_TMPDIR"], "test_dense_only_export")) est.train(training.create_input_fn(tf.estimator.ModeKeys.TRAIN)) def test_with_prefetch_postpush(self): p = FeatureTask.params() p.name = "feature_task" task = FeatureTask(p) training = cpu_training.CpuTraining( cpu_training.CpuTrainingConfig(enable_variable_prefetch=True, enable_variable_postpush=True, enable_embedding_postpush=True, embedding_prefetch_capacity=1), task) est = tf.estimator.Estimator( training.create_model_fn(), os.path.join(os.environ["TEST_TMPDIR"], "test_with_prefetch_postpush")) est.train(training.create_input_fn(tf.estimator.ModeKeys.TRAIN)) def test_cpu_training_float_feature(self): p = FloatFeatureTask.params() p.name = "float_feature_task" task = FloatFeatureTask(p) training = cpu_training.CpuTraining(cpu_training.CpuTrainingConfig(), task) est = tf.estimator.Estimator( training.create_model_fn(), os.path.join(os.environ["TEST_TMPDIR"], "test_cpu_training_float_feature")) est.train(training.create_input_fn(tf.estimator.ModeKeys.TRAIN)) def test_cpu_training_sequence_feature(self): p = SequenceFeatureTask.params() p.name = "sequence_feature_task" task = SequenceFeatureTask(p) training = cpu_training.CpuTraining(cpu_training.CpuTrainingConfig(), task) est = tf.estimator.Estimator( training.create_model_fn(), os.path.join(os.environ["TEST_TMPDIR"], "test_cpu_training_sequence_feature")) est.train(training.create_input_fn(tf.estimator.ModeKeys.TRAIN)) def test_cpu_training_with_slot_occurrence_threshold(self): p = FeatureWithSlotOccurrenceThresholdTask.params() p.name = "feature_with_slot_occurrence_task" task = FeatureWithSlotOccurrenceThresholdTask(p) training = cpu_training.CpuTraining(cpu_training.CpuTrainingConfig(), task) est = tf.estimator.Estimator( training.create_model_fn(), os.path.join(os.environ["TEST_TMPDIR"], "test_cpu_training_with_slot_occurrence_threshold")) est.train(training.create_input_fn(tf.estimator.ModeKeys.TRAIN)) slot_to_occurrence_threshold = training._slot_to_occurrence_threshold self.assertEqual(len(slot_to_occurrence_threshold), 1) self.assertTrue(2021 in slot_to_occurrence_threshold) self.assertEqual(slot_to_occurrence_threshold[2021], 3) def test_cpu_training_with_expire_time(self): p = FeatureWithExpireTimeTask.params() p.name = "feature_with_expire_time_task" task = FeatureWithExpireTimeTask(p) training = cpu_training.CpuTraining(cpu_training.CpuTrainingConfig(), task) base_name = os.path.join(os.environ["TEST_TMPDIR"], "test_cpu_training_with_expire_time") # train est = tf.estimator.Estimator(training.create_model_fn(), base_name) est = est.train(training.create_input_fn(tf.estimator.ModeKeys.TRAIN)) slot_to_expire_time = training._slot_to_expire_time self.assertEqual(len(slot_to_expire_time), 2) self.assertTrue(1 in slot_to_expire_time) self.assertTrue(2 in slot_to_expire_time) self.assertEqual(slot_to_expire_time[1], 0) self.assertEqual(slot_to_expire_time[2], 1) #predict result = est.predict(training.create_input_fn( tf.estimator.ModeKeys.PREDICT)) result = list(result) expected = [[0, 0, 0, 0, 0], [-0.001414, -0.001414, -0.001414, -0.001414, -0.001414]] self.assertAllClose(result, expected) def test_cpu_training_non_feature(self): p = NonFeatureTask.params() p.name = "non_feature_task" task = NonFeatureTask(p) training = cpu_training.CpuTraining(cpu_training.CpuTrainingConfig(), task) est = tf.estimator.Estimator( training.create_model_fn(), os.path.join(os.environ["TEST_TMPDIR"], "test_cpu_training_non_feature")) est.train(training.create_input_fn(tf.estimator.ModeKeys.TRAIN)) def test_gpu_export(self): p = FeatureTask.params() p.name = "gpu_export" task = FeatureTask(p) training = cpu_training.CpuTraining(cpu_training.CpuTrainingConfig(), task) model_dir = os.path.join(os.environ["TEST_TMPDIR"], "test_gpu_export") est = tf.estimator.Estimator(training.create_model_fn(), model_dir) est.train(training.create_input_fn(tf.estimator.ModeKeys.TRAIN)) export_dir_base = os.path.join(model_dir, "saved_models") exporter = saved_model_exporters.DistributedExporter( training.create_model_fn(), model_dir, export_dir_base, with_remote_gpu=True) exporter.export_saved_model(training.create_serving_input_receiver_fn()) _DISTRIBUTED_TRAIN_BINARY = "monolith/native_training/cpu_training_distributed_test_binary" class DistributedTrainTest(tf.test.TestCase): def _run_process(self, args_tmpl: List, num_ps: int, num_workers: int): processes = [] for i in range(num_ps): args = copy.copy(args_tmpl) args.append("--server_type=ps") args.append("--index={}".format(i)) process = subprocess.Popen(args) processes.append(process) for i in range(num_workers): args = copy.copy(args_tmpl) args.append("--server_type=worker") args.append("--index={}".format(i)) process = subprocess.Popen(args) processes.append(process) if i == 0: # this is best effort waiting, otherwise test may take 30 secs to finish. # The goal here is to wait for chief to initialize global variables. time.sleep(1) processes.reverse() return processes def _run_test(self, args_tmpl: List, num_ps: int, num_workers: int): processes = self._run_process(args_tmpl, num_ps, num_workers) print(" ".join(args_tmpl), num_ps, num_workers) for process in processes: # We give 70 secs to timeout because of 30 secs querying interval. self.assertEqual(process.wait(timeout=150), 0) def _test_dir(self): return os.path.join(os.environ["TEST_TMPDIR"], "DistributedTrainTest", self._testMethodName) def _test_args(self, num_ps, num_workers, case=0): args = [ _DISTRIBUTED_TRAIN_BINARY, "--test_case={}".format(case), "--test_dir={}".format(self._test_dir()), "--num_ps={}".format(num_ps), "--num_workers={}".format(num_workers), "--uuid={}".format(self._testMethodName), f"--use_native_multi_hash_table={True if FLAGS.use_native_multi_hash_table else False}" ] return args # TODO(leqi.zou): Currently, this test mocks too much, should find a way to elegantly solve # the shutdown problem both in test and training # This test may takes 30 secs to be finished because variable initialization problem. def test0_basic(self): num_ps = 4 # We have 2 workers and 1 chief num_workers = 3 args_tmpl = self._test_args(num_ps, num_workers) self._run_test(args_tmpl, num_ps, num_workers) def test0_with_extra_ps(self): num_ps = 2 num_workers = 1 num_extra_ps = 2 args_tmpl = self._test_args(num_ps, num_workers) args_tmpl.append("--num_extra_ps={}".format(num_extra_ps)) self._run_test(args_tmpl, num_ps + num_extra_ps, num_workers) def test0_with_redundant_ps(self): num_ps = 4 num_workers = 2 num_redundant_ps = 2 args_tmpl = self._test_args(num_ps, num_workers) args_tmpl.append("--num_redundant_ps={}".format(num_redundant_ps)) self._run_test(args_tmpl, num_ps + num_redundant_ps, num_workers) def test1_with_debugging_server(self): if FLAGS.use_native_multi_hash_table: # Debugging server doesnt support multi hash table. return num_ps = 2 num_workers = 1 args_tmpl = self._test_args(num_ps, num_workers, case=1) processes = self._run_process(args_tmpl, num_ps, num_workers) model_dir = os.path.join(self._test_dir(), "test1_with_debugging_server/model") while True: ckpt_state = tf.train.get_checkpoint_state(model_dir) if ckpt_state: break time.sleep(1) debugging_info_str = file_io.read_file_to_string( utils.get_debugging_info_file_name(model_dir), binary_mode=True) debugging_info = debugging_info_pb2.DebuggingInfo() debugging_info.ParseFromString(debugging_info_str) self.assertEqual(debugging_info.num_workers, num_workers) self.assertLen(debugging_info.cluster.ps_addrs, num_ps) self.assertLen(debugging_info.feature_name_configs, 1) self.assertEqual(debugging_info.feature_name_configs[0].feature_name, "feature") worker = debugging_server.DebuggingWorker(model_dir) self.assertEqual(worker.fetch_variables(["global_step:0", "test"]), {'global_step:0': '1'}) fids = ["0", "1", "2", "0"] result = worker.fetch_features(["feature"] * 3 + ["test"], fids) for idx in range(2): fid = fids[idx] entry_dump = embedding_hash_table_pb2.EntryDump() text_format.Parse(result["feature"][fid], entry_dump) self.assertLen(entry_dump.num, 5) self.assertNotIn("2", result["feature"]) self.assertNotIn("test", result) for process in processes: process.kill() def test2_temporary_error(self): num_ps = 1 num_workers = 1 args_tmpl = self._test_args(num_ps, num_workers, case=2) self._run_test(args_tmpl, num_ps, num_workers) class LocalTrainTest(tf.test.TestCase): def testBasic(self): print(tf.compat.v1.get_default_graph().as_graph_def()) p = FeatureTask.params() p.name = "feature_task" p.train.max_steps = 1 cpu_training.local_train(p, model_dir=os.path.join(os.environ["TEST_TMPDIR"], "local_train_basic"), profiling=False) def testWithPs(self): print(tf.compat.v1.get_default_graph().as_graph_def()) p = FeatureTask.params() p.name = "feature_task" p.train.max_steps = 1 cpu_training.local_train(p, num_ps=2, model_dir=os.path.join(os.environ["TEST_TMPDIR"], "local_train_with_ps"), profiling=False) if __name__ == "__main__": tf.compat.v1.disable_eager_execution() app.run(tf.test.main) ================================================ FILE: monolith/native_training/data/BUILD ================================================ load("@rules_python//python:defs.bzl", "py_binary", "py_library", "py_test") load("@org_tensorflow//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test", "tf_custom_op_library") load("@com_google_protobuf//:protobuf.bzl", "cc_proto_library", "py_proto_library") load("@rules_cc//cc:defs.bzl", "cc_library", "cc_test") load("@pip_deps//:requirements.bzl", "requirement") package(default_visibility = ["//visibility:public"]) cc_proto_library( name = "data_op_config_cc_proto", srcs = ["data_op_config.proto"], ) py_proto_library( name = "data_op_config_py_proto", srcs = ["data_op_config.proto"], srcs_version = "PY2AND3", ) cc_library( name = "pb_data_internal_lib", alwayslink = 1, ) cc_library( name = "pb_data_lib", srcs = [ "kernels/add_action_kernel.cc", "kernels/add_label_kernel.cc", "kernels/cache_one_dataset_kernel.cc", "kernels/cache_one_dataset_kernel.h", "kernels/df_resource_kernel.cc", "kernels/df_resource_kernel.h", "kernels/dynamic_match_file_dataset_kernel.cc", "kernels/extract_fid_kernel.cc", "kernels/feature_hash.cc", "kernels/feature_name_mapper_tf_bridge.cc", "kernels/feature_name_mapper_tf_bridge.h", "kernels/fill_multi_rank_output_kernel.cc", "kernels/filter_by_label_kernel.cc", "kernels/instance_reweight_dataset_kernel.cc", "kernels/instance_reweight_dataset_kernel.h", "kernels/item_pool_kernels.cc", "kernels/item_pool_kernels.h", "kernels/kafka_kernels.cc", "kernels/label_normalization_kernel.cc", "kernels/label_upper_bound_kernel.cc", "kernels/map_id_kernels.cc", "kernels/merge_flow_dataset_kernel.cc", "kernels/multi_label_gen_kernel.cc", "kernels/negative_gen_dataset_kernel.cc", "kernels/negative_gen_dataset_kernel.h", "kernels/parquet_dataset_kernel.cc", "kernels/parse_example_lib.cc", "kernels/parse_example_lib.h", "kernels/parse_input_data_kernel.cc", "kernels/parse_sparse_feature.cc", "kernels/parse_sparse_feature.h", "kernels/pb_dataset_kernel.cc", "kernels/ragged_feature_kernel.cc", "kernels/scatter_label_kernel.cc", "kernels/split_flow_dataset_kernel.cc", "kernels/string_to_variant.cc", "kernels/tf_example_to_example_kernel.cc", "kernels/transform_dataset_kernel.cc", "kernels/transform_dataset_kernel.h", "kernels/variant_filter_kernel.cc", "kernels/gen_fid_mask.cc", ], deps = [ ":data_op_config_cc_proto", ":pb_data_internal_lib", "//monolith/native_training/data/kernels/internal:cache_mgr", "//monolith/native_training/data/kernels/internal:datasource_utils", "//monolith/native_training/data/kernels/internal:file_match_split_provider", "//monolith/native_training/data/kernels/internal:label_utils", "//monolith/native_training/data/kernels/internal:value_filter_by_line_id", "//monolith/native_training/data/kernels/internal:value_filter_by_feature", "//monolith/native_training/data/kernels/internal:parquet_example_reader", "//monolith/native_training/data/kernels/internal:relational_utils", "//monolith/native_training/data/kernels/internal:uniq_hashtable", "//monolith/native_training/data/training_instance:data_reader", "//monolith/native_training/data/training_instance:fid", "//monolith/native_training/data/training_instance:instance_utils", "//monolith/native_training/data/training_instance:parse_instance_lib", "//monolith/native_training/data/training_instance:reader_util", "//monolith/native_training/data/transform:transforms", "//monolith/native_training/runtime/common:metrics", "//monolith/native_training/runtime/common:linalg_utils", "//monolith/native_training/runtime/concurrency:queue", "//monolith/native_training/runtime/ops:traceme", "//third_party/nlohmann:json", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/hash:city", "@com_google_absl//absl/random", "@com_google_absl//absl/strings:str_format", "@kafka", ], alwayslink = 1, ) cc_library( name = "pb_data_ops", srcs = [ "ops/feature_utils_ops.cc", "ops/parse_input_data_ops.cc", "ops/pb_dataset_ops.cc", ], copts = ["-DNDEBUG"], deps = [ ":pb_data_lib", "@org_tensorflow//tensorflow/core:framework_headers_lib", ], alwayslink = 1, ) py_library( name = "datasets_py", srcs = [ "datasets.py", ], deps = [ ":feature_list", ":feature_utils_py", ":parsers_py", "//monolith:utils", "//monolith/native_training:mlp_utils", "//monolith/native_training:monolith_export", "//monolith/native_training/data/transform:transforms_py", "//monolith/native_training/distribute:distributed_dataset", "//monolith/native_training/hooks:ckpt_hooks", "//monolith/native_training/runtime/ops:gen_monolith_ops", "@org_tensorflow//tensorflow:tensorflow_py", requirement("kafka_python"), ], ) py_library( name = "parsers_py", srcs = [ "parsers.py", ], deps = [ ":data_op_config_py_proto", ":feature_list", "//idl:example_py_proto", "//idl:proto_parser_py_proto", "//monolith:utils", "//monolith/native_training:logging_ops", "//monolith/native_training:monolith_export", "//monolith/native_training:native_task_context", "//monolith/native_training:utils", "//monolith/native_training/runtime/ops:gen_monolith_ops", "@org_tensorflow//tensorflow:tensorflow_py", ], ) py_library( name = "feature_utils_py", srcs = [ "feature_utils.py", ], deps = [ ":data_op_config_py_proto", ":feature_list", ":parsers_py", "//idl:example_py_proto", "//idl:proto_parser_py_proto", "//monolith:utils", "//monolith/native_training:monolith_export", "//monolith/native_training/runtime/ops:gen_monolith_ops", ], ) py_library( name = "feature_list", srcs = ["feature_list.py"], deps = [ ":utils", "//monolith/native_training:utils", "@org_tensorflow//tensorflow:tensorflow_py", ], ) py_test( name = "feature_list_test", srcs = ["feature_list_test.py"], data = ["//monolith/native_training/data/test_data:test_feature_lists"], deps = [ ":feature_list", ], ) py_library( name = "data", srcs = [ "__init__.py", ], srcs_version = "PY3", deps = [ ":datasets_py", ":feature_utils_py", ":parsers_py", ], ) py_test( name = "extract_fid_test", srcs = [ "extract_fid_test.py", ], main = "extract_fid_test.py", deps = [ "//idl:example_py_proto", "//idl:proto_parser_py_proto", "//monolith:utils", "//monolith/native_training/runtime/ops:gen_monolith_ops", ], ) py_library( name = "utils", srcs = ["utils.py"], ) py_library( name = "item_pool_hook", srcs = ["item_pool_hook.py"], deps = [ ":datasets_py", ":feature_utils_py", ], ) py_test( name = "item_pool_test", srcs = [ "item_pool_test.py", ], deps = [ ":feature_utils_py", ":feature_list", "//monolith:utils", ], ) py_test( name = "multi_flow_test", srcs = [ "multi_flow_test.py", ], main = "multi_flow_test.py", deps = [ ":datasets_py", ":feature_utils_py", ":parsers_py", "//idl:example_py_proto", "//idl:proto_parser_py_proto", ], ) py_test( name = "negative_gen_test", srcs = [ "negative_gen_test.py", ], main = "negative_gen_test.py", deps = [ ":datasets_py", ":feature_utils_py", ":parsers_py", "//idl:example_py_proto", "//idl:proto_parser_py_proto", ], ) py_binary( name = "kafka_dataset_test", srcs = [ "kafka_dataset_test.py", ], deps = [ ":datasets_py", ":feature_utils_py", ":parsers_py", "//idl:example_py_proto", "//idl:proto_parser_py_proto", "//monolith/native_training/model_export:data_gen_utils", ], ) py_binary( name = "data_service_test", srcs = [ "data_service_test.py", ], deps = [ ":datasets_py", ":feature_utils_py", ], ) py_binary( name = "data_service_parquet_test", srcs = [ "data_service_parquet_test.py", ], deps = [ ":datasets_py", ":feature_utils_py", ], ) exports_files([ "kernels/add_action_kernel.cc", "kernels/instance_reweight_dataset_kernel.cc", "kernels/instance_reweight_dataset_kernel.h", "kernels/negative_gen_dataset_kernel.cc", "kernels/negative_gen_dataset_kernel.h", "kernels/df_resource_kernel.h", "kernels/df_resource_kernel.cc", "kernels/split_flow_dataset_kernel.cc", "kernels/merge_flow_dataset_kernel.cc", "kernels/parse_example_lib.cc", "kernels/parse_example_lib.h", "kernels/parse_input_data_kernel.cc", "kernels/pb_dataset_kernel.cc", "kernels/ragged_feature_kernel.cc", "kernels/variant_filter_kernel.cc", "kernels/parquet_dataset_kernel.cc", "kernels/item_pool_kernels.h", "kernels/item_pool_kernels.cc", "ops/feature_utils_ops.cc", "ops/parse_input_data_ops.cc", "ops/pb_dataset_ops.cc", ]) ================================================ FILE: monolith/native_training/data/__init__.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import monolith.native_training.data.datasets as datasets from monolith.native_training.data.datasets import PBDataset, InstanceReweightDataset, NegativeGenDataset, PbType from monolith.native_training.data.parsers import parse_examples, parse_instances, parse_example_batch from monolith.native_training.data.feature_utils import filter_by_fids, filter_by_feature_value, filter_by_value, \ feature_combine, negative_sample, switch_slot, special_strategy ================================================ FILE: monolith/native_training/data/data_op_config.proto ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. syntax = "proto2"; package monolith.native_training.data.config; message TaskLabelConf { repeated int32 pos_actions = 1; repeated int32 neg_actions = 2; optional float sample_rate = 3 [default = 1.0]; } message LabelConf { repeated TaskLabelConf conf = 1; } message TFRecordFeatureDescription { map sparse_features = 1; repeated string dense_features = 2; optional string label = 3 [default = ""]; optional string instance_weight = 4 [default = ""]; } ================================================ FILE: monolith/native_training/data/data_ops_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import time import tensorflow as tf from absl import logging, flags import numpy as np from struct import unpack import zlib import gzip from monolith.native_training.data.datasets import PBDataset, InstanceReweightDataset, PbType, CompressType, \ FilePBDataset, KafkaDataset, CacheOneDataset from monolith.native_training.data.parsers import parse_instances, parse_examples, parse_example_batch from monolith.native_training.data.feature_utils import filter_by_fids, filter_by_value, negative_sample, \ switch_slot, feature_combine, special_strategy from idl.matrix.proto.example_pb2 import Example, ExampleBatch from monolith.native_training.model_export.data_gen_utils import gen_random_data_file, ParserArgs from tensorflow.python.framework import sparse_tensor from monolith.native_training.estimator import RunConfig from monolith.native_training.hooks import session_hooks FLAGS = flags.FLAGS features = { 'f_spm_1': 301, 'f_spm_3': 303, 'f_spm_2': 302, 'f_spm_4': 304, 'f_user_id': 1, 'f_user_ctx_network': 61, 'f_user_id-f_page': 504, 'f_scm': 306, 'f_goods_id': 200, 'f_goods_sale_number_1000': 225, 'f_goods_praise_cnt': 229, 'f_spm': 300, 'f_page': 305, 'f_is_dup': 310, 'f_user_ctx_platform': 52, 'f_goods_title_terms': 209, 'f_goods_tags_terms': 211, 'f_user_test09_array_int32': 554, 'f_user_test15_array_float': 540, 'f_user_test14_array_bool': 543, 'f_user_test12_array_uint64': 551, 'f_user_test10_array_int64': 549 } group_slots = [200,201,202,203,204,205,206,210,211,212,213,214,215,\ 216,217,218,219,220,221,222,223,224,225,230,231,232,233,234,235,236,237,238,239,240,241,242] def parse_inst_exam(tensor: tf.Tensor, out_type): fidv1_features = [ 1, 2, 32, 33, 36, 38, 42, 50, 54, 56, 60, 66, 120, 150, 180, 182, 192, 220, 333, 410, 412, 422, 446 ] if out_type == PbType.INSTANCE: return parse_instances(tensor, fidv1_features, dense_features=['label'], dense_feature_shapes=[2], dense_feature_types=[tf.float32], extra_features=['uid', 'req_time', 'item_id'], extra_feature_shapes=[1, 1, 1]) else: return parse_examples( tensor, sparse_features=[f'fc_slot_{slot}' for slot in fidv1_features], dense_features=['label'], dense_feature_shapes=[2], dense_feature_types=[tf.float32], extra_features=['uid', 'req_time', 'item_id'], extra_feature_shapes=[1, 1, 1]) def parse_eb(tensor: tf.Tensor, out_type): if out_type == PbType.INSTANCE: feature_dict = parse_instances( tensor, fidv1_features=list(features.values()), dense_features=['label'], dense_feature_shapes=[2], dense_feature_types=[tf.float32], extra_features=['uid', 'req_time', 'item_id'], extra_feature_shapes=[1, 1, 1]) else: feature_dict = parse_examples( tensor, sparse_features=list(features.keys()), dense_features=['label'], dense_feature_shapes=[2], dense_feature_types=[tf.float32], extra_features=['uid', 'req_time', 'item_id', 'actions'], extra_feature_shapes=[1, 1, 1, 1]) feature_dict['f_page'] = switch_slot(feature_dict['f_page'], slot=306) feature_dict['f_user_id-f_goods_tags_terms'] = feature_combine( feature_dict['f_user_id'], feature_dict['f_goods_tags_terms'], slot=505) return feature_dict class DataOpsTest(tf.test.TestCase): @classmethod def setUpClass(cls): cwd = os.getcwd() cls.patterns = [os.path.join(cwd, "tmp_data", "part-*")] cls._files = [] args = ParserArgs(fidv1_features=[i for i in range(1, 10)], extra_features=[ 'uid', 'sample_rate', 'req_time', 'actions', 'stay_time' ], extra_feature_shapes=[1, 1, 1, 1, 1], batch_size=16, variant_type='instance') for i in range(3): tf.io.gfile.makedirs(os.path.join(cwd, "tmp_data")) file_name = os.path.join(cwd, "tmp_data", f"part-{i}") gen_random_data_file(file_name, args, num_batch=10, sort_id=True, kafka_dump_prefix=False) cls._files.append(file_name) @classmethod def tearDownClass(cls): for file_name in cls._files: tf.io.gfile.remove(file_name) def pb_dataset_target(self, input_pb_type, output_pb_type, filter_fn=None): if input_pb_type == PbType.INSTANCE: lagrangex_header = False has_sort_id, kafka_dump, kafka_dump_prefix = True, True, False file_name = "monolith/native_training/data/training_instance/instance.pb" elif input_pb_type == PbType.EXAMPLE: lagrangex_header = False has_sort_id, kafka_dump, kafka_dump_prefix = True, True, False file_name = "monolith/native_training/data/training_instance/example.pb" else: lagrangex_header = True has_sort_id, kafka_dump, kafka_dump_prefix = False, False, False file_name = "monolith/native_training/data/training_instance/examplebatch.data" def parser(tensor: tf.Tensor): if output_pb_type == PbType.PLAINTEXT: return parse_inst_exam(tensor, input_pb_type) elif input_pb_type != PbType.EXAMPLEBATCH: return parse_inst_exam(tensor, output_pb_type) else: return parse_eb(tensor, output_pb_type) with tf.Graph().as_default(): config = tf.compat.v1.ConfigProto() config.graph_options.rewrite_options.disable_meta_optimizer = True with self.session(config=config) as sess: dataset = PBDataset(file_name=file_name, lagrangex_header=lagrangex_header, has_sort_id=has_sort_id, kafka_dump=kafka_dump, kafka_dump_prefix=kafka_dump_prefix, input_pb_type=input_pb_type, output_pb_type=output_pb_type) if input_pb_type == PbType.EXAMPLEBATCH: variant_type = 'instance' if output_pb_type == PbType.INSTANCE else 'example' dataset = dataset.instance_reweight( action_priority="2,7,0,1,3,4,5,6,8,9,10,11", reweight= "0:0:1,1:0:1,2:3:-1,3:0:1,4:0:1,5:0:1,6:0:1,7:6:1,8:0:1,9:0:1,10:0:1,11:0:-1", variant_type=variant_type) if filter_fn is not None: dataset = dataset.filter(filter_fn) dataset = dataset.batch(8, drop_remainder=True).map(parser) it = tf.compat.v1.data.make_initializable_iterator(dataset) element = it.get_next() sess.run(it.initializer) count = 0 while True: try: element_num = sess.run(element) # print(element_num) count += 8 except tf.errors.OutOfRangeError: break logging.info("The number of batch is: {}".format(count)) def testInstance2Instance(self): self.pb_dataset_target(input_pb_type=PbType.INSTANCE, output_pb_type=PbType.INSTANCE) def testInstance2Example(self): self.pb_dataset_target(input_pb_type=PbType.INSTANCE, output_pb_type=PbType.EXAMPLE) def testExample2Example(self): self.pb_dataset_target(input_pb_type=PbType.EXAMPLE, output_pb_type=PbType.EXAMPLE) def testExample2Instance(self): self.pb_dataset_target(input_pb_type=PbType.EXAMPLE, output_pb_type=PbType.INSTANCE) def testExampleBatch2Example(self): self.pb_dataset_target(input_pb_type=PbType.EXAMPLEBATCH, output_pb_type=PbType.EXAMPLE) def testExampleBatch2Instance(self): self.pb_dataset_target(input_pb_type=PbType.EXAMPLEBATCH, output_pb_type=PbType.INSTANCE) def testInstanceWithPBInstanceDataset(self): self.pb_dataset_target(input_pb_type=PbType.INSTANCE, output_pb_type=PbType.PLAINTEXT) def testExampleWithPBInstanceDataset(self): self.pb_dataset_target(input_pb_type=PbType.EXAMPLE, output_pb_type=PbType.PLAINTEXT) def testSetFilterInstance(self): self.pb_dataset_target( input_pb_type=PbType.EXAMPLEBATCH, output_pb_type=PbType.INSTANCE, filter_fn=lambda variant: filter_by_fids(variant, has_actions=[1, 2])) def testSetFilterExample(self): self.pb_dataset_target( input_pb_type=PbType.EXAMPLEBATCH, output_pb_type=PbType.EXAMPLE, filter_fn=lambda variant: filter_by_fids( variant, has_actions=[1, 2], variant_type='example')) def testValueFilterInstance(self): self.pb_dataset_target(input_pb_type=PbType.EXAMPLEBATCH, output_pb_type=PbType.INSTANCE, filter_fn=lambda variant: filter_by_value( variant, "sample_rate", "ge", 0.8)) def testValueFilterInInstance(self): self.pb_dataset_target(input_pb_type=PbType.EXAMPLEBATCH, output_pb_type=PbType.INSTANCE, filter_fn=lambda variant: filter_by_value( variant, "chnid", "in", [0, 2, 5])) def testValueFilterEqInstance(self): self.pb_dataset_target( input_pb_type=PbType.EXAMPLEBATCH, output_pb_type=PbType.INSTANCE, filter_fn=lambda variant: filter_by_value(variant, "chnid", "eq", 0)) def testValueFilterBewteenInstance(self): self.pb_dataset_target(input_pb_type=PbType.EXAMPLEBATCH, output_pb_type=PbType.INSTANCE, filter_fn=lambda variant: filter_by_value( variant, "sample_rate", "between", [0.1, 0.9])) def testValueFilterStrInstance(self): self.pb_dataset_target( input_pb_type=PbType.EXAMPLEBATCH, output_pb_type=PbType.INSTANCE, filter_fn=lambda variant: filter_by_value(variant, "vid", "eq", 'scm')) def testValueFilterAnyInstance(self): self.pb_dataset_target(input_pb_type=PbType.EXAMPLEBATCH, output_pb_type=PbType.INSTANCE, filter_fn=lambda variant: filter_by_value( variant, "actions", "any", [2, 5, 7])) def testValueFilterAllInstance(self): self.pb_dataset_target(input_pb_type=PbType.EXAMPLEBATCH, output_pb_type=PbType.INSTANCE, filter_fn=lambda variant: filter_by_value( variant, "actions", "all", [2, 5, 7])) def testValueFilterDiffInstance(self): self.pb_dataset_target(input_pb_type=PbType.EXAMPLEBATCH, output_pb_type=PbType.INSTANCE, filter_fn=lambda variant: filter_by_value( variant, "actions", "diff", [2, 5, 7])) def testSpecialStrategyInstance(self): self.pb_dataset_target( input_pb_type=PbType.EXAMPLEBATCH, output_pb_type=PbType.INSTANCE, filter_fn=lambda variant: special_strategy( variant, [2, 5, 7], "2:0.7:-1,5:0.9:1,4:0.2:0,7:1.0:1")) def testValueFilterExample(self): self.pb_dataset_target( input_pb_type=PbType.EXAMPLEBATCH, output_pb_type=PbType.EXAMPLE, filter_fn=lambda variant: filter_by_value( variant, "sample_rate", "ge", 0.8, variant_type='example')) def testExampleBatchPredScalar(self): eb = ExampleBatch() file_name = "monolith/native_training/data/training_instance/examplebatch.data" with tf.Graph().as_default(): config = tf.compat.v1.ConfigProto() config.graph_options.rewrite_options.disable_meta_optimizer = True examples_placeholder = tf.compat.v1.placeholder(dtype=tf.string, shape=(None,)) parsed_results = parse_example_batch( examples_placeholder, sparse_features=list(features.keys()), dense_features=['label'], dense_feature_shapes=[2], dense_feature_types=[tf.float32], extra_features=['uid', 'req_time', 'item_id'], extra_feature_shapes=[1, 1, 1]) with self.session(config=config) as sess: with open(file_name, 'rb') as stream: stream.read(8) # strip lagrangex_header size = unpack(" 0: all_input_files.append(one_pattern) else: logging.warning(f"pattern not match any files: {one_pattern}") except Exception as e: logging.warning( f"pattern not match any files: {one_pattern} with error: {e}") else: for pattern in pattern_format_list[0]: pattern_recurse(pattern_format_list[1:], *args, pattern) pattern_recurse(all_pattern_format) logging.info(f"all_input_files {all_input_files}") assert len(all_input_files) > 0, "no match files" kwargs['patterns'] = all_input_files # dataset_input_patterns will use DistributedFilePBDataset, "file_name" param will cause conflict # meanwhile "file_name" param is useless, but user code fill this param as default in model.py # to fix this problem, pop this params if "file_name" in kwargs: kwargs.pop("file_name") if FLAGS.dataset_input_use_parquet is not None: kwargs['use_parquet'] = FLAGS.dataset_input_use_parquet if FLAGS.dataset_input_use_tfrecord is not None: kwargs['use_tfrecord'] = FLAGS.dataset_input_use_tfrecord assert not ( FLAGS.dataset_input_use_parquet and FLAGS.dataset_input_use_tfrecord ), "It's not allowed to specify dataset_input_use_parquet=True and dataset_input_use_tfrecord=True" if kwargs.get('kafka_other_metadata', None) is None and FLAGS.kafka_other_metadata is not None: kwargs['kafka_other_metadata'] = FLAGS.kafka_other_metadata try: # the first param is str, batch to streaming, use kafka params for cmd kafka_args = [ kwargs.pop('topics', FLAGS.kafka_topics.split(',')), kwargs.pop('group_id', FLAGS.kafka_group_id), kwargs.pop('servers', FLAGS.kafka_servers) ] assert all(x is not None for x in kafka_args) logging.info('use KafkaDataset!') return KafkaDataset(*kafka_args, **kwargs) except Exception as e: logging.error(str(e)) logging.info("it's not streaming training") tf_record_args = { 'file_name', 'compression_type', 'buffer_size', 'num_parallel_reads' } def is_kafka_dataset(): # 'topics', 'group_id' and 'servers' are for KafkaDataset return 'topics' in kwargs and 'group_id' in kwargs and 'servers' in kwargs if args is None or len(args) == 0: # all arguments are in kwargs # 'patterns' for DistributedFilePBDataset if 'patterns' in kwargs and not is_kafka_dataset(): logging.info('use DistributedFilePBDataset!') return DistributedFilePBDataset(**kwargs) elif is_kafka_dataset(): logging.info('use KafkaDataset!') return KafkaDataset(**kwargs) elif kwargs.get('use_parquet'): return ParquetDataset(**kwargs) elif kwargs.get('use_tfrecord'): logging.info('use TFRecordDataset!') invalid_args = list(k for k in kwargs if k not in tf_record_args) for k in invalid_args: kwargs.pop(k) logging.info('---kwargs: %s', kwargs) return TFRecordDatasetWrapper(**kwargs) elif 'file_name' in kwargs or len(kwargs) == 0: return FilePBDataset(*args, **kwargs) else: return super(DatasetMetaclass, cls).__call__(*args, **kwargs) elif isinstance(args[0], str): # The first arg is a filename if kwargs.get('use_parquet'): return ParquetDataset(*args, **kwargs) elif kwargs.get('use_tfrecord'): logging.info('use TFRecordDataset!') invalid_args = list(k for k in kwargs if k not in tf_record_args) for k in invalid_args: kwargs.pop(k) logging.info('---kwargs: %s', kwargs) return TFRecordDatasetWrapper(*args, **kwargs) else: logging.info('use FilePBDataset!') return FilePBDataset(*args, **kwargs) elif isinstance(args[0], (list, tuple)): # The first arg is a list, never reach here if len(args) > 1: if isinstance(args[1], str): logging.info('use KafkaDataset!') return KafkaDataset(*args, **kwargs) else: logging.info('use DistributedFilePBDataset!') return DistributedFilePBDataset(*args, **kwargs) else: if 'group_id' in kwargs or 'servers' in kwargs: logging.info('use KafkaDataset!') return KafkaDataset(*args, **kwargs) else: logging.info('use DistributedFilePBDataset!') return DistributedFilePBDataset(*args, **kwargs) else: return super(DatasetMetaclass, cls).__call__(*args, **kwargs) class PBDataset(metaclass=DatasetMetaclass): def __init__( self, topics_or_files: Union[str, List[str]] = '', buffer_size_or_group_id: Union[int, str] = None, input_pb_type_or_servers: Union[PbType, str] = None, output_pb_type: PbType = None, feature_pruning_type: int = FeaturePruningType.PRUNING_RAW_FEATURE, disable_iterator_save_restore: bool = True, *, has_header=True, variant_type: str = None, stream_timeout=-1, message_poll_timeout=10000, poll_batch_size: int = None, filter_empty: bool = False, configuration=None, container: str = '', shared_name: str = '', cycle_length=None, block_length=None, num_parallel_calls=None, deterministic=None, **kwargs): pass @classmethod def gen_patterns(cls, input_path: str = None, start_date: int = None, start_hour: int = None, end_date: int = None, end_hour: int = None, is_hourly: bool = False, wildcard: str = '*') -> List[str]: input_path = input_path or _get_params('input_path', None) if not input_path: return [] start_date = start_date or _get_params('start_date', None) if not start_date: return [] end_date = end_date or _get_params('end_date', None) if not end_date: end_date = datetime.today().strftime('%Y%m%d') is_hourly = is_hourly if is_hourly is not None else _get_params( 'is_hourly', False) start_hour = start_hour or _get_params('start_hour', 0) or 0 end_hour = end_hour or _get_params('end_hour', 0) or 0 wildcard = wildcard or _get_params('wildcard', '*') start = datetime.strptime(f'{start_date}:{start_hour:02d}', '%Y%m%d:%H') if is_hourly: end = datetime.strptime(f'{end_date}:{end_hour:02d}', '%Y%m%d:%H') else: end = datetime.strptime(f'{end_date}:00', '%Y%m%d:%H') delta = timedelta(hours=1) if is_hourly else timedelta(days=1) cur = start patterns = [] while cur < end: if is_hourly: pat = f"{cur.strftime('%Y%m%d/%H')}{wildcard}" else: pat = os.path.join(cur.strftime('%Y%m%d'), wildcard) patterns.append(os.path.join(input_path, pat)) cur = cur + delta return patterns class DynamicMatchingFilesDataset(dataset_ops.DatasetSource): """A `Dataset` that list the files according to the input patterns.""" def __init__(self, patterns: List[str]): assert patterns is not None and len(patterns) > 0 self._patterns = ops.convert_to_tensor(patterns, dtype=dtypes.string, name="patterns") variant_tensor = pb_datasource_ops.dynamic_matching_files_dataset( self._patterns) super(DynamicMatchingFilesDataset, self).__init__(variant_tensor) @property def element_spec(self): return tensor_spec.TensorSpec([], dtypes.string) class TFRecordDatasetWrapper(tf.data.TFRecordDataset): def __init__(self, file_name, compression_type=None, buffer_size=None, num_parallel_reads=None, **kwargs): super().__init__(file_name, compression_type=compression_type, buffer_size=buffer_size, num_parallel_reads=num_parallel_reads) class ParquetDataset(dataset_ops.DatasetSource): def __init__(self, file_name, output_pb_type: PbType, select_columns: List[str], select_columns_type: List[str], batch_size=512, drop_remainder=True, **kwargs): # assert isinstance(file_name, str) assert output_pb_type in [ PbType.EXAMPLE, PbType.EXAMPLEBATCH, PbType.PLAINTEXT ] assert output_pb_type != 'example_batch' or (isinstance(batch_size, int) and batch_size > 0) batch_size = 0 if output_pb_type == 'example' else batch_size assert isinstance(select_columns, list) and all( isinstance(c, str) for c in select_columns) assert isinstance(select_columns_type, list) and all( t in ["int", "fid_v1", "fid_v2", "float"] for t in select_columns_type) for feature in select_columns: add_feature(feature) if output_pb_type == PbType.EXAMPLEBATCH and batch_size > 0 and drop_remainder: get_default_parser_ctx().set('batch_size', batch_size) self._out_type = tf.string if output_pb_type == PbType.PLAINTEXT else tf.variant tf.compat.v1.add_to_collection(name=OUTPUT_PB_TYPE_GRAPH_KEY, value=output_pb_type.to_name()) variant_tensor = pb_datasource_ops.parquet_dataset( file_name=file_name, output_pb_type=output_pb_type.to_name(), batch_size=batch_size, select_columns=select_columns, select_columns_type=select_columns_type, drop_remainder=drop_remainder) super().__init__(variant_tensor) @property def element_spec(self): return tensor_spec.TensorSpec([], self._out_type) @monolith_export class CompressType(Enum): UNKNOW = 0 NO = 1 SNAPPY = 2 ZSTD = 3 ZLIB = 4 GZIP = 5 @monolith_export class FilePBDataset(dataset_ops.DatasetSource): """从标准输入/pb文件中读取序列化数据, 并将其反序列化存于TF的Variant类型中. 这样做的好处是可以直接对PB对象进行过滤与修改, 不用等到parse以后. Monolith提供了一系列工具操作Variant变量, 如filter_by_fids, filter_by_value, negative_sample等 另外, InstanceReweightDataset/NegativeGenDataset 这些DataSet也可以直接作用于Variant Args: file_name (:obj:`str`): 文件名, 如果为空, 则从stdin读取数据 buffer_size (:obj:`int`): 读取文件时缓存大小, 默认100MB input_pb_type (:obj:`str`): 输入pb类型, 可以是example/example_batch/instance output_pb_type (:obj:`str`): 输入pb类型, 可以是example/instance/plaintext Raises: TypeError: 如果有任何参数与类型不匹配, 则抛TypeError ValueError: 如果有任何值与期望不匹配, 则抛ValueError """ def __init__( self, file_name: str = "", buffer_size: int = None, input_pb_type: PbType = None, output_pb_type: PbType = None, feature_pruning_type: int = FeaturePruningType.PRUNING_RAW_FEATURE, disable_iterator_save_restore: bool = True, use_snappy: bool = None, compression_type: CompressType = CompressType.UNKNOW, **kwargs): input_pb_type = input_pb_type or _get_params('data_type', PbType.INSTANCE) output_pb_type = output_pb_type or (PbType.INSTANCE if input_pb_type == PbType.INSTANCE else PbType.EXAMPLE) feature_name_list = [] feature_id_list = [] if input_pb_type in [PbType.EXAMPLEBATCH, PbType.EXAMPLE]: try: feature_list = FeatureList.parse() for feature in feature_list: name, slot = feature.feature_name, feature.slot assert None not in [name, slot] feature_name_list.append(name) feature_id_list.append(slot) except Exception as e: logging.warning('Failed to parse feature_list.conf, %s', e) self._file_name = file_name self._buffer_size = buffer_size self._input_pb_type = input_pb_type self._output_pb_type = output_pb_type self._out_type = tf.string if output_pb_type == PbType.PLAINTEXT else tf.variant self._has_sort_id = kwargs.get('has_sort_id', _get_params('sort_id', True)) self._kafka_dump = kwargs.get('kafka_dump', _get_params('kafka_dump', False)) logging.info('input_pb_type: %s, kafka_dump: %s, output_pb_type: %s', self._input_pb_type, self._kafka_dump, self._output_pb_type) self._kafka_dump_prefix = kwargs.get( 'kafka_dump_prefix', _get_params('kafka_dump_prefix', False)) self._lagrangex_header = kwargs.get('lagrangex_header', _get_params('lagrangex_header', False)) if disable_iterator_save_restore and isinstance(file_name, str): # This is the special case that dataset uses stdin as the input. # In this case, we should diable the ckpt save/restore. if context.default_execution_mode == context.GRAPH_MODE: ckpt_hooks.disable_iterator_save_restore() default_buffer_size = 128 * 1024 * 1024 if input_pb_type == PbType.EXAMPLEBATCH else 64 * 1024 * 1024 logging.info( f"FilePBDataset input compression_type: {compression_type} {FLAGS.dataset_input_compression_type} {use_snappy} {FLAGS.dataset_input_use_snappy}" ) if compression_type == CompressType.UNKNOW and FLAGS.dataset_input_compression_type is not None: compression_type = CompressType[ FLAGS.dataset_input_compression_type.upper()] logging.info(f"FilePBDataset change compression_type {compression_type}") logging.info(f"FilePBDataset compression_type {compression_type}") use_snappy = use_snappy or FLAGS.dataset_input_use_snappy if use_snappy is None: if isinstance(file_name, str) and file_name.endswith('.snappy'): use_snappy = True logging.info(f"FilePBDataset change use_snappy {use_snappy}") if use_snappy is None: use_snappy = False tf.compat.v1.add_to_collection(name=OUTPUT_PB_TYPE_GRAPH_KEY, value=output_pb_type.to_name()) variant_tensor = pb_datasource_ops.pb_dataset( file_name=file_name, use_snappy=use_snappy, buffer_size=buffer_size or default_buffer_size, input_pb_type=input_pb_type.to_name(), output_pb_type=output_pb_type.to_name(), has_sort_id=self._has_sort_id, kafka_dump=self._kafka_dump, kafka_dump_prefix=self._kafka_dump_prefix, lagrangex_header=self._lagrangex_header, feature_pruning_type=feature_pruning_type, feature_name_list=feature_name_list, feature_id_list=feature_id_list, out_type=self._out_type, compression_type=compression_type.value, ) logging.info("Start init of the pb instance dataset base.") super().__init__(variant_tensor) @property def element_spec(self): return tensor_spec.TensorSpec([], self._out_type) class DistributedFilePBDataset(dataset_ops.DatasetSource): def __init__( self, patterns: Union[str, List[str]], buffer_size: int = None, input_pb_type: PbType = None, output_pb_type: PbType = None, feature_pruning_type: int = FeaturePruningType.PRUNING_RAW_FEATURE, exclude_fn: Callable[[tf.Tensor], bool] = None, cycle_length=2, block_length=None, num_parallel_calls=tf.data.AUTOTUNE, deterministic=None, use_parquet: bool = False, use_tfrecord: bool = False, **kwargs): if not patterns: patterns = [""] elif isinstance(patterns, str): patterns = [patterns] else: logging.info( f'patterns: len {len(patterns)}, frist is {patterns[0]}, last is {patterns[-1]}' ) patterns.sort() enable_dynamic_sharding = kwargs.get( 'enable_dynamic_sharding', _get_params('enable_dynamic_sharding', False)) logging.info(f"enable_dynamic_sharding: {enable_dynamic_sharding}") assert not ( use_parquet and use_tfrecord ), "It's not allowed to specify use_parquet=True and use_tfrecord=True simultaneously!" if use_parquet: map_func = lambda file_name: ParquetDataset( file_name=file_name, output_pb_type=output_pb_type, **kwargs) elif use_tfrecord: map_func = lambda file_name: tf.data.TFRecordDataset(filenames= [file_name]) else: map_func = lambda file_name: FilePBDataset( file_name=file_name, buffer_size=buffer_size, input_pb_type=input_pb_type, output_pb_type=output_pb_type, feature_pruning_type=feature_pruning_type, disable_iterator_save_restore=not enable_dynamic_sharding, **kwargs) graph = tf.compat.v1.get_default_graph() if FLAGS.data_service_dispatcher and not hasattr(graph, 'dry_run'): files_list = DynamicMatchingFilesDataset(patterns) # files_list = tf.data.Dataset.from_tensor_slices(patterns) if exclude_fn is not None: files_list = files_list.filter(predicate=exclude_fn) dataset = files_list.interleave(map_func, cycle_length=cycle_length, block_length=block_length, num_parallel_calls=num_parallel_calls, deterministic=deterministic) elif enable_dynamic_sharding: files_list = distributed_dataset.create_dynamic_sharding_dataset(patterns) if exclude_fn is not None: files_list = files_list.filter(predicate=exclude_fn) dataset = files_list.flat_map(map_func) else: files_list = matching_files.MatchingFilesDataset(patterns) if exclude_fn is not None: files_list = files_list.filter(predicate=exclude_fn) ctx = native_task_context.get() if ctx is not None: if ctx.num_workers > 1: files_list = files_list.shard(ctx.num_workers, ctx.worker_index) else: shard_num = kwargs.get('shard_num', 1) shard_index = kwargs.get('shard_index', 0) if shard_num > 1: files_list = files_list.shard(shard_num, shard_index) cycle_length = kwargs.get('cycle_length', _get_params('max_task_num_per_worker', 4)) num_parallel_calls = kwargs.get('num_parallel_calls', _get_params('max_task_num_per_worker', 4)) block_length = kwargs.get('block_length', _get_params('block_length', 1)) dataset = files_list.interleave(map_func=map_func, cycle_length=cycle_length, block_length=block_length, num_parallel_calls=num_parallel_calls, deterministic=False) self._dataset = dataset super(DistributedFilePBDataset, self).__init__(variant_tensor=self._dataset._variant_tensor) @property def element_spec(self): return self._dataset.element_spec @monolith_export class InstanceReweightDataset(dataset_ops.UnaryUnchangedStructureDataset): """样本重加权, 并根据action给样本打标签, 使用方式为 dataset.instance_reweight 一个样本可能有多个action, 按`action_priority`, 找到最高优的action. 再用action找到对应的 `action:weight:label`, 让样本重复weight次(也有可能是0次, 即删除样本), 然后给样本打上label指定的标签 Args: input_dataset (:obj:`dataset`): 输入数据集 action_priority (:obj:`str`): action用int表示, 以逗号分隔的int数组, 排在前面的优先级高 reweight (:obj:`str`): 基本单元是`action:weight:label`, 可以用逗号分隔多个基本单元 1) action: 动作, 用int表示, 与业务相关, 如download, install, click, exposure等 2) weight: 权重, 用int表示, 表示样本重复的次数 3) label: 标签, 一般用1/-1表示. variant_type (:obj:`str`): 输入数据是variant类型的, 支持两种格式, instance/example Raises: TypeError: 如果有任何参数与类型不匹配, 则抛TypeError ValueError: 如果有任何值与期望不匹配, 则抛ValueError """ def __init__(self, input_dataset, action_priority: str = None, reweight: str = None, variant_type: str = 'example'): self._label_priority = action_priority self._reweight = reweight self._variant_type = variant_type actions, weights, labels = [], [], [] for item in reweight.strip().split(','): (action, weight, label) = item.strip().split(':') actions.append(int(action)) weights.append(int(weight)) labels.append(int(label)) priorities = [int(p) for p in action_priority.strip().split(',')] variant_tensor = pb_datasource_ops.instance_reweight_dataset( input=input_dataset._variant_tensor, method=0, actions=actions, weights=weights, labels=labels, priorities=priorities, variant_type=variant_type) logging.info("Start init of the pb instance dataset base.") super(InstanceReweightDataset, self).__init__(input_dataset, variant_tensor) @property def element_spec(self): return tensor_spec.TensorSpec([], dtypes.variant) @monolith_export class NegativeGenDataset(dataset_ops.UnaryUnchangedStructureDataset): """负例生成. 有时, 样本中只有正例, 没有负例, 需要随机生成负例 推荐系统中的样本通常是由user侧, item侧两部分组成. 这里的做法是: - 先收集每个样本的item侧信息, 生成一个item池子 - item池子并不是平铺的, 而是按某个特征(channel_slot)分类组织的. 如果在同一个channel随机取item得到的是hard负例, 在其它channel中抽样得到的是easy负例 - 并不是一开始就生成负例, 而是要等item池子积累到一定大小才开始生成负例 Args: input_dataset (:obj:`dataset`): 输入数据集 neg_num (:obj:`int`): 为一个正例生成`neg_num`个负例 channel_feature (:obj:`string`): 用于当item分类的字段 per_channel (:obj:`bool`): 是否分类 start_num (:obj:`int`): 在item池子中积累多少个后才开始采样 max_iten_num (:obj:`int`): 每一个channel最多收集多注个item item_features: (:obj:`List[str]`): item侧的特征名列表 positive_label: 正例的label, 仅为正例生成负例 negative_label: 生成的负例的被打上的label easy_hard_ratio: (:obj:`float`): 当使用 per_channel 的时候, hard和easy负例之间的比例。取值在 0 ~ 1 之间。举例:0.8就是大致80% easy负例 Raises: TypeError: 如果有任何参数与类型不匹配, 则抛TypeError ValueError: 如果有任何值与期望不匹配, 则抛ValueError """ def __init__(self, input_dataset, neg_num: int, per_channel: bool = False, channel_feature: Union[int, str] = '', item_features: Union[List[int], List[str]] = [], start_num: int = 500, max_item_num: int = 100000, positive_label: int = 1, negative_label: int = -1, negative_action: int = -99999, positive_actions: List[int] = [], label_index: int = 0, action_priority: str = '', index_feature: Union[int, str] = '', throw_origin: bool = False, throw_origin_neg: bool = False, cache_only_pos: bool = True, cache_negative_actions: List[int] = [], real_neg_instance_weight: float = 1.0, sampled_neg_instance_weight: float = -1.0, unbias_sampled_neg: bool = True, origin_neg_in_pool_proba: float = 1.0, neg_sample_declay_factor: float = 1.0, easy_hard_ratio: float = 0.0, variant_type: str = 'example'): pool = create_item_pool(start_num=start_num, max_item_num_per_channel=max_item_num) tf.compat.v1.add_to_collection(POOL_KEY, pool) channel_feature = str(channel_feature) item_features = [str(item) for item in item_features] action_priority_items = action_priority.strip().split(',') assert len(action_priority_items) == len(set(action_priority_items)) index_feature = str(index_feature) assert variant_type in {'instance', 'example'} assert label_index >= 0 assert isinstance(cache_negative_actions, list) and \ all(isinstance(x, int) for x in cache_negative_actions) assert len(set(positive_actions) & set(cache_negative_actions)) == 0, \ "positive_actions and cache_negative_actions have intersection, pls check" variant_tensor = pb_datasource_ops.instance_negative_gen_dataset( input=input_dataset._variant_tensor, pool=pool, neg_num=neg_num, per_channel=per_channel, channel_feature=channel_feature, item_features=item_features, label_index=label_index, positive_label=positive_label, negative_label=negative_label, negative_action=negative_action, action_priority=action_priority, positive_actions=positive_actions, index_feature=index_feature, throw_origin=throw_origin, throw_origin_neg=throw_origin_neg, cache_only_pos=cache_only_pos, cache_negative_actions=cache_negative_actions, real_neg_instance_weight=real_neg_instance_weight, sampled_neg_instance_weight=sampled_neg_instance_weight, unbias_sampled_neg=unbias_sampled_neg, origin_neg_in_pool_proba=origin_neg_in_pool_proba, neg_sample_declay_factor=neg_sample_declay_factor, easy_hard_ratio=easy_hard_ratio, variant_type=variant_type) super(NegativeGenDataset, self).__init__(input_dataset, variant_tensor) @property def element_spec(self): return tensor_spec.TensorSpec([], dtypes.variant) def instance_reweight(self, action_priority: str, reweight: str, **kwargs): value = tf.compat.v1.get_collection(OUTPUT_PB_TYPE_GRAPH_KEY) assert len(value) == 1 variant_type = value[0] assert variant_type in {"instance", "example"} return InstanceReweightDataset(self, action_priority, reweight, variant_type=variant_type) @monolith_export class CacheOneDataset(dataset_ops.UnaryDataset): def __init__(self, input_dataset): self._input_dataset = input_dataset variant_tensor = pb_datasource_ops.monolith_cache_one_dataset( input_dataset._variant_tensor) super().__init__(input_dataset, variant_tensor) @property def element_spec(self): return (self._input_dataset.element_spec, tensor_spec.TensorSpec([], dtypes.bool)) @monolith_export class SplitFlowDataset(dataset_ops.UnaryUnchangedStructureDataset): def __init__(self, input_dataset, data_flow: List[str], index: int, max_queue_size: int = 1024, variant_type: str = 'example'): variant_tensor = pb_datasource_ops.split_flow_dataset( input_dataset._variant_tensor, data_flow=data_flow, index=index, max_queue_size=max_queue_size, variant_type=variant_type) super(SplitFlowDataset, self).__init__(input_dataset, variant_tensor) @property def element_spec(self): return tensor_spec.TensorSpec([], dtypes.variant) @monolith_export class MergeFlowDataset(dataset_ops.DatasetV2): def __init__(self, input_dataset, dataset_to_merge, max_queue_size: int = 1024, variant_type: str = 'example'): self._input_dataset = input_dataset self._dataset_to_merge = dataset_to_merge output_types = dataset_ops.get_legacy_output_types(input_dataset) for ds in dataset_to_merge: ds_types = dataset_ops.get_legacy_output_types(ds) if output_types != ds_types: raise TypeError("Datasets to merge have different types %s and %s" % (output_types, ds_types)) input_shapes = dataset_ops.get_legacy_output_shapes(input_dataset) flat_sequence = None input_shapes_flatten = nest.flatten(input_shapes) for ds in dataset_to_merge: ds_shapes_flatten = nest.flatten(dataset_ops.get_legacy_output_shapes(ds)) if flat_sequence is None: flat_sequence = [ ts1.most_specific_compatible_shape(ts2) for (ts1, ts2) in zip(input_shapes_flatten, ds_shapes_flatten) ] else: tmp = [ ts1.most_specific_compatible_shape(ts2) for (ts1, ts2) in zip(input_shapes_flatten, ds_shapes_flatten) ] assert all(ts1 == ts2 for (ts1, ts2) in zip(flat_sequence, tmp)) output_shapes = nest.pack_sequence_as(input_shapes, flat_sequence) output_classes = dataset_ops.get_legacy_output_classes(input_dataset) for ds in dataset_to_merge: ds_classes = dataset_ops.get_legacy_output_classes(ds) if output_classes != ds_classes: raise TypeError("Datasets to merge have different classes %s and %s" % (output_classes, ds_classes)) self._structure = structure.convert_legacy_structure( output_types, output_shapes, output_classes) self._input_datasets = [input_dataset] + dataset_to_merge input_dataset_variant = [ds._variant_tensor for ds in self._input_datasets] data_flow = ['input_ds'] + [ 'ds_to_merge_{}'.format(i + 1) for i in range(len(self._dataset_to_merge)) ] variant_tensor = pb_datasource_ops.merge_flow_dataset( input_dataset_variant, data_flow=data_flow, max_queue_size=max_queue_size, variant_type=variant_type) super(MergeFlowDataset, self).__init__(variant_tensor) def _inputs(self): return self._input_datasets @property def element_spec(self): return self._structure def negative_gen(self, neg_num: int, per_channel: bool = False, channel_feature: Union[int, str] = '', item_features: Union[List[int], List[str]] = [], start_num: int = 500, max_item_num: int = 100000, positive_label: int = 1, negative_label: int = -1, negative_action: int = -99999, positive_actions: List[int] = [], label_index: int = 0, action_priority: str = '', index_feature: Union[int, str] = '', throw_origin: bool = False, throw_origin_neg: bool = False, cache_only_pos: bool = False, cache_negative_actions: List[int] = [], real_neg_instance_weight: float = 1.0, sampled_neg_instance_weight: float = -1.0, unbias_sampled_neg: bool = True, origin_neg_in_pool_proba: float = 1.0, neg_sample_declay_factor: float = 1.0, easy_hard_ratio: float = 0.0, **kwargs): value = tf.compat.v1.get_collection(OUTPUT_PB_TYPE_GRAPH_KEY) assert len(value) == 1 variant_type = value[0] assert variant_type in {"instance", "example"} return NegativeGenDataset( self, neg_num=neg_num, per_channel=per_channel, channel_feature=channel_feature, item_features=item_features, start_num=start_num, max_item_num=max_item_num, label_index=label_index, positive_label=positive_label, negative_label=negative_label, negative_action=negative_action, action_priority=action_priority, positive_actions=positive_actions, index_feature=index_feature, throw_origin=throw_origin, throw_origin_neg=throw_origin_neg, cache_only_pos=cache_only_pos, cache_negative_actions=cache_negative_actions, real_neg_instance_weight=real_neg_instance_weight, sampled_neg_instance_weight=sampled_neg_instance_weight, unbias_sampled_neg=unbias_sampled_neg, origin_neg_in_pool_proba=origin_neg_in_pool_proba, neg_sample_declay_factor=neg_sample_declay_factor, easy_hard_ratio=easy_hard_ratio, variant_type=variant_type) def split_flow(self, data_flow: List[str], index: int, max_queue_size: int = 1024, **kwargs): value = tf.compat.v1.get_collection(OUTPUT_PB_TYPE_GRAPH_KEY) assert len(value) == 1 variant_type = value[0] assert variant_type in {"instance", "example"} return SplitFlowDataset(self, data_flow=data_flow, index=index, max_queue_size=max_queue_size, variant_type=variant_type) def merge_flow(self, dataset_to_merge, max_queue_size: int = 1024, **kwargs): value = tf.compat.v1.get_collection(OUTPUT_PB_TYPE_GRAPH_KEY) assert len(value) == 1 variant_type = value[0] assert variant_type in {"instance", "example"} return MergeFlowDataset(self, dataset_to_merge, max_queue_size=max_queue_size, variant_type=variant_type) class KafkaGen(object): def __init__(self, topics: List[str], group_id: str, servers: Union[str, List[str]], stream_timeout: int = -1, message_poll_timeout: int = 10000, poll_batch_size: int = 1024): if stream_timeout == -1: stream_timeout = sys.maxsize elif stream_timeout >= 0: stream_timeout = max(stream_timeout, message_poll_timeout) else: raise ValueError('stream_timeout must bigger then -1') if isinstance(topics, str): topics = [topics] self.topics, self.group_id, self.servers = topics, group_id, servers self._lock = RLock() self._stop_iteration = False # lock self._consumer: KafkaConsumer = None # lock self._queue = Queue(maxsize=1024) self.message_poll_timeout = message_poll_timeout self.poll_batch_size = poll_batch_size self._max_stream_timeout_polls = int(stream_timeout / message_poll_timeout) self._stream_timeout_polls = -1 @property def consumer(self): with self._lock: if self._consumer is None: self._consumer = KafkaConsumer(*self.topics, group_id=self.group_id, bootstrap_servers=self.servers) thread = Thread(target=self._poll) thread.start() return self._consumer def __iter__(self): return self def __next__(self): assert self.consumer is not None while True: data = self._queue.get(timeout=self.message_poll_timeout) if data: return data with self._lock: if self._stop_iteration: raise StopIteration def __call__(self): return self def _poll(self): while self._stream_timeout_polls < self._max_stream_timeout_polls: try: msg = self._consumer.poll(timeout_ms=self.message_poll_timeout, max_records=self.poll_batch_size, update_offsets=True) if msg: poll_values = [] for part, values in msg.items(): part_vals = [value.value for value in values if value.value] if part_vals: poll_values.extend(part_vals) if poll_values: self._stream_timeout_polls = 0 self._queue.put(poll_values) else: self._stream_timeout_polls += 1 continue else: self._stream_timeout_polls += 1 except Exception as e: logging.error(f'poll error: {e}') break with self._lock: self._consumer.close() self._stop_iteration = True class PyKafkaDataset(dataset_ops.DatasetSource): def __init__(self, topics, group_id, servers, *, has_header=True, variant_type: str = None, stream_timeout=-1, message_poll_timeout=10000, poll_batch_size: int = 1024, filter_empty: bool = False, **kwargs): variant_type = variant_type or _get_params('data_type', PbType.INSTANCE).to_name() self._has_sort_id = kwargs.get('has_sort_id', _get_params('sort_id', False)) self._kafka_dump = kwargs.get('kafka_dump', _get_params('kafka_dump', False)) logging.info(f'pb_type: {variant_type}, kafka_dump: {self._kafka_dump}') self._kafka_dump_prefix = kwargs.get( 'kafka_dump_prefix', _get_params('kafka_dump_prefix', False)) self._lagrangex_header = kwargs.get('lagrangex_header', _get_params('lagrangex_header', False)) if context.default_execution_mode == context.GRAPH_MODE: ckpt_hooks.disable_iterator_save_restore() kafka_gen = KafkaGen(topics, group_id, servers, stream_timeout, message_poll_timeout, poll_batch_size) dataset = tf.data.Dataset.from_generator(generator=kafka_gen, output_types=tf.string, output_shapes=None) dataset = dataset.map( lambda v: string_to_variant(v, variant_type=variant_type.lower(), has_header=has_header, lagrangex_header=self._lagrangex_header, has_sort_id=self._has_sort_id, kafka_dump=self._kafka_dump, kafka_dump_prefix=self._kafka_dump_prefix), num_parallel_calls=tf.data.AUTOTUNE) dataset = dataset.prefetch(buffer_size=tf.data.AUTOTUNE).unbatch() if filter_empty: dataset = dataset.filter(predicate=lambda x: has_variant( input=x, variant_type=variant_type.lower())) self._dataset = dataset super().__init__(self._dataset._variant_tensor) @property def element_spec(self): return self._dataset.element_spec def create_plain_kafka_dataset(topics: List[str], group_id: str, servers: str, stream_timeout=-1, message_poll_timeout=10000, poll_batch_size: int = 1024, configuration=None, container: str = '', shared_name: str = '', kafka_other_metadata: str = None): metadata = list(configuration or []) if group_id is not None: metadata.append(f"group.id={group_id}") if servers is not None: metadata.append(f"bootstrap.servers={servers}") if poll_batch_size is not None: assert isinstance(poll_batch_size, int) and poll_batch_size > 0 metadata.append(f"batch.num.messages={poll_batch_size}") if kafka_other_metadata: kafka_other_metadata_list = kafka_other_metadata.split(',') for meta in kafka_other_metadata_list: metadata.append(meta) resource = kafka_resource_init(topics=topics, metadata=metadata, container=container, shared_name=shared_name) dataset = tf.data.experimental.Counter() dataset = dataset.map(lambda i: kafka_read_next( input=resource, index=i, message_poll_timeout=message_poll_timeout, stream_timeout=stream_timeout, )) dataset = dataset.apply( tf.data.experimental.take_while( lambda v: tf.greater(v.continue_fetch, 0))) return dataset class KafkaDataset(dataset_ops.DatasetSource): def __init__(self, topics: List[str], group_id: str, servers: str, *, has_header=True, variant_type: PbType = None, output_pb_type: PbType = None, stream_timeout=-1, message_poll_timeout=10000, poll_batch_size: int = None, filter_empty: bool = False, configuration=None, container: str = '', shared_name: str = '', kafka_other_metadata: str = None, **kwargs): variant_type = (variant_type or _get_params('data_type', PbType.INSTANCE)).to_name() if output_pb_type is None: output_pb_type = variant_type else: output_pb_type = output_pb_type.to_name() self._out_type = tf.string if output_pb_type == PbType.PLAINTEXT else tf.variant self._has_sort_id = kwargs.get('has_sort_id', _get_params('sort_id', False)) self._kafka_dump = kwargs.get('kafka_dump', _get_params('kafka_dump', False)) logging.info(f'pb_type: {variant_type}, kafka_dump: {self._kafka_dump}') self._kafka_dump_prefix = kwargs.get( 'kafka_dump_prefix', _get_params('kafka_dump_prefix', False)) self._lagrangex_header = kwargs.get('lagrangex_header', _get_params('lagrangex_header', False)) if context.default_execution_mode == context.GRAPH_MODE: ckpt_hooks.disable_iterator_save_restore() self._chnids = kwargs.get('chnids', _get_params('chnids', None)) self._datasources = kwargs.get('datasources', _get_params('datasources', None)) self._default_datasource = kwargs.get('default_datasource', _get_params('default_datasource', '')) with tf.name_scope("MonolithKafkaDataset"): if stream_timeout == -1: stream_timeout = sys.maxsize elif stream_timeout >= 0: stream_timeout = max(stream_timeout, message_poll_timeout) else: raise ValueError( f"Invalid stream_timeout value: {stream_timeout} ,set it to -1 to block indefinitely." ) metadata = list(configuration or []) if group_id is not None: metadata.append(f"group.id={group_id}") if servers is not None: metadata.append(f"bootstrap.servers={servers}") if poll_batch_size is None: if variant_type == "examplebatch": poll_batch_size = 16 else: poll_batch_size = 128 if poll_batch_size is not None: assert isinstance(poll_batch_size, int) and poll_batch_size > 0 metadata.append(f"batch.num.messages={poll_batch_size}") if kafka_other_metadata: kafka_other_metadata_list = kafka_other_metadata.split(',') for meta in kafka_other_metadata_list: metadata.append(meta) tf.compat.v1.add_to_collection(name=OUTPUT_PB_TYPE_GRAPH_KEY, value=output_pb_type) resource = kafka_resource_init( topics=topics, metadata=metadata, input_pb_type=variant_type, #"", step 1 output_pb_type=output_pb_type, #"", step 2 has_sort_id=self._has_sort_id, kafka_dump=self._kafka_dump, kafka_dump_prefix=self._kafka_dump_prefix, lagrangex_header=self._lagrangex_header, container=container, shared_name=shared_name) self._resource = resource dataset = tf.data.experimental.Counter() dataset = dataset.map( lambda i: kafka_read_next_v2( #kafka_read_next step 3 input=self._resource, index=i, message_poll_timeout=message_poll_timeout, stream_timeout=stream_timeout, )) dataset = dataset.apply( tf.data.experimental.take_while( lambda v: tf.greater(v.continue_fetch, 0))) ''' dataset = dataset.map(lambda v: string_to_variant( v.message, variant_type=variant_type.lower(), has_header=has_header, lagrangex_header=self._lagrangex_header, has_sort_id=self._has_sort_id, kafka_dump=self._kafka_dump, kafka_dump_prefix=self._kafka_dump_prefix, chnids=self._chnids, datasources=self._datasources, default_datasource=self._default_datasource), num_parallel_calls=tf.data.AUTOTUNE) ''' ''' # step 4 dataset = dataset.flat_map(lambda v: tf.data.Dataset.from_tensors( string_to_variant_with_transform( v.message, input_type=variant_type.lower(), output_type=output_pb_type, has_header=has_header, lagrangex_header=self._lagrangex_header, has_sort_id=self._has_sort_id, kafka_dump=self._kafka_dump, kafka_dump_prefix=self._kafka_dump_prefix, chnids=self._chnids, datasources=self._datasources, default_datasource=self._default_datasource))) ''' dataset = dataset.map(lambda v: v.message) dataset = dataset.prefetch(buffer_size=tf.data.AUTOTUNE).unbatch() if filter_empty: dataset = dataset.filter(predicate=lambda x: has_variant( input=x, variant_type=variant_type.lower())) self._dataset = dataset super().__init__(self._dataset._variant_tensor) @property def element_spec(self): return tensor_spec.TensorSpec([], self._out_type) def register_dataset(service, dataset, buffer_size=32): protocol, address = _parse_service(service) external_state_policy = dataset.options().experimental_external_state_policy if external_state_policy is None: external_state_policy = ExternalStatePolicy.WARN logging.info('external_state_policy: %s', external_state_policy) dataset = dataset.map( lambda *x: compression_ops.compress(x), # num_parallel_calls=dataset_ops.AUTOTUNE) num_parallel_calls=None) logging.info('num_parallel_calls: None') # dataset = dataset.prefetch(buffer_size=buffer_size) dataset = dataset._apply_options() dataset_id = gen_experimental_dataset_ops.register_dataset( dataset._variant_tensor, address=address, protocol=protocol, external_state_policy=external_state_policy.value) return dataset_id def from_dataset_id(processing_mode, service, dataset_id, element_spec, job_name=None, max_outstanding_requests=None, task_refresh_interval_hint_ms=None, buffer_size: int = 16): ProcessingMode.validate(processing_mode) protocol, address = _parse_service(service) dataset = _DataServiceDataset( dataset_id=dataset_id, processing_mode=processing_mode, address=address, protocol=protocol, job_name=job_name, max_outstanding_requests=max_outstanding_requests, task_refresh_interval_hint_ms=task_refresh_interval_hint_ms) dataset = dataset.prefetch(buffer_size=buffer_size).map( lambda x: compression_ops.uncompress(x, output_spec=element_spec), num_parallel_calls=dataset_ops.AUTOTUNE) # Disable autosharding for shared jobs. if job_name: options = dataset_ops.Options() options.experimental_distribute.auto_shard_policy = AutoShardPolicy.OFF dataset = dataset.with_options(options) return dataset def merged_window(self: tf.data.Dataset, size: int = 2, drop_remainder: bool = True): dataset = self.window(size=size, drop_remainder=drop_remainder) def re_shape(ts: Union[tf.Tensor, tf.RaggedTensor]): if isinstance(ts, tf.Tensor): shape = ts._shape_as_list() if shape: if shape[0] is None or shape[1] is None: shape[1] = -1 else: shape[1] = shape[0] * shape[1] del shape[0] return tf.reshape(ts, shape=shape) else: return ts else: return ts.values element_spec = self.element_spec if isinstance(element_spec, (tf.TensorSpec, tf.RaggedTensor)): return dataset.flat_map(map_func=lambda window: window.batch( size, drop_remainder=drop_remainder).map(map_func=re_shape)) elif isinstance(element_spec, (tuple, list)): return dataset.flat_map(map_func=lambda *window: tf.data.Dataset.zip( tuple( value.batch(size, drop_remainder=drop_remainder).map( map_func=re_shape) for value in window))) elif isinstance(element_spec, dict): return dataset.flat_map(map_func=lambda window: tf.data.Dataset.zip({ key: value.batch(size, drop_remainder=drop_remainder).map( map_func=re_shape) for key, value in window.items() })) else: raise Exception(f"element_spec {element_spec} is not support!") def distribute(self, *, target: str = None, job_name: str = "monolith_dataservice_task", num_worker: int = None, worker_idx: int = None, queue_device: str = "/job:ps/task:0/device:CPU:0", max_outstanding_requests: int = dataset_ops.AUTOTUNE, window_size: int = None): graph = tf.compat.v1.get_default_graph() if hasattr(graph, 'dry_run') or not FLAGS.data_service_dispatcher: return self if worker_idx is None: worker_idx = FLAGS.dataset_worker_idx if num_worker is None: num_worker = FLAGS.dataset_num_workers if target is None: target = FLAGS.data_service_dispatcher assert worker_idx is not None and num_worker is not None and target is not None if max_outstanding_requests is None: max_outstanding_requests = min(num_worker, 8) if FLAGS.is_local: dataset_id = register_dataset(target, self) dataset = dsvc.from_dataset_id( processing_mode="distributed_epoch", service=target, dataset_id=dataset_id, job_name=job_name, element_spec=self.element_spec, max_outstanding_requests=max_outstanding_requests) return dataset elif num_worker is None or num_worker <= 0: logging.warning(f'num_worker is {num_worker}, error') return self elif worker_idx is None or worker_idx < 0: logging.warning(f'worker_idx is {worker_idx}, error') return self try: if FLAGS.kafka_topics is not None and FLAGS.kafka_group_id is not None: return self except Exception as e: pass logging.info( f'dataset.distribute worker_idx {worker_idx}, num_worker {num_worker}, target {target}' ) tf_config = os.environ.get('TF_CONFIG') if tf_config is not None: tf_config = json.loads(tf_config) roles = set(map(lambda x: x.lower(), tf_config['cluster'])) if queue_device is None: if 'ps' in roles: queue_device = "/job:ps/task:0/device:CPU:0" elif 'worker' in roles: queue_device = "/job:worker/task:0/device:CPU:0" else: raise Exception('role error') element_spec = self.element_spec if enable_sync_training(): has_error = False try: enable_bps = int(os.getenv("MONOLITH_WITH_BYTEPS", "0")) if enable_bps: import byteps.tensorflow as hvd else: import horovod.tensorflow as hvd except (ImportError, tf.errors.NotFoundError) as e: logging.info(f'ImportError is {e}') has_error = True if has_error: dataset_id = register_dataset(target, self) else: dataset_id = tf.constant(value=1000, dtype=tf.int64, shape=tuple(), name='default_dataset_id') if hvd.rank() == 0: tf.compat.v1.add_to_collection(name="registed_dataset_id", value=register_dataset(target, self)) else: tf.compat.v1.add_to_collection(name="registed_dataset_id", value=dataset_id) dataset = dsvc.from_dataset_id( processing_mode="distributed_epoch", service=target, dataset_id=dataset_id, job_name=job_name, element_spec=element_spec, max_outstanding_requests=max_outstanding_requests) if window_size is not None: dataset = dataset.merged_window(size=window_size) elif tf_config is not None and 'ps' in map(lambda x: x.lower(), tf_config['cluster']): logging.info('PS/Worker mode, use queue to broadcast dataset_id') with tf.compat.v1.device(queue_device): queue = tf.compat.v1.FIFOQueue(capacity=num_worker, dtypes=[tf.int64], shared_name=f'{job_name}_queue', shapes=tuple()) if worker_idx == 0: # data service try to register dataset, if the dataset has been registed, return dataset_id drectily # that means get or register dataset. for data parallel, the data pipeline assure to be identity # here we ues queue to ensure the same data pipeline for a job dataset_id = register_dataset(target, self) stacked_dids = tf.stack(values=[dataset_id for _ in range(num_worker)], name='stacked_dids') enqueue_op = queue.enqueue_many(vals=stacked_dids) with tf.compat.v1.control_dependencies(control_inputs=[enqueue_op]): # to share pipeline, job_name must be specified dataset = dsvc.from_dataset_id( processing_mode="distributed_epoch", service=target, dataset_id=dataset_id, job_name=job_name, element_spec=element_spec, max_outstanding_requests=max_outstanding_requests) else: dataset_id = queue.dequeue() dataset = dsvc.from_dataset_id( processing_mode="distributed_epoch", service=target, dataset_id=dataset_id, job_name=job_name, element_spec=element_spec, max_outstanding_requests=max_outstanding_requests) if window_size is not None: dataset = dataset.merged_window(size=window_size) else: logging.info(f'enable_sync_training is {enable_sync_training()}') return self return dataset def transform(self, t: Transform, **kwargs): value = tf.compat.v1.get_collection(OUTPUT_PB_TYPE_GRAPH_KEY) assert len(value) == 1 variant_type = value[0] assert variant_type in {"instance", "example"} return TransformDataset(self, t, variant_type=variant_type) @monolith_export class TransformDataset(dataset_ops.UnaryUnchangedStructureDataset): """样本过滤/改写 Args: input_dataset (:obj:`dataset`): 输入数据集 transform (:obj:`Transform`): 改写方式 variant_type (:obj:`str`): 输入数据是variant类型的, 支持两种格式, instance/example Raises: TypeError: 如果有任何参数与类型不匹配, 则抛TypeError ValueError: 如果有任何值与期望不匹配, 则抛ValueError """ def __init__(self, input_dataset, transform: Transform, variant_type: str): assert variant_type in {"instance", "example"} self._transform = transform variant_tensor = pb_datasource_ops.transform_dataset( input=input_dataset._variant_tensor, config=transform.as_proto().SerializeToString(), variant_type=variant_type) logging.info("Start init of the pb instance dataset base.") super(TransformDataset, self).__init__(input_dataset, variant_tensor) @property def element_spec(self): return tensor_spec.TensorSpec([], dtypes.variant) Dataset.instance_reweight = instance_reweight Dataset.negative_gen = negative_gen Dataset.split_flow = split_flow Dataset.merge_flow = merge_flow Dataset.distribute = lambda ds, *args, **kwargs: ds Dataset.merged_window = merged_window Dataset.transform = transform ================================================ FILE: monolith/native_training/data/docker-compose.yaml ================================================ version: "3" services: zookeeper: image: 'bitnami/zookeeper:latest' ports: - '2181:2181' environment: - ALLOW_ANONYMOUS_LOGIN=yes kafka: image: 'bitnami/kafka:latest' ports: - '9092:9092' environment: - KAFKA_BROKER_ID=1 - KAFKA_LISTENERS=PLAINTEXT://:9092 - KAFKA_ADVERTISED_LISTENERS=PLAINTEXT://127.0.0.1:9092 - KAFKA_ZOOKEEPER_CONNECT=zookeeper:2181 - ALLOW_PLAINTEXT_LISTENER=yes depends_on: - zookeeper ================================================ FILE: monolith/native_training/data/eager_mode_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import tensorflow as tf from monolith.native_training.data.datasets import PBDataset, PbType, DynamicMatchingFilesDataset from monolith.native_training.data.parsers import parse_instances, parse_examples, parse_example_batch from monolith.native_training.data.feature_utils import switch_slot, feature_combine features = { 'f_spm_1': 301, 'f_spm_3': 303, 'f_spm_2': 302, 'f_spm_4': 304, 'f_user_id': 1, 'f_user_ctx_network': 61, 'f_user_id-f_page': 504, 'f_scm': 306, 'f_goods_id': 200, 'f_goods_sale_number_1000': 225, 'f_goods_praise_cnt': 229, 'f_spm': 300, 'f_page': 305, 'f_is_dup': 310, 'f_user_ctx_platform': 52, 'f_goods_title_terms': 209, 'f_goods_tags_terms': 211, 'f_user_test09_array_int32': 554, 'f_user_test15_array_float': 540, 'f_user_test14_array_bool': 543, 'f_user_test12_array_uint64': 551, 'f_user_test10_array_int64': 549 } group_slots = [200,201,202,203,204,205,206,210,211,212,213,214,215,\ 216,217,218,219,220,221,222,223,224,225,230,231,232,233,234,235,236,237,238,239,240,241,242] def parse_inst_exam(tensor: tf.Tensor, out_type): fidv1_features = [ 1, 2, 32, 33, 36, 38, 42, 50, 54, 56, 60, 66, 120, 150, 180, 182, 192, 220, 333, 410, 412, 422, 446 ] if out_type == PbType.INSTANCE: return parse_instances(tensor, fidv1_features, dense_features=['label'], dense_feature_shapes=[2], dense_feature_types=[tf.float32], extra_features=['uid', 'req_time', 'item_id'], extra_feature_shapes=[1, 1, 1]) else: return parse_examples( tensor, sparse_features=[f'fc_slot_{slot}' for slot in fidv1_features], dense_features=['label'], dense_feature_shapes=[2], dense_feature_types=[tf.float32], extra_features=['uid', 'req_time', 'item_id'], extra_feature_shapes=[1, 1, 1]) def parse_eb(tensor: tf.Tensor, out_type): if out_type == PbType.INSTANCE: feature_dict = parse_instances( tensor, fidv1_features=list(features.values()), dense_features=['label'], dense_feature_shapes=[2], dense_feature_types=[tf.float32], extra_features=['uid', 'req_time', 'item_id'], extra_feature_shapes=[1, 1, 1]) else: feature_dict = parse_examples(tensor, sparse_features=list(features.keys()), dense_features=['label'], dense_feature_shapes=[2], dense_feature_types=[tf.float32], extra_features=['uid', 'req_time', 'item_id'], extra_feature_shapes=[1, 1, 1]) feature_dict['f_page'] = switch_slot(feature_dict['f_page'], slot=306) feature_dict['f_user_id-f_goods_tags_terms'] = feature_combine( feature_dict['f_user_id'], feature_dict['f_goods_tags_terms'], slot=505) return feature_dict class DataOpsTest(tf.test.TestCase): def target(self, input_pb_type, output_pb_type): filter_fn = None if input_pb_type == PbType.INSTANCE: lagrangex_header = False has_sort_id, kafka_dump, kafka_dump_prefix = True, True, False file_name = "monolith/native_training/data/training_instance/instance.pb" elif input_pb_type == PbType.EXAMPLE: lagrangex_header = False has_sort_id, kafka_dump, kafka_dump_prefix = True, True, False file_name = "monolith/native_training/data/training_instance/example.pb" else: lagrangex_header = True has_sort_id, kafka_dump, kafka_dump_prefix = False, False, False file_name = "monolith/native_training/data/training_instance/examplebatch.data" def parser(tensor: tf.Tensor): if output_pb_type == PbType.PLAINTEXT: return parse_inst_exam(tensor, input_pb_type) elif input_pb_type != PbType.EXAMPLEBATCH: return parse_inst_exam(tensor, output_pb_type) else: return parse_eb(tensor, output_pb_type) dataset = PBDataset(file_name=file_name, lagrangex_header=lagrangex_header, has_sort_id=has_sort_id, kafka_dump=kafka_dump, kafka_dump_prefix=kafka_dump_prefix, input_pb_type=input_pb_type, output_pb_type=output_pb_type) if input_pb_type == PbType.EXAMPLEBATCH: variant_type = 'instance' if output_pb_type == PbType.INSTANCE else 'example' dataset = dataset.instance_reweight( action_priority="2,7,0,1,3,4,5,6,8,9,10,11", reweight= "0:0:1,1:0:1,2:3:-1,3:0:1,4:0:1,5:0:1,6:0:1,7:6:1,8:0:1,9:0:1,10:0:1,11:0:-1", variant_type=variant_type) if filter_fn is not None: dataset = dataset.filter(filter_fn) dataset = dataset.batch(8, drop_remainder=True).map(parser) for feature in dataset.take(5): self.assertIn(len(feature), {26, 27}) def testExampleBatch2Instance(self): self.target(PbType.EXAMPLEBATCH, PbType.INSTANCE) def testExample2Instance(self): self.target(PbType.EXAMPLE, PbType.INSTANCE) def testInstance2Instance(self): self.target(PbType.INSTANCE, PbType.INSTANCE) def testExampleBatch(self): lagrangex_header = True has_sort_id, kafka_dump, kafka_dump_prefix = False, False, False file_name = "monolith/native_training/data/training_instance/examplebatch.data" input_pb_type, output_pb_type = PbType.EXAMPLEBATCH, PbType.EXAMPLEBATCH dataset = PBDataset(file_name=file_name, lagrangex_header=lagrangex_header, has_sort_id=has_sort_id, kafka_dump=kafka_dump, kafka_dump_prefix=kafka_dump_prefix, input_pb_type=input_pb_type, output_pb_type=output_pb_type) def parser(tensor): freatues = parse_example_batch( tensor, sparse_features=list(features.keys()), dense_features=['label'], dense_feature_shapes=[2], dense_feature_types=[tf.float32], extra_features=['uid', 'req_time', 'item_id'], extra_feature_shapes=[1, 1, 1]) return freatues dataset = dataset.map(parser) for feature in dataset.take(5): self.assertIn(len(feature), {26, 27}) if __name__ == '__main__': tf.test.main() ================================================ FILE: monolith/native_training/data/extract_fid_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import tensorflow as tf from monolith.native_training.runtime.ops import gen_monolith_ops ragged_data_ops = gen_monolith_ops class ExtraFidTest(tf.test.TestCase): def test_parse_search(self): fid = ragged_data_ops.extract_fid(185, 4).numpy() self.assertTrue(fid == 1153447759131936) if __name__ == "__main__": tf.test.main() ================================================ FILE: monolith/native_training/data/feature_list.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from absl import logging, flags from dataclasses import dataclass import inspect import threading import numpy as np from collections import defaultdict from typing import List, Dict, Optional, Set, Tuple, Union, get_type_hints import tensorflow as tf from monolith.native_training.data.utils import get_slot_feature_name, get_slot_from_feature_name from monolith.native_training.utils import add_to_collections _BOOL_FLAGS = {'true', 'yes', 't', 'y', '1'} _cache = {} FID_MASK = (1 << 64) - 1 FLAGS = flags.FLAGS def new_instance(cls, args): signature = inspect.signature(cls.__init__) valid_args = {} for key, param in signature.parameters.items(): if key not in {'cls', 'self'}: if param.name in args: valid_args[param.name] = args[param.name] return cls(**valid_args) @dataclass class Feed: feed_name: str = None shared: bool = None feature_id: int = None def __post_init__(self): if self.shared is not None: self.shared = self.shared.lower() in _BOOL_FLAGS else: self.shared = False if isinstance(self.feature_id, str): self.feature_id = int(self.feature_id) @property def name(self): return self.feed_name @dataclass class Cache: cache_column: str = None cache_name: str = None capacity: int = None timeout: int = None cache_type: str = None cache_key_class: str = None def __post_init__(self): if isinstance(self.capacity, str): self.capacity = int(self.capacity) if isinstance(self.timeout, str): self.timeout = int(self.timeout) @property def name(self): if self.cache_name is not None: return self.cache_name elif self.cache_key_class is not None: return self.cache_key_class elif self.cache_column is not None: return 'cache_column' else: raise Exception('no name for cache') @dataclass class Feature: feature_name: str = None depend: List[str] = None method: str = None slot: int = None args: List[str] = None feature_version: int = None shared: bool = False cache_keys: List[str] = None need_raw: bool = False feature_id: int = None input_optional: List[bool] = None feature_group: List[str] = None def __post_init__(self): if isinstance(self.feature_group, str): self.feature_group = [ item.strip().replace('"', '').replace("'", '') for item in self.feature_group.strip().split(',') ] if isinstance(self.depend, str): self.depend = [ item.strip().replace('"', '').replace("'", '') for item in self.depend.strip().split(',') ] if isinstance(self.input_optional, str): self.input_optional = [ item.strip().replace('"', '').replace("'", '') == 'true' for item in self.input_optional.strip().split(',') ] if isinstance(self.args, str): self.args = [ item.strip().replace('"', '').replace("'", '') for item in self.args.strip().split(',') ] if isinstance(self.cache_keys, str): self.cache_keys = [ item.strip().replace('"', '').replace("'", '') for item in self.cache_keys.strip().split(',') ] if isinstance(self.slot, str): self.slot = int(self.slot) if isinstance(self.shared, str): self.shared = self.shared.lower() in _BOOL_FLAGS if isinstance(self.need_raw, str): self.need_raw = self.need_raw.lower() in _BOOL_FLAGS if isinstance(self.feature_id, str): self.feature_id = int(self.feature_id) if isinstance(self.feature_version, str): self.feature_version = int(self.feature_version) def __str__(self): terms = [] for name, clz in get_type_hints(Feature).items(): value = getattr(self, name) if value is not None: if clz == str: terms.append("{}={}".format(name, value)) elif clz == int: terms.append("{}={}".format(name, value)) elif clz == bool: if value: terms.append("{}=true".format(name)) elif clz._name == 'List' and len(clz.__args__) == 1: if clz.__args__[0] == str: terms.append("{}={}".format(name, ','.join(value))) elif clz.__args__[0] == bool: format_value = [str(b).lower() for b in value] terms.append("{}={}".format(name, ','.join(format_value))) else: raise ValueError("Type Error") return ';'.join(terms) @property def name(self): term_list = [] for term in self.feature_name.split('-'): if term.startswith('fc_'): term = term[3:] elif self.feature_name.startswith('f_'): term = term[2:] term_list.append(term) return '-'.join(term_list).lower() @property def depend_strip_prefix(self): depend = [] for dep in self.depend: term_list = [] for term in dep.split('-'): if term.startswith('fc_'): term = term[3:] elif term.startswith('f_'): term = term[2:] term_list.append(term) depend.append('-'.join(term_list).lower()) return depend class FeatureList(object): _lock = threading.Lock() def __init__(self, column_name: Optional[Set[str]], feeds: Dict[str, Feed], caches: Dict[str, Cache], features: Dict[str, Feature]): self.column_name = column_name self.feeds = feeds self.caches = caches self.features = features self.__slots = defaultdict(list) for feat in features.values(): self.__slots[feat.slot].append(feat) add_to_collections('feature_list', self) def __getitem__(self, item) -> Feature: if isinstance(item, int): if item in self.__slots: return self.__slots[item][0] else: raise Exception('there is no feature {}'.format(item)) else: assert isinstance(item, str) item = item.strip() if item in self.features: return self.features[item] elif f'f_{item}' in self.features: return self.features[f'f_{item}'] elif f'fc_{item}' in self.features: return self.features[f'fc_{item}'] else: if '-' in item: new_item = '-'.join([f'fc_{term}' for term in item.split('-')]) if new_item in self.features: return self.features[new_item] new_item = '-'.join([f'f_{term}' for term in item.split('-')]) if new_item in self.features: return self.features[new_item] raise Exception('there is no feature {}'.format(item)) def get(self, item, default=None): try: return self.__getitem__(item) except: return default def get_with_slot(self, slot): if slot in self.__slots: return self.__slots[slot] else: return [] def __len__(self): return len(self.features) def __contains__(self, item): return item in self.features or f'f_{item}' in self.features or f'fc_{item}' in self.features or item in self.__slots def __iter__(self): return iter(self.features.values()) @classmethod def parse(cls, fname: str = None, use_old_name: bool = True) -> 'FeatureList': fname = fname or FLAGS.feature_list assert fname is not None with cls._lock: if fname in _cache: return _cache[fname] column_name = None feeds, caches, features = {}, {}, {} with open(fname) as stream: for line in stream: line = line.strip() if len(line) == 0 or line.startswith("#"): continue if line.startswith('column_name'): start = len('column_name:') column_name = {item.strip() for item in line[start:].split(',')} continue if line.startswith('cache_column'): cache = Cache(cache_column=line[len('cache_column:'):].strip()) caches[cache.name] = cache continue params = {} items = line.split('=') for i in range(len(items) - 1): if i == 0: key = items[i].strip() else: start = items[i].rindex(" ") key = items[i][start:].strip() if i == len(items) - 2: value = items[i + 1] else: end = items[i + 1].rindex(" ") value = items[i + 1][0:end] params[key] = value.strip().rstrip(',').rstrip(';').rstrip() try: if line.startswith('feed'): feed = new_instance(Feed, params) feeds[feed.name] = feed elif line.startswith('cache'): cache = new_instance(Cache, params) caches[cache.name] = cache else: feat = new_instance(Feature, params) if use_old_name: features[feat.feature_name] = feat else: features[feat.name] = feat except Exception as e: print(line) raise e feat_list = cls(column_name, feeds, caches, features) _cache[fname] = feat_list return feat_list def get_feature_name_and_slot(item) -> Tuple[str, Optional[int]]: if isinstance(item, int): try: feature_list = FeatureList.parse() return feature_list.get(item).feature_name, item except: return get_slot_feature_name(item), item elif isinstance(item, str): try: feature_list = FeatureList.parse() assert item in feature_list return item, feature_list[item].slot except: return item, get_slot_from_feature_name(item) else: # for FeatureColumn assert hasattr(item, 'feature_name') and hasattr(item, 'feature_slot') return item.feature_name, item.feature_slot _VALID_FNAMES = set() def is_example_batch(): # only example batch need column prune, this function is design for example_batch is_example_batch = False if hasattr(FLAGS, 'data_type') and FLAGS.data_type: if FLAGS.data_type.lower() in {'example_batch', 'examplebatch'}: is_example_batch = True return is_example_batch def add_feature(feature: Union[str, int, List[str], List[int]]): global _VALID_FNAMES if not isinstance(feature, (list, tuple)): feature = [feature] if feature: for element in feature: if isinstance(element, str): _VALID_FNAMES.add(element) else: assert isinstance(element, int) _VALID_FNAMES.add(get_slot_feature_name(element)) def add_feature_by_fids(fids: Union[int, List[int]], feature_list: FeatureList = None): if not is_example_batch(): return if isinstance(fids, int): fids = [fids] if feature_list is None: # for example_batch, there is a feature_list.conf feature_list = FeatureList.parse() if feature_list: for fid in fids: find_feature = False if isinstance(fid, int): fid = fid & FID_MASK else: assert isinstance(fid, np.int64) fid = fid & np.uint64(FID_MASK).astype(np.int64) for feature in feature_list.get_with_slot(fid >> 54): if feature.feature_version is None or feature.feature_version == 1: add_feature(feature.feature_name) find_feature = True for feature in feature_list.get_with_slot(fid >> 48): if feature.feature_version == 2: add_feature(feature.feature_name) find_feature = True if not find_feature: raise Exception(f'Cannot find feature name for fid: {fid}') else: raise Exception('Cannot create feature_list') def get_valid_features() -> List[str]: global _VALID_FNAMES return list(_VALID_FNAMES) ================================================ FILE: monolith/native_training/data/feature_list_test.py ================================================ ================================================ FILE: monolith/native_training/data/feature_utils.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from enum import Enum import string import numpy as np from typing import Any, List, Union, Dict, Tuple import tensorflow as tf from monolith.utils import get_libops_path from monolith.native_training.monolith_export import monolith_export from monolith.native_training.runtime.ops import gen_monolith_ops from idl.matrix.proto.line_id_pb2 import LineId from monolith.native_training.data.feature_list import add_feature, add_feature_by_fids from monolith.native_training.data.data_op_config_pb2 import ( LabelConf, TaskLabelConf, TFRecordFeatureDescription) ragged_data_ops = gen_monolith_ops @monolith_export def filter_by_fids(variant: tf.Tensor, filter_fids: List[int] = None, has_fids: List[int] = None, select_fids: List[int] = None, has_actions: List[int] = None, req_time_min: int = 0, select_slots: List[int] = None, variant_type: str = 'instance'): """通过特征ID (FID) 过滤, 离散特征过滤 Args: variant (:obj:`Tensor`): 输入数据, 必须是variant类型 filter_fids (:obj:`List[int]`): 任意一个FID出现`filter_fids`中, 样本被过滤 has_fids (:obj:`List[int]`): 任意一个FID出现在`has_fids`中, 则样本被选择 select_fids (:obj:`List[int]`): 所有`select_fids`均出现在样本中, 则样本被选择 has_actions (:obj:`List[int]`): 任意一个action出现在`has_actions`中, 则样本被选择 req_time_min (:obj:`int`): 请求时间最小值 select_slots (:obj:`List[int]`): 所有`select_slots`均出现在样本中, 样本才被选择 variant_type (:obj:`str`): variant类型, 可以为instance/example Returns: variant tensor, 过滤后的数据, variant类型 """ filter_fids = [] if filter_fids is None else [ np.uint64(fid).astype(np.int64) for fid in filter_fids ] has_fids = [] if has_fids is None else [ np.uint64(fid).astype(np.int64) for fid in has_fids ] select_fids = [] if select_fids is None else [ np.uint64(fid).astype(np.int64) for fid in select_fids ] select_slots = [] if select_slots is None else select_slots assert all([slot > 0 for slot in select_slots]) if variant_type != 'instance': add_feature_by_fids(filter_fids) add_feature_by_fids(has_fids) add_feature_by_fids(select_fids) return ragged_data_ops.set_filter(variant, filter_fids, has_fids, select_fids, has_actions or [], req_time_min, select_slots, variant_type) @monolith_export def filter_by_feature_value(variant: tf.Tensor, field_name: str, op: str, operand: Union[float, int, str, List[float], List[int], List[str]], field_type: str, keep_empty: bool = False, operand_filepath: str = None): """通过值过滤, 连续特征过滤, Args: variant (:obj:`Tensor`): 输入数据, 必须是variant类型 field_name (:obj:`List[int]`): 当field_name, 样本被过滤 op (:obj:`str`): 比较运算符, 可以是 gt/ge/eq/lt/le/neq/between/in/not-in 等 布尔运算,也可以是 all/any/diff 等集合布尔运算 operand (:obj:`float`): 操作数, 用于比较, 可以为值或者List keep_empty (:obj:`bool`): False field_type (:obj:`str`): 需要显式指定字段类型, 可以为int64/float/double/bytes Returns: variant tensor, 过滤后的数据, variant类型 """ assert op in { 'gt', 'ge', 'eq', 'lt', 'le', 'neq', 'between', 'in', 'not-in', 'all', 'any', 'diff', 'startswith', 'endswith' } assert (operand is None and operand_filepath) or (operand is not None and not operand_filepath) assert field_type in { 'int64', 'float', 'double', 'bytes' }, 'You must specify field_type for feature value_filter!' string_operand = [] operand_filepath = '' if operand_filepath is None else operand_filepath if operand_filepath: assert op in {'in', 'not-in'} assert (isinstance(operand_filepath, str) and tf.io.gfile.exists(operand_filepath)) int_operand, float_operand = [], [] elif op in {'all', 'any', 'diff'}: assert field_type == 'int64', 'all/any/diff op only support int64 list' if not isinstance(operand, (list, tuple)): assert isinstance(operand, int) int_operand, float_operand = [operand], [] else: assert all(isinstance(o, int) for o in operand) int_operand, float_operand = list(operand), [] elif field_type in {'float', 'double'}: if op == 'between': assert all(isinstance(o, (int, float)) for o in operand) int_operand, float_operand = [], [float(o) for o in operand] else: int_operand, float_operand = [], [float(operand)] elif field_type == 'int64': if op in {'in', 'not-in', 'between'}: assert all(isinstance(o, int) for o in operand) int_operand, float_operand = list(operand), [] else: int_operand, float_operand = [int(operand)], [] elif field_type == 'bytes': int_operand, float_operand = [], [] if isinstance(operand, str): string_operand.append(operand) elif isinstance(operand, (list, tuple)): assert all(isinstance(o, str) for o in operand) string_operand.extend(operand) else: raise RuntimeError("params error!") else: raise RuntimeError("params error!") return ragged_data_ops.feature_value_filter(variant, field_name=field_name, op=op, float_operand=float_operand, int_operand=int_operand, string_operand=string_operand, operand_filepath=operand_filepath, field_type=field_type, keep_empty=keep_empty) @monolith_export def filter_by_value(variant: tf.Tensor, field_name: str, op: str, operand: Union[float, int, str, List[float], List[int], List[str]], variant_type: str = 'instance', keep_empty: bool = False, operand_filepath: str = None): """通过值过滤, 连续特征过滤, Args: variant (:obj:`Tensor`): 输入数据, 必须是variant类型 field_name (:obj:`List[int]`): 需要执行过滤逻辑的字段名 op (:obj:`str`): 比较运算符, 可以是 gt/ge/eq/lt/le/neq/between/in/not-in 等 布尔运算,也可以是 all/any/diff 等集合布尔运算,也可以是 startswith/endswith 等 字符串判断逻辑 operand (:obj:`float`): 操作数, 用于比较 variant_type (:obj:`str`): variant类型, 可以为instance/example keep_empty (:obj:`bool`): False Returns: variant tensor, 过滤后的数据, variant类型 """ if variant_type != 'instance': add_feature('__LINE_ID__') assert op in { 'gt', 'ge', 'eq', 'lt', 'le', 'neq', 'between', 'in', 'not-in', 'all', 'any', 'diff', 'startswith', 'endswith' } fields = LineId.DESCRIPTOR.fields_by_name assert field_name in fields assert (operand is None and operand_filepath) or (operand is not None and not operand_filepath) field = fields[field_name] string_operand = [] operand_filepath = '' if operand_filepath is None else operand_filepath if operand_filepath: assert op in {'in', 'not-in'} assert (isinstance(operand_filepath, str) and tf.io.gfile.exists(operand_filepath)) int_operand, float_operand = [], [] elif field.has_options: assert op in {'all', 'any', 'diff'} assert field.cpp_type in { field.CPPTYPE_INT32, field.CPPTYPE_INT64, field.CPPTYPE_UINT32, field.CPPTYPE_UINT64 } if not isinstance(operand, (list, tuple)): assert isinstance(operand, int) int_operand, float_operand = [operand], [] else: assert all(isinstance(o, int) for o in operand) int_operand, float_operand = list(operand), [] elif field.cpp_type in {field.CPPTYPE_DOUBLE, field.CPPTYPE_FLOAT}: if op == 'between': assert all(isinstance(o, (int, float)) for o in operand) int_operand, float_operand = [], [float(o) for o in operand] else: int_operand, float_operand = [], [float(operand)] elif field.cpp_type in { field.CPPTYPE_INT32, field.CPPTYPE_INT64, field.CPPTYPE_UINT32, field.CPPTYPE_UINT64 }: if op in {'in', 'not-in', 'between'}: assert all(isinstance(o, int) for o in operand) int_operand, float_operand = list(operand), [] else: int_operand, float_operand = [int(operand)], [] elif field.cpp_type == field.CPPTYPE_STRING: int_operand, float_operand = [], [] if isinstance(operand, str): string_operand.append(operand) elif isinstance(operand, (list, tuple)): assert all(isinstance(o, str) for o in operand) string_operand.extend(operand) else: raise RuntimeError("params error!") else: raise RuntimeError("params error!") return ragged_data_ops.value_filter(variant, field_name=field_name, op=op, float_operand=float_operand, int_operand=int_operand, string_operand=string_operand, operand_filepath=operand_filepath, keep_empty=keep_empty, variant_type=variant_type) @monolith_export def add_action( variant: tf.Tensor, field_name: str, op: str, operand: Union[float, int, str, List[float], List[int], List[str]], action: int, variant_type: str = 'example', ): """根据指定 LineId 字段经过简单的关系运算,决定是否为 actions 字段增加值 Args: variant (:obj:`Tensor`): 输入数据,必须是 variant 类型 field_name (:obj:`List[int]`): 根据 field_name 对应值进行条件判断 op (:obj:`str`): 比较运算符,可以是 gt/ge/eq/lt/le/neq/between/in operand (:obj:`float`): 操作数,用于比较 action (:obj:`int`): 当条件满足时,需要往 LineId.actions 添加的值 variant_type (:obj:`str`): 'instance' 或 'example' Returns: variant tensor, 改写后的数据,variant 类型 """ if variant_type != 'instance': add_feature('__LINE_ID__') assert op in {'gt', 'ge', 'eq', 'lt', 'le', 'neq', 'between', 'in'} assert variant_type in {'instance', 'example'} fields = LineId.DESCRIPTOR.fields_by_name assert field_name in fields field = fields[field_name] string_operand = [] if field.cpp_type in {field.CPPTYPE_DOUBLE, field.CPPTYPE_FLOAT}: if op == 'between': assert all(isinstance(o, (int, float)) for o in operand) int_operand, float_operand = [], [float(o) for o in operand] else: int_operand, float_operand = [], [float(operand)] elif field.cpp_type in { field.CPPTYPE_INT32, field.CPPTYPE_INT64, field.CPPTYPE_UINT32, field.CPPTYPE_UINT64 }: if op in {'in', 'between'}: assert all(isinstance(o, int) for o in operand) int_operand, float_operand = list(operand), [] else: int_operand, float_operand = [int(operand)], [] elif field.cpp_type == field.CPPTYPE_STRING: int_operand, float_operand = [], [] if isinstance(operand, str): string_operand.append(operand) elif isinstance(operand, (list, tuple)): assert all(isinstance(o, str) for o in operand) string_operand.extend(operand) else: raise RuntimeError("params error!") else: raise RuntimeError("params error!") return ragged_data_ops.add_action(variant, field_name=field_name, op=op, float_operand=float_operand, int_operand=int_operand, string_operand=string_operand, actions=[action], variant_type=variant_type) @monolith_export def add_label( variant: tf.Tensor, config: str, negative_value: float, new_sample_rate: float, variant_type: str = 'example', ): """根据给定配置决定是否添加 label,支持 multi-task label 生成,请务必配合 filter_by_label 过滤算子同时使用,否则可能会有无效样本被喂入训练器。 举例 config='1,2:3:1.0;4::0.5',表示一共有两个 task(;分隔), task1 pos_actions = {1,2}, neg_actions = {3}, sample_rate = 1.0,而 task2 pos_actions = {4}, neg_actions 为空,sample_rate = 0.5 add_label 的执行逻辑如下 - 对于 task1,如果当前样本的 actions 包含 {1, 2} 任一个则判定为正例,否则根据给定 采样率决定是否采样(sample_rate < 1.0 方可触发采样),若触发采样且在采样范围内 标为负例,不在采样范围内置为无效 label,若未触发采样直接标记为负例。这个例子里由于 task1 的 sample_rate=1.0,因此不会触发负采样 - 对于 task2,如果当前样本的 actions 包含 {4} 则判定为正例,由于未指定 neg_actions 对于不包含 {4} 的样本直接进行负采样,在采样范围内标为负例,不在采样范围内置为 无效 label。这个例子里由于 task2 的 sample_rate=0.5,因此会对于不包含 {4} 的样本 触发负采样 Args: variant (:obj:`Tensor`): 输入数据,必须是 variant 类型 config (:obj:`str`): 形如 '1,2:3:1.0;4::0.5' negative_value (:obj:`float`): 如 -1.0 或 0.0 new_sample_rate (:obj:`float`): 为 LineId.sample_rate 赋值 variant_type (:obj:`str`): 'instance' 或 'example' Returns: variant tensor, 改写后的数据,variant 类型 """ assert variant_type in {'instance', 'example'} if variant_type != 'instance': add_feature('__LINE_ID__') assert config, 'Please specify config and retry!' assert 0 < new_sample_rate <= 1.0, 'new_sample_rate should be in (0, 1.0]' label_conf = LabelConf() for task in config.split(';'): # skip empty parts, e.g. config = '1,2:3:1.0;' if len(task) == 0: continue task_conf = label_conf.conf.add() pos_actions, neg_actions, sample_rate = task.split(':') pos_actions_list = [ int(pos) for pos in pos_actions.split(',') if len(pos) > 0 ] neg_actions_list = [ int(neg) for neg in neg_actions.split(',') if len(neg) > 0 ] task_conf.pos_actions.extend(pos_actions_list) task_conf.neg_actions.extend(neg_actions_list) task_conf.sample_rate = float(sample_rate) return ragged_data_ops.add_label(variant, config=label_conf.SerializeToString(), negative_value=negative_value, sample_rate=new_sample_rate, variant_type=variant_type) @monolith_export def scatter_label( variant: tf.Tensor, config: str, variant_type: str = 'example', ): """根据给定配置 scatter label 以支持 multi-task label 生成,配置形如 'chnid0:index0,chnid1:index1',请务必配合 filter_by_label 过滤算子使用, 否则可能会有无效样本被喂入训练器。举例 config='100:3,200:1,300:4', 表示一共有 5 个 task(最大的 index=4),scatter_label 的执行逻辑如下 1. 获取 label_value = label[0],亦即默认待处理样本的 label.size() > 0 2. 重置待处理样本的 label 长度为 5,并全部初始化为 INVALID_LABEL 3. if 样本的 chnid = 100,label[3] = label_value 4. else if 样本的 chnid = 200,label[1] = label_value 5. else if 样本的 chnid = 300,label[4] = label_value 6. else 样本的 chnid not in {100, 200, 300},则 label 中全部值为 INVALID_LABEL Args: variant (:obj:`Tensor`): 输入数据,必须是 variant 类型 config (:obj:`str`): 形如 '100:3,200:1,300:4' variant_type (:obj:`str`): 'instance' 或 'example' Returns: variant tensor, 改写后的数据,variant 类型 """ assert variant_type in {'instance', 'example'} if variant_type != 'instance': add_feature('__LABEL__') add_feature('__LINE_ID__') assert config, 'Please specify config and retry!' return ragged_data_ops.scatter_label(variant, config=config, variant_type=variant_type) @monolith_export def filter_by_label( variant: tf.Tensor, label_threshold: List[float], filter_equal: bool = False, variant_type: str = 'example', ) -> bool: """根据给定配置决定是否保留当前样本,支持 multi-task Args: variant (:obj:`Tensor`): 输入数据,必须是 variant 类型 label_threshold (:obj:`List[float]`): 样本任一 label 值 >= 相应 label_threshold 值则样本被保留,否则被丢弃。举例 label_threshold = [-100.0, 0.0],假设样本 - label = [-1000, -1],则该样本被丢弃,即不存在任何合法 label 值 - label = [-1000, 0],则该样本被保留,即第 2 个 label 值合法 - label = [-1, -1],则该样本被保留,即第 1 个 label 值合法 - label = [-1, 1],则该样本被保留,即第 1, 2 个 label 值均合法 filter_equal (:obj:`bool`): Whether to filter when label equals to threshold. variant_type (:obj:`str`): 'instance' 或 'example' Returns: valid tensor, 是否保留当前样本 """ assert variant_type in {'instance', 'example'} if variant_type != 'instance': add_feature('__LABEL__') assert len(label_threshold) > 0, 'Please specify label_threshold and retry!' return ragged_data_ops.filter_by_label(variant, label_threshold=label_threshold, filter_equal=filter_equal, variant_type=variant_type) @monolith_export def special_strategy(variant: tf.Tensor, strategy_list: List[int], strategy_conf: str = None, variant_type: str = 'instance', keep_empty_strategy=True): """用LineID中的special_strategy进行过滤, Args: variant (:obj:`Tensor`): 输入数据, 必须是variant类型 strategy_list (:obj:`List[int]`): strategy列表 strategy_conf (:obj:`str`): 配置方式为 `strategy:sample_rate:label`, 如果有多个可以用逗号分割. 用于实现采样, 包括对正例/负例/所有样本采样, 并修改样本标签 variant_type (:obj:`str`): variant类型, 可以为instance/example keep_empty_strategy (:obj:`bool`): 是否保留strategy为空的样本, 默认为False Returns: variant tensor, 过滤后的数据, variant类型 """ if variant_type != 'instance': add_feature('__LABEL__') add_feature('__LINE_ID__') items = [] if strategy_conf is None else strategy_conf.strip().split(',') special_strategies, sample_rates, labels = [], [], [] if len(items) > 0: for item in items: tl = item.strip().split(':') if len(tl) == 2: special_strategies.append(int(tl[0])) sample_rates.append(float(tl[1])) elif len(tl) == 3: special_strategies.append(int(tl[0])) sample_rates.append(float(tl[1])) labels.append(float(tl[2])) assert len(special_strategies) == len(sample_rates) assert len(special_strategies) == len(labels) or len(labels) == 0 assert all(0 <= sr <= 1 for sr in sample_rates) return ragged_data_ops.special_strategy( variant, special_strategies=special_strategies, sample_rates=sample_rates, labels=labels, strategy_list=strategy_list, keep_empty_strategy=keep_empty_strategy, variant_type=variant_type) @monolith_export def negative_sample(variant: tf.Tensor, drop_rate: float, label_index: int = 0, threshold: float = 0.0, variant_type: str = 'instance', action_priority: str = None, per_action_drop_rate: str = None): """负例采样 Args: variant (:obj:`Tensor`): 输入数据, 必须是variant类型 drop_rate (:obj:`float`): 负例丢弃比例, 取值区间为[0, 1), sample_rate = 1 - drop_rate. label_index (:obj:`int`): 样本中labels是一个列表, label_index表示本次启用哪一个index对应的label threshold (:obj:`float`): label是一个实数, 大于`threshold`的是正样本 variant_type (:obj:`str`): variant类型, 可以为instance/example action_priority (:obj:`str`): action的优先级列表, action用int表示, 以逗号分隔, 排在前面的优先级高 per_action_drop_rate (:obj:`str`): 基本单元是`action:drop_rate`, 可以用逗号分隔多个基本单元 Returns: variant tensor, 过滤后的数据, variant类型 """ if variant_type != 'instance': add_feature('__LABEL__') assert action_priority is None or isinstance(action_priority, str) assert per_action_drop_rate is None or isinstance(per_action_drop_rate, str) priority = [] actions, action_drop_rate = [], [] if action_priority and per_action_drop_rate: priority = [int(i) for i in action_priority.strip().split(",")] for item in per_action_drop_rate.strip().split(","): action, dr = item.strip().split(":") actions.append(int(action)) action_drop_rate.append(float(dr)) return ragged_data_ops.negative_sample(variant, drop_rate=drop_rate, label_index=label_index, threshold=threshold, variant_type=variant_type, priorities=priority, actions=actions, per_action_drop_rate=action_drop_rate) @monolith_export def feature_combine(src1: tf.RaggedTensor, src2: tf.RaggedTensor, slot: int) -> tf.RaggedTensor: """特征交叉, 用于对已抽取Sparse特征的交叉 Args: src1 (:obj:`RaggedTensor`): 参与交叉的sparse特征, 可以是简单特征, 也可以是序列特征 src1 (:obj:`RaggedTensor`): 参与交叉的sparse特征, 可以是简单特征, 也可以是序列特征 slot (:obj:`int`): 输出特征的slot Returns: RaggedTensor, 交叉后的特征 """ assert isinstance(src1, tf.RaggedTensor) assert isinstance(src2, tf.RaggedTensor) splits, values = ragged_data_ops.feature_combine( rt_nested_splits_src1=src1.nested_row_splits, rt_dense_values_src1=src1.flat_values, rt_nested_splits_src2=src2.nested_row_splits, rt_dense_values_src2=src2.flat_values, slot=slot, fid_version=2) if splits[0].dtype == tf.float32: return tf.RaggedTensor.from_row_splits(values=values, row_splits=splits[1], validate=False) else: return tf.RaggedTensor.from_nested_row_splits(flat_values=values, nested_row_splits=splits, validate=False) @monolith_export def switch_slot(ragged: tf.RaggedTensor, slot: int) -> tf.RaggedTensor: """对Sparse特征切换slot Args: ragged (:obj:`RaggedTensor`): 输入sparse特征, 可以是简单特征, 也可以是序列特征 slot (:obj:`int`): 输出特征的slot Returns: RaggedTensor, 切换后的特征 """ assert isinstance(ragged, tf.RaggedTensor) nested_row_splits = ragged.nested_row_splits splits, values = ragged_data_ops.switch_slot( rt_nested_splits=nested_row_splits, rt_dense_values=ragged.flat_values, slot=slot, fid_version=2) if splits[0].dtype == tf.float32: return tf.RaggedTensor.from_row_splits(values=values, row_splits=splits[1], validate=False) else: return ragged.with_flat_values(values) @monolith_export def switch_slot_batch(variant: tf.Tensor, features: Dict[str, Tuple[bool, int]], variant_type: str = 'example_batch', suffix: str = 'share') -> tf.Tensor: """对Sparse特征批量切换slot Args: variant (:obj:`VariantTensor`): 输入特征, 目前只支持pb格式 features (:obj:`dict`): 特征配置, 特征名 -> (是否原地修改, 新slot) variant_type (:obj:`str`): 输入variant的类型, 目前支持'example', 'example_batch'这两种 Returns: Variant Tensor, 切换后的特征 """ feats, slots, inplaces = [], [], [] for name, (inplace, slot) in features.items(): feats.append(name) inplaces.append(inplace) slots.append(slot) assert variant_type in {'example', 'example_batch'} output = ragged_data_ops.switch_slot_batch(variant, features=feats, slots=slots, inplaces=inplaces, suffix=suffix, variant_type=variant_type) return output @monolith_export def label_upper_bound(variant: tf.Tensor, label_upper_bounds: List[float], variant_type: str = 'instance'): """给label设置upper_bound, instance的label超过upper_bound的会被设置成upper_bound. Args: variant (:obj:`Tensor`): 输入数据,必须是 variant 类型 label_upper_bounds (:obj:`List[float]`): 样本任一 label 值 >= 相应 label_upper_bounds 时,该label会被设置为upper_bound variant_type (:obj:`str`): 'instance' 或 'example' Returns: variant tensor, label根据upper_bound调整后的数据, variant类型 """ assert variant_type in {'instance', 'example'} assert len( label_upper_bounds) > 0, 'Please specify label_threshold and retry!' return ragged_data_ops.label_upper_bound( variant, label_upper_bounds=label_upper_bounds, variant_type=variant_type) @monolith_export def label_normalization(variant: tf.Tensor, norm_methods: List[str], norm_values: List[float], variant_type: str = 'instance'): """对Label进行normalization, instance的label会被修改为norm之后的数值. Args: variant (:obj:`Tensor`): 输入数据,必须是 variant 类型 norm_methods (:obj:`List[str]`): normlization的方法,例如log,scale,repow,scalelog norm_values (:obj:`List[float]`): 对应normalization方法使用的norm_value, 长度需要与norm_methods保持一致 variant_type (:obj:`str`): 'instance' 或 'example' Returns: variant tensor, label根据upper_bound调整后的数据, variant类型 """ assert variant_type in {'instance', 'example'} assert len(norm_methods) == len( norm_values), 'norm_methods and norm_values should have the same length' return ragged_data_ops.label_normalization(variant, norm_methods=norm_methods, norm_values=norm_values, variant_type=variant_type) @monolith_export def use_field_as_label(variant: tf.Tensor, field_name: str, overwrite_invalid_value: bool = False, label_threshold: float = 7200, variant_type: str = 'instance'): """用line_id里的field作为新的label。 Args: variant (:obj:`Tensor`): 输入数据,必须是 variant 类型 overwrite_invalid_value (:obj:`bool`): 是否对新field进行overwrite,如果overwrite会在value >= label_threshold时overwrite成0. label_threshold (:obj:`List[float]`): 对新field进行overwrite的threshold值,如果value >= label_threshold则改写为0. variant_type (:obj:`str`): 'instance' 或 'example' Returns: variant tensor, label根据upper_bound调整后的数据, variant类型 """ assert variant_type in {'instance', 'example'} return ragged_data_ops.use_field_as_label( variant, field_name=field_name, overwrite_invalid_value=overwrite_invalid_value, label_threshold=label_threshold, variant_type=variant_type) def create_item_pool(start_num: int, max_item_num_per_channel: int, container: str = '', shared_name: str = '') -> tf.Tensor: assert start_num >= 0 and max_item_num_per_channel > 0 handle = ragged_data_ops.ItemPoolCreate( start_num=start_num, max_item_num_per_channel=max_item_num_per_channel, container=container, shared_name=shared_name) return handle def item_pool_random_fill(pool: tf.Tensor) -> tf.Tensor: handle = ragged_data_ops.ItemPoolRandomFill(ipool=pool) return handle def item_pool_check(pool: tf.Tensor, model_path: str, global_step: int, nshards: int = 1, buffer_size: int = 10 * 1024 * 1024) -> tf.Tensor: handle = ragged_data_ops.ItemPoolCheck(ipool=pool, model_path=model_path, nshards=nshards, buffer_size=buffer_size, global_step=global_step) return handle def save_item_pool(pool: tf.Tensor, global_step: tf.Tensor, model_path: str, nshards: int = 1) -> tf.Tensor: handle = ragged_data_ops.ItemPoolSave(ipool=pool, global_step=global_step, model_path=model_path, nshards=nshards) return handle def restore_item_pool(pool: tf.Tensor, global_step: tf.Tensor, model_path: str, nshards: int = 1, buffer_size: int = 10 * 1024 * 1024) -> tf.Tensor: handle = ragged_data_ops.ItemPoolRestore(ipool=pool, global_step=global_step, model_path=model_path, nshards=nshards, buffer_size=buffer_size) return handle def fill_multi_rank_output( variant: tf.Tensor, enable_draw_as_rank: bool = False, enable_chnid_as_rank: bool = False, enable_lineid_rank_as_rank: bool = False, rank_num: int = 18, variant_type: str = 'instance', ): """When use_rank_multi_output flag is set. """ assert variant_type in {'instance', 'example'} if variant_type != 'instance': add_feature('__LINE_ID__') return ragged_data_ops.fill_multi_rank_output( input=variant, variant_type=variant_type, enable_draw_as_rank=enable_draw_as_rank, enable_chnid_as_rank=enable_chnid_as_rank, enable_lineid_rank_as_rank=enable_lineid_rank_as_rank, rank_num=rank_num) def use_f100_multi_head( variant: tf.Tensor, variant_type: str = 'instance', ): """When use_f100_multihead flag is set. """ assert variant_type in {'instance', 'example'} return ragged_data_ops.use_f100_multi_head(input=variant, variant_type=variant_type) def map_id(tensor: tf.Tensor, map_dict: Dict[int, int], default: int = -1): assert map_dict is not None and len(map_dict) > 0 from_value, to_value = zip(*map_dict.items()) return ragged_data_ops.MapId(input=tensor, from_value=list(from_value), to_value=list(to_value), default_value=default) def multi_label_gen(variant: tf.Tensor, head_to_index: Dict[Any, int], head_field: str = 'chnid', pos_actions: List[int] = None, neg_actions: List[int] = None, use_origin_label: bool = False, pos_label: float = 1.0, neg_label: float = 0.0, action_priority: str = None, task_num: int = None, variant_type: str = 'example'): task_num = 0 if task_num is None else task_num head_to_index_list, max_idx = [], 0 for head, idx in head_to_index.items(): head_to_index_list.append(f'{head}:{idx}') max_idx = max(idx, max_idx) if task_num != 0: assert max_idx < task_num else: task_num = max_idx + 1 action_priority = action_priority or "" pos_actions, neg_actions = pos_actions or [], neg_actions or [] if use_origin_label: assert len(pos_actions) == 0 and len(neg_actions) == 0 else: assert len(pos_actions) > 0 fields = LineId.DESCRIPTOR.fields_by_name assert head_field in fields field = fields[head_field] assert field.cpp_type in { field.CPPTYPE_INT32, field.CPPTYPE_INT64, field.CPPTYPE_UINT32, field.CPPTYPE_UINT64, field.CPPTYPE_STRING } assert variant_type in {'instance', 'example'} return ragged_data_ops.multi_label_gen( variant, task_num=task_num, head_to_index=','.join(head_to_index_list), head_field=head_field, action_priority=action_priority, pos_actions=pos_actions, neg_actions=neg_actions, use_origin_label=use_origin_label, pos_label=pos_label, neg_label=neg_label, variant_type=variant_type) def string_to_variant(tensor: tf.Tensor, variant_type: str = 'example', has_header: bool = False, has_sort_id: bool = False, lagrangex_header: bool = False, kafka_dump_prefix: bool = False, kafka_dump: bool = False, chnids: List[int] = None, datasources: List[str] = None, default_datasource: str = ''): assert variant_type in { 'instance', 'example', 'examplebatch', 'example_batch' } return ragged_data_ops.string_to_variant( input=tensor, has_header=has_header, has_sort_id=has_sort_id, lagrangex_header=lagrangex_header, kafka_dump_prefix=kafka_dump_prefix, kafka_dump=kafka_dump, input_type=variant_type, chnids=chnids or [], datasources=datasources or [], default_datasource=default_datasource) #string_to_variant_with_transform example ''' dataset = dataset.flat_map(lambda v: tf.data.Dataset.from_tensors( string_to_variant_with_transform( v.message, input_type=variant_type.lower(), output_type=output_pb_type, has_header=has_header, lagrangex_header=self._lagrangex_header, has_sort_id=self._has_sort_id, kafka_dump=self._kafka_dump, kafka_dump_prefix=self._kafka_dump_prefix, chnids=self._chnids, datasources=self._datasources, default_datasource=self._default_datasource))) ''' def string_to_variant_with_transform(tensor: tf.Tensor, input_type: str = 'example', output_type: str = 'example', has_header: bool = False, has_sort_id: bool = False, lagrangex_header: bool = False, kafka_dump_prefix: bool = False, kafka_dump: bool = False, chnids: List[int] = None, datasources: List[str] = None, default_datasource: str = ''): assert input_type in {'instance', 'example', 'examplebatch', 'example_batch'} assert output_type in {'instance', 'example', 'examplebatch', 'example_batch'} return ragged_data_ops.string_to_variant_with_transform( input=tensor, has_header=has_header, has_sort_id=has_sort_id, lagrangex_header=lagrangex_header, kafka_dump_prefix=kafka_dump_prefix, kafka_dump=kafka_dump, input_type=input_type, output_type=output_type, chnids=chnids or [], datasources=datasources or [], default_datasource=default_datasource) def variant_to_zeros(tensor: tf.Tensor): return ragged_data_ops.variant_to_zeros(tensor) def kafka_resource_init(topics: List[str], metadata: List[str], input_pb_type: str = "", output_pb_type: str = "", has_sort_id: bool = False, lagrangex_header: bool = False, kafka_dump_prefix: bool = False, kafka_dump: bool = False, container: str = '', shared_name: str = ''): return ragged_data_ops.KafkaGroupReadableInit( topics=topics, metadata=metadata, has_sort_id=has_sort_id, lagrangex_header=lagrangex_header, kafka_dump_prefix=kafka_dump_prefix, kafka_dump=kafka_dump, input_pb_type=input_pb_type, output_pb_type=output_pb_type, container=container, shared_name=shared_name) def kafka_read_next(input, index: int, message_poll_timeout: int, stream_timeout: int): return ragged_data_ops.KafkaGroupReadableNext( input=input, index=index, message_poll_timeout=message_poll_timeout, stream_timeout=stream_timeout) def kafka_read_next_v2(input, index: int, message_poll_timeout: int, stream_timeout: int): return ragged_data_ops.KafkaGroupReadableNextV2( input=input, index=index, message_poll_timeout=message_poll_timeout, stream_timeout=stream_timeout) def has_variant(input, variant_type: str = 'example'): return ragged_data_ops.HasVariant(input=input, variant_type=variant_type) def gen_fid_mask(tenosr: tf.RaggedTensor, fid: int) -> tf.Tensor: fid = np.uint64(fid).astype(np.int64) return ragged_data_ops.monolith_gen_fid_mask(tenosr.row_splits, tenosr.flat_values, fid=fid) @monolith_export def tf_example_to_example(serialized: tf.Tensor, sparse_features: Dict[str, int], dense_features: List[str], label: str, instance_weight: str = None): """ 将序列化的 tf.example 转换为 Monolith Example,在转换的同时,指定的 sparse_features 会被抽取成 FID Args: serialized (:obj:`Tensor`): tf.example 的序列化数据,string 类型 sparse_features (:obj:`Dict[str, int]`): sparse feature name 到 slot id 的映射, 举例:sparse_features = {"user_id": 1, "item_id": 2, "posterior_ctr": 3}, 1. "user_id" 原始类型为 int64,它将被抽取成 FID,存入 Monolith Example 的 fid_v2_list,对应 slot_id=1 2. "item_id" 原始类型为 int64,它将被抽取成 FID,存入 Monolith Example 的 fid_v2_list,对应 slot_id=2 3. "posterior_ctr" 原始类型为 float32,它将被抽取成 FID,存入 Monolith Example 的 fid_v2_list,对应 slot_id=3 dense_features (:obj:`List[str]`): 指定的这些字段将直接 Copy 到 Monolith Example 中 label (:obj:`str`): 存储在 tf.example 中的哪个字段是 label instance_weight (:obj:`str`): 存储在 tf.example 中的哪个字段是 instance_weight Returns: variant tensor: Monolith Example 格式的 variant tensor """ ## default value setting sparse_features = sparse_features or [] dense_features = dense_features or [] label = label or "" instance_weight = instance_weight or "" ## validity check intersection = set(sparse_features.keys()) & set(dense_features) assert len( intersection ) == 0, f"{intersection} occur in sparse_features and dense_features simultaneously, please investigate and retry!" assert label not in sparse_features, f"label: {label} should NOT occur in sparse_features, please investigate and retry!" assert label not in dense_features, f"label: {label} should NOT occur in dense_features, please investigate and retry!" assert instance_weight not in sparse_features, f"instance_weight: {instance_weight} should NOT occur in sparse_features, please investigate and retry!" assert instance_weight not in dense_features, f"instance_weight: {instance_weight} should NOT occur in dense_features, please investigate and retry!" slot_ids = list(sparse_features.values()) duplicates = {slot for slot in slot_ids if slot_ids.count(slot) > 1} assert len( duplicates ) == 0, f"{duplicates} have multiple sparse feature name mapping, please investigate and retry!" for slot_id in slot_ids: assert 0 < slot_id < 32768, "slot_id should be in [1, 32768)" ## generate feature_description proto feature_description = TFRecordFeatureDescription() for k, v in sparse_features.items(): feature_description.sparse_features[k] = v feature_description.dense_features.extend(dense_features) feature_description.label = label feature_description.instance_weight = instance_weight return ragged_data_ops.MonolithTFExampleToExample( input=serialized, feature_description=feature_description.SerializeToString()) ================================================ FILE: monolith/native_training/data/feature_utils_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import io from absl import logging import os import uuid import random import struct import tensorflow as tf import tempfile import threading from typing import List, BinaryIO from idl.matrix.proto import proto_parser_pb2, example_pb2, feature_pb2 from monolith.native_training.data.datasets import PBDataset, PbType from monolith.native_training.data.parsers import parse_instances, \ parse_examples from monolith.native_training.model_export.data_gen_utils import lg_header, sort_header from monolith.native_training.data.feature_utils import ( add_action, add_label, feature_combine, filter_by_fids, filter_by_label, filter_by_feature_value, filter_by_value, scatter_label, switch_slot, switch_slot_batch, map_id, use_field_as_label, label_upper_bound, label_normalization, multi_label_gen, string_to_variant, variant_to_zeros, has_variant, negative_sample, gen_fid_mask) fid_v1_mask = (1 << 54) - 1 fid_v2_mask = (1 << 48) - 1 def get_fid_v1(slot: int, signautre: int): return (slot << 54) | (signautre & fid_v1_mask) def get_fid_v2(slot: int, signature: int): return (slot << 48) | (signature & fid_v2_mask) features = { 'f_spm_1': 301, 'f_spm_3': 303, 'f_spm_2': 302, 'f_spm_4': 304, 'f_user_id': 1, 'f_user_ctx_network': 61, 'f_user_id-f_page': 504, 'f_scm': 306, 'f_goods_id': 200, 'f_goods_sale_number_1000': 225, 'f_goods_praise_cnt': 229, 'f_spm': 300, 'f_page': 305, 'f_is_dup': 310, 'f_user_ctx_platform': 52, 'f_goods_title_terms': 209, 'f_goods_tags_terms': 211, 'f_user_test09_array_int32': 554, 'f_user_test15_array_float': 540, 'f_user_test14_array_bool': 543, 'f_user_test12_array_uint64': 551, 'f_user_test10_array_int64': 549 } group_slots = [ 200, 201, 202, 203, 204, 205, 206, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242 ] def parse_instance_or_example(tensor: tf.Tensor, out_type, extra_sparse_features: List[str] = None): fidv1_features = [ 1, 2, 32, 33, 36, 38, 42, 50, 54, 56, 60, 66, 120, 150, 180, 182, 192, 220, 333, 410, 412, 422, 446 ] if out_type == PbType.INSTANCE: return parse_instances(tensor, fidv1_features, dense_features=['label'], dense_feature_shapes=[2], dense_feature_types=[tf.float32], extra_features=[ 'uid', 'req_time', 'item_id', 'actions', 'video_finish_percent' ], extra_feature_shapes=[1, 1, 1, 2, 1]) else: return parse_examples( tensor, sparse_features=[f'fc_slot_{slot}' for slot in fidv1_features] + \ extra_sparse_features if extra_sparse_features else [], dense_features=['label'], dense_feature_shapes=[2], dense_feature_types=[tf.float32], extra_features=[ 'uid', 'req_time', 'item_id', 'actions', 'video_finish_percent' ], extra_feature_shapes=[1, 1, 1, 3, 1]) def parse_example_batch(tensor: tf.Tensor, out_type, extra_sparse_features: List[str] = None): if out_type == PbType.INSTANCE: feature_dict = parse_instances(tensor, fidv1_features=list(features.values()), dense_features=['label'], dense_feature_shapes=[2], dense_feature_types=[tf.float32], extra_features=[ 'uid', 'req_time', 'item_id', 'actions', 'video_finish_percent' ], extra_feature_shapes=[1, 1, 1, 3, 1]) else: feature_dict = parse_examples(tensor, sparse_features=list(features.keys()) + \ extra_sparse_features if extra_sparse_features else [], dense_features=['label'], dense_feature_shapes=[2], dense_feature_types=[tf.float32], extra_features=[ 'uid', 'req_time', 'item_id', 'actions', 'video_finish_percent' ], extra_feature_shapes=[1, 1, 1, 3, 1]) # print(feature_dict) # feature_dict['f_page'] = switch_slot(feature_dict['f_page'], slot=306) # feature_dict['f_user_id-f_goods_tags_terms'] = feature_combine( # feature_dict['f_user_id'], feature_dict['f_goods_tags_terms'], slot=505) return feature_dict def line_id_to_feature(name, value): tmp_feature = feature_pb2.Feature() tmp_feature.name = name value_mapping = { str: lambda val: tmp_feature.bytes_value.append(val.encode()), int: tmp_feature.int64_value.append, float: tmp_feature.float_value.append, } if isinstance(value, list): if len(value) > 0: value_type = type(value[0]) else: value_type = type(value) append_func = value_mapping.get(value_type) if append_func: if isinstance(value, list): for val in value: append_func(val) else: append_func(value) return tmp_feature def generate_instance(labels: List[int], actions: List[int], chnid: int = None, did: str = None, fid_v1_list: List[int] = None, device_type: str = None, req_id: str = None, chnids: List[int] = None, video_play_time: float = None, write_line_id_to_feature: bool = False, shuffle_features: bool = False): instance = proto_parser_pb2.Instance() instance.fid.extend(fid_v1_list if fid_v1_list else []) instance.label.extend(labels) instance.line_id.user_id = "test_{}".format(uuid.uuid4()) instance.line_id.uid = 100 instance.line_id.sample_rate = 0.5 instance.line_id.actions.extend(actions) features = [] if chnid is not None: instance.line_id.chnid = chnid if write_line_id_to_feature: features.append(line_id_to_feature('chnid', chnid)) if did is not None: instance.line_id.did = did if write_line_id_to_feature: features.append(line_id_to_feature('did', did)) if device_type is not None: instance.line_id.device_type = device_type if req_id is not None: instance.line_id.req_id = req_id if write_line_id_to_feature: features.append(line_id_to_feature('req_id', req_id)) if chnids is not None and write_line_id_to_feature: features.append(line_id_to_feature('chnids', chnids)) if video_play_time is not None: instance.line_id.video_play_time = video_play_time if write_line_id_to_feature: features.append(line_id_to_feature('video_play_time', video_play_time)) if shuffle_features: random.shuffle(features) instance.feature.extend(features) return instance def write_instance_into_file(file: BinaryIO, instance): sort_id = str(instance.line_id.user_id) file.write(struct.pack('Q', size_binary)[0] sort_id = stream.read(size) else: sort_id = "" # proto. if has_kafka_dump: stream.read(size_t) # size + proto_binary # This is the proto part. size_binary = stream.read(size_t)[::-1] size = struct.unpack('>Q', size_binary)[0] proto_binary = stream.read(size) instance = proto_parser_pb2.Instance() instance.ParseFromString(proto_binary) return instance class DataOpsTest(tf.test.TestCase): def pb_dataset_target(self, input_pb_type, output_pb_type, filter_fn=None, add_action_fn=None, return_result_key='actions', num_return_items=2): if input_pb_type == PbType.INSTANCE: lagrangex_header = False has_sort_id, kafka_dump, kafka_dump_prefix = True, True, False file_name = "monolith/native_training/data/training_instance/instance.pb" elif input_pb_type == PbType.EXAMPLE: lagrangex_header = False has_sort_id, kafka_dump, kafka_dump_prefix = True, True, False file_name = "monolith/native_training/data/training_instance/example.pb" else: lagrangex_header = True has_sort_id, kafka_dump, kafka_dump_prefix = False, False, False file_name = "monolith/native_training/data/training_instance/examplebatch.data" def parser(tensor: tf.Tensor): if output_pb_type == PbType.PLAINTEXT: return parse_instance_or_example(tensor, input_pb_type) elif input_pb_type != PbType.EXAMPLEBATCH: return parse_instance_or_example(tensor, output_pb_type) else: return parse_example_batch(tensor, output_pb_type) with tf.Graph().as_default(): config = tf.compat.v1.ConfigProto() config.graph_options.rewrite_options.disable_meta_optimizer = True with self.session(config=config) as sess: dataset = PBDataset(file_name=file_name, lagrangex_header=lagrangex_header, has_sort_id=has_sort_id, kafka_dump=kafka_dump, kafka_dump_prefix=kafka_dump_prefix, input_pb_type=input_pb_type, output_pb_type=output_pb_type) if add_action_fn is not None: dataset = dataset.map(add_action_fn) if input_pb_type == PbType.EXAMPLEBATCH: variant_type = 'instance' if output_pb_type == PbType.INSTANCE else 'example' dataset = dataset.instance_reweight( action_priority="2,7,0,1,3,4,5,6,8,9,10,11", reweight= "0:0:1,1:0:1,2:3:-1,3:0:1,4:0:1,5:0:1,6:0:1,7:6:1,8:0:1,9:0:1,10:0:1,11:0:-1", variant_type=variant_type) if filter_fn is not None: dataset = dataset.filter(filter_fn) batch_size = 4 dataset = dataset.batch(batch_size, drop_remainder=True).map(parser) it = tf.compat.v1.data.make_one_shot_iterator(dataset) element = it.get_next() results = list() for _ in range(num_return_items): try: element_result = sess.run(element) results.append(element_result[return_result_key]) except tf.errors.OutOfRangeError: break return results def test_input_instance_output_instance(self): actions = self.pb_dataset_target(input_pb_type=PbType.INSTANCE, output_pb_type=PbType.INSTANCE) self.assertAllEqual(actions[0], [[1, 0], [1, 0], [1, 0], [1, 0]]) self.assertAllEqual(actions[1], [[1, 0], [1, 0], [1, 0], [1, 0]]) def test_input_instance_output_instance_add_action(self): actions = self.pb_dataset_target( input_pb_type=PbType.INSTANCE, output_pb_type=PbType.INSTANCE, add_action_fn=lambda variant: add_action( variant, 'sample_rate', 'ge', 0, 2, variant_type='instance')) self.assertAllEqual(actions[0], [[1, 2], [1, 2], [1, 2], [1, 2]]) self.assertAllEqual(actions[1], [[1, 2], [1, 2], [1, 2], [1, 2]]) def test_input_instance_output_example(self): actions = self.pb_dataset_target(input_pb_type=PbType.INSTANCE, output_pb_type=PbType.EXAMPLE) self.assertAllEqual(actions[0], [[1, 0, 0], [1, 0, 0], [1, 0, 0], [1, 0, 0]]) self.assertAllEqual(actions[1], [[1, 0, 0], [1, 0, 0], [1, 0, 0], [1, 0, 0]]) def test_input_instance_output_example_add_action(self): actions = self.pb_dataset_target(input_pb_type=PbType.INSTANCE, output_pb_type=PbType.EXAMPLE, add_action_fn=lambda variant: add_action( variant, 'req_time', 'between', [1622667900, 1622667911], 2, variant_type='example')) self.assertAllEqual(actions[0], [[1, 2, 0], [1, 2, 0], [1, 2, 0], [1, 2, 0]]) self.assertAllEqual(actions[1], [[1, 2, 0], [1, 0, 0], [1, 0, 0], [1, 0, 0]]) def test_input_example_output_instance(self): actions = self.pb_dataset_target(input_pb_type=PbType.EXAMPLE, output_pb_type=PbType.INSTANCE) self.assertAllEqual(actions[0], [[1, 0], [1, 0], [1, 0], [1, 0]]) self.assertAllEqual(actions[1], [[1, 0], [1, 0], [1, 0], [1, 0]]) def test_input_example_output_instance_add_action(self): actions = self.pb_dataset_target( input_pb_type=PbType.EXAMPLE, output_pb_type=PbType.INSTANCE, add_action_fn=lambda variant: add_action(variant, 'req_time', 'in', [1622667900, 1622667911], 2, variant_type='instance')) self.assertAllEqual(actions[0], [[1, 2], [1, 2], [1, 0], [1, 2]]) self.assertAllEqual(actions[1], [[1, 2], [1, 2], [1, 2], [1, 0]]) def test_input_example_output_example(self): actions = self.pb_dataset_target(input_pb_type=PbType.EXAMPLE, output_pb_type=PbType.EXAMPLE) self.assertAllEqual(actions[0], [[1, 0, 0], [1, 0, 0], [1, 0, 0], [1, 0, 0]]) self.assertAllEqual(actions[1], [[1, 0, 0], [1, 0, 0], [1, 0, 0], [1, 0, 0]]) def test_input_example_output_example_add_action(self): actions = self.pb_dataset_target( input_pb_type=PbType.EXAMPLE, output_pb_type=PbType.EXAMPLE, add_action_fn=lambda variant: add_action( variant, 'uid', 'eq', 62975225690081677, 2, variant_type='example')) self.assertAllEqual(actions[0], [[1, 0, 0], [1, 0, 0], [1, 2, 0], [1, 0, 0]]) self.assertAllEqual(actions[1], [[1, 0, 0], [1, 0, 0], [1, 2, 0], [1, 0, 0]]) def test_input_example_batch_output_instance(self): actions = self.pb_dataset_target(input_pb_type=PbType.EXAMPLEBATCH, output_pb_type=PbType.INSTANCE) self.assertAllEqual(actions[0], [[2, 0, 0], [2, 0, 0], [2, 0, 0], [2, 0, 0]]) self.assertAllEqual(actions[1], [[2, 0, 0], [2, 0, 0], [2, 0, 0], [2, 0, 0]]) def test_input_example_batch_output_instance_add_action(self): actions = self.pb_dataset_target( input_pb_type=PbType.EXAMPLEBATCH, output_pb_type=PbType.INSTANCE, add_action_fn=lambda variant: add_action(variant, 'video_finish_percent', 'ge', 0, 3, variant_type='instance')) self.assertAllEqual(actions[0], [[2, 3, 0], [2, 3, 0], [2, 3, 0], [2, 3, 0]]) self.assertAllEqual(actions[1], [[2, 3, 0], [2, 3, 0], [2, 3, 0], [2, 3, 0]]) def test_input_example_batch_output_example(self): actions = self.pb_dataset_target(input_pb_type=PbType.EXAMPLEBATCH, output_pb_type=PbType.EXAMPLE) self.assertAllEqual(actions[0], [[2, 0, 0], [2, 0, 0], [2, 0, 0], [2, 0, 0]]) self.assertAllEqual(actions[1], [[2, 0, 0], [2, 0, 0], [2, 0, 0], [2, 0, 0]]) def test_input_example_batch_output_example_add_action(self): actions = self.pb_dataset_target( input_pb_type=PbType.EXAMPLEBATCH, output_pb_type=PbType.EXAMPLE, add_action_fn=lambda variant: add_action( variant, 'video_finish_percent', 'le', 0, 3, variant_type='example') ) self.assertAllEqual(actions[0], [[2, 3, 0], [2, 3, 0], [2, 3, 0], [2, 3, 0]]) self.assertAllEqual(actions[1], [[2, 3, 0], [2, 3, 0], [2, 3, 0], [2, 3, 0]]) def test_input_instance_output_instance_add_label(self): mock_batch_num = 100 add_label_config = '1,2:3:1.0;4::0.5' def mock_instance_for_add_label(batch_num: int = 200): tmpfile = tempfile.mkstemp()[1] labels = [[], [], [], []] # for task1: 1,2 -> positive, 3 -> negative # for task2: 4 -> positive, other -> negative/invalid(depends on sampling) # instance1: task1 -> positive, task2 -> positive # instance2: task1 -> positive, task2 -> negative/invalid(depends on sampling) # instance3: task1 -> negative, task2 -> positive # instance4: task1 -> invalid, task2 -> negative/invalid(depends on sampling) actions = [[1, 2, 4], [1], [3, 4], [5]] with io.open(tmpfile, 'wb') as writer: for _ in range(batch_num): for label, action in zip(labels, actions): instance = generate_instance(label, action) write_instance_into_file(writer, instance) return tmpfile file_name = mock_instance_for_add_label(mock_batch_num) logging.info('file_name: %s', file_name) def parser(tensor: tf.Tensor): return parse_instances(tensor, fidv1_features=list(features.values()), dense_features=['label'], dense_feature_shapes=[2], dense_feature_types=[tf.float32], extra_features=[ 'uid', 'req_time', 'item_id', 'actions', 'video_finish_percent' ], extra_feature_shapes=[1, 1, 1, 3, 1]) with tf.Graph().as_default(): config = tf.compat.v1.ConfigProto() config.graph_options.rewrite_options.disable_meta_optimizer = True with self.session(config=config) as sess: dataset = PBDataset(file_name=file_name, lagrangex_header=False, has_sort_id=True, kafka_dump=False, kafka_dump_prefix=False, input_pb_type=PbType.INSTANCE, output_pb_type=PbType.INSTANCE) dataset = dataset.map( lambda variant: add_label(variant, config=add_label_config, negative_value=-1.0, new_sample_rate=1.0, variant_type='instance')) dataset = dataset.filter(lambda variant: filter_by_label( variant, label_threshold=[-100, -100], variant_type='instance')) batch_size = 4 dataset = dataset.batch(batch_size, drop_remainder=False).map(parser) it = tf.compat.v1.data.make_one_shot_iterator(dataset) valid_instance_num = 0 for _ in range(mock_batch_num): try: element = it.get_next() element_result = sess.run(element) valid_instance_num += len(element_result['label']) except tf.errors.OutOfRangeError: break logging.info('Valid instance number: %d', valid_instance_num) self.assertAllInRange(valid_instance_num, 340, 360) os.remove(file_name) def test_input_instance_output_instance_label_upper_bound(self): labels = self.pb_dataset_target( input_pb_type=PbType.INSTANCE, output_pb_type=PbType.INSTANCE, add_action_fn=lambda variant: label_upper_bound( variant, label_upper_bounds=[0.5, 0.5], variant_type='instance'), return_result_key='label') self.assertAllEqual(labels[0], [[0, 0.5], [0, 0.5], [0, 0.5], [0, 0.5]]) self.assertAllEqual(labels[1], [[0, 0.5], [0, 0.5], [0, 0.5], [0, 0.5]]) def test_input_instance_output_instance_label_normalization(self): labels = self.pb_dataset_target( input_pb_type=PbType.INSTANCE, output_pb_type=PbType.INSTANCE, add_action_fn=lambda variant: label_normalization( variant, norm_methods=['scale', 'repow'], norm_values=[0.5, 3], variant_type='instance'), return_result_key='label') self.assertAllEqual(labels[0], [[0, 8], [0, 8], [0, 8], [0, 8]]) self.assertAllEqual(labels[1], [[0, 8], [0, 8], [0, 8], [0, 8]]) def test_input_examplebatch_output_instance_use_field_as_label(self): labels = self.pb_dataset_target( input_pb_type=PbType.INSTANCE, output_pb_type=PbType.INSTANCE, add_action_fn=lambda variant: use_field_as_label( variant, 'sample_rate', False, 0, variant_type='instance'), return_result_key='label') self.assertAllEqual(labels[0], [[1, 1], [1, 1], [1, 1], [1, 1]]) self.assertAllEqual(labels[1], [[1, 1], [1, 1], [1, 1], [1, 1]]) labels = self.pb_dataset_target( input_pb_type=PbType.INSTANCE, output_pb_type=PbType.INSTANCE, add_action_fn=lambda variant: use_field_as_label(label_upper_bound( variant, label_upper_bounds=[0.5, 0.5], variant_type='instance'), 'sample_rate', True, 1.1, variant_type='instance' ), return_result_key='label') # Original label is [0, 0.5], new label = max(original_label, [1, 1]) self.assertAllEqual(labels[0], [[1, 1], [1, 1], [1, 1], [1, 1]]) self.assertAllEqual(labels[1], [[1, 1], [1, 1], [1, 1], [1, 1]]) labels = self.pb_dataset_target( input_pb_type=PbType.INSTANCE, output_pb_type=PbType.INSTANCE, add_action_fn=lambda variant: use_field_as_label(label_upper_bound( variant, label_upper_bounds=[0.5, 0.5], variant_type='instance'), 'sample_rate', True, 0.9, variant_type='instance' ), return_result_key='label') # Original label is [0, 0.5], new label = max(original_label, [0, 0]) self.assertAllEqual(labels[0], [[0, 0.5], [0, 0.5], [0, 0.5], [0, 0.5]]) self.assertAllEqual(labels[1], [[0, 0.5], [0, 0.5], [0, 0.5], [0, 0.5]]) def test_input_instance_output_instance_filter_by_label_equals(self): labels = self.pb_dataset_target( input_pb_type=PbType.INSTANCE, output_pb_type=PbType.INSTANCE, filter_fn=lambda variant: filter_by_label(variant, label_threshold=[0, 1], filter_equal=False, variant_type='instance'), return_result_key='label', num_return_items=100) self.assertEqual(len(labels), 100) self.assertAllEqual(labels[0], [[0, 1], [0, 1], [0, 1], [0, 1]]) self.assertAllEqual(labels[1], [[0, 1], [0, 1], [0, 1], [0, 1]]) labels = self.pb_dataset_target( input_pb_type=PbType.INSTANCE, output_pb_type=PbType.INSTANCE, filter_fn=lambda variant: filter_by_label(variant, label_threshold=[0, 1], filter_equal=True, variant_type='instance'), return_result_key='label', num_return_items=100) self.assertEqual(len(labels), 49) self.assertAllEqual(labels[0], [[0, 2], [0, 2], [0, 2], [0, 2]]) self.assertAllEqual(labels[1], [[0, 2], [0, 2], [0, 2], [0, 2]]) def test_input_instance_output_instance_scatter_label(self): mock_batch_num = 1 scatter_label_config = '100:3,200:1,300:4' def mock_instance_for_scatter_label(batch_num: int = 200): tmpfile = tempfile.mkstemp()[1] labels = [[1], [2], [3], []] actions = [[], [], [], []] chnids = [0, 100, 200, 300] with io.open(tmpfile, 'wb') as writer: for _ in range(batch_num): for label, action, chnid in zip(labels, actions, chnids): instance = generate_instance(label, action, chnid) write_instance_into_file(writer, instance) return tmpfile file_name = mock_instance_for_scatter_label(mock_batch_num) logging.info('file_name: %s', file_name) def parser(tensor: tf.Tensor): return parse_instances(tensor, fidv1_features=list(features.values()), dense_features=['label'], dense_feature_shapes=[5], dense_feature_types=[tf.float32], extra_features=[ 'uid', 'req_time', 'item_id', 'actions', 'video_finish_percent' ], extra_feature_shapes=[1, 1, 1, 3, 1]) with tf.Graph().as_default(): config = tf.compat.v1.ConfigProto() config.graph_options.rewrite_options.disable_meta_optimizer = True with self.session(config=config) as sess: dataset = PBDataset(file_name=file_name, lagrangex_header=False, has_sort_id=True, kafka_dump=False, kafka_dump_prefix=False, input_pb_type=PbType.INSTANCE, output_pb_type=PbType.INSTANCE) dataset = dataset.map(lambda variant: scatter_label( variant, config=scatter_label_config, variant_type='instance')) dataset = dataset.filter(lambda variant: filter_by_label( variant, label_threshold=[-100, -100, -100, -100, -100], variant_type='instance')) batch_size = 4 dataset = dataset.batch(batch_size, drop_remainder=False).map(parser) it = tf.compat.v1.data.make_one_shot_iterator(dataset) try: element = it.get_next() element_result = sess.run(element) self.assertAllEqual(len(element_result['label']), 2) self.assertAllClose( element_result['label'], [[ -3.4028235e+38, -3.4028235e+38, -3.4028235e+38, 2.0000000e+00, -3.4028235e+38 ], [ -3.4028235e+38, 3.0000000e+00, -3.4028235e+38, -3.4028235e+38, -3.4028235e+38 ]]) except tf.errors.OutOfRangeError: self.assertTrue(False) os.remove(file_name) def test_filter_by_bytes_value(self): mock_batch_num = 200 def mock_instance_for_filter_by_bytes_value(batch_num: int = 200): tmpfile = tempfile.mkstemp()[1] labels = [[1], [2], [3], []] actions = [[], [], [], []] req_ids = ['abckjhfjh', 'kjhfjh', 'huggfyfixyz', ''] chnids = [10, 20, 30, 40] dids = ['hello', 'world', '300', '400'] with io.open(tmpfile, 'wb') as writer: for _ in range(batch_num): for label, action, chnid, did, req_id in zip(labels, actions, chnids, dids, req_ids): instance = generate_instance(label, action, chnid=chnid, did=did, req_id=req_id, write_line_id_to_feature=True, shuffle_features=True) write_instance_into_file(writer, instance) return tmpfile file_name = mock_instance_for_filter_by_bytes_value(mock_batch_num) logging.info('file_name: %s', file_name) def parser(tensor: tf.Tensor): return parse_examples(tensor, sparse_features=list(features.keys()), dense_features=['label', 'req_id'], dense_feature_shapes=[5, 1], dense_feature_types=[tf.float32, tf.string], extra_features=['uid', 'req_time', 'did'], extra_feature_shapes=[1, 1, 1]) config = tf.compat.v1.ConfigProto() config.graph_options.rewrite_options.disable_meta_optimizer = True def filter_fn(ts): return filter_by_value(ts, field_name='req_id', op='endswith', variant_type='example', operand=['kjhfjh', 'huggfyfi']) def feature_filter_fn(ts): return filter_by_feature_value(ts, field_name='req_id', op='endswith', operand=['kjhfjh', 'huggfyfi'], field_type='bytes') with self.session(config=config) as sess: dataset = PBDataset(file_name=file_name, lagrangex_header=False, has_sort_id=True, kafka_dump=False, kafka_dump_prefix=False, input_pb_type=PbType.INSTANCE, output_pb_type=PbType.EXAMPLE) dataset_filter = dataset.filter(filter_fn) dataset_feature_filter = dataset.filter(feature_filter_fn) batch_size = 4 dataset_filter = dataset_filter.batch(batch_size, drop_remainder=False).map(parser) dataset_feature_filter = dataset_feature_filter.batch(batch_size, drop_remainder=False).map(parser) num_parallel_calls = 4 dataset_feature_filter_parallel = dataset.map(map_func=lambda x: x, num_parallel_calls=num_parallel_calls) \ .filter(feature_filter_fn) \ .batch(100, drop_remainder=False).map(parser) try: it = tf.compat.v1.data.make_one_shot_iterator(dataset_filter) element = it.get_next() result = sess.run(element) self.assertAllEqual(len(result['req_id']), 4) self.assertAllEqual(result['req_id'], [[b'abckjhfjh'], [b'kjhfjh'], [b'abckjhfjh'], [b'kjhfjh']]) it = tf.compat.v1.data.make_one_shot_iterator(dataset_feature_filter) element = it.get_next() result = sess.run(element) self.assertAllEqual(len(result['req_id']), 4) self.assertAllEqual(result['req_id'], [[b'abckjhfjh'], [b'kjhfjh'], [b'abckjhfjh'], [b'kjhfjh']]) # test for parallelism ("cached_feature_index" in kernel impl) it = tf.compat.v1.data.make_one_shot_iterator(dataset_feature_filter_parallel) element = it.get_next() result = sess.run(element) self.assertAllEqual(len(result['req_id']), 100) self.assertAllEqual(result['req_id'], [[b'abckjhfjh'], [b'kjhfjh']] * 50) except tf.errors.OutOfRangeError: self.assertTrue(False) os.remove(file_name) def test_filter_by_float_value(self): mock_batch_num = 200 def mock_instance_for_filter_by_float_value(batch_num: int = 200): tmpfile = tempfile.mkstemp()[1] labels = [[1], [2], [3], []] actions = [[], [], [], []] req_ids = ['abckjhfjh', 'kjhfjh', 'huggfyfixyz', 'mbzc'] video_play_times = [1.0, 2.0, 3.0, 4.0] with io.open(tmpfile, 'wb') as writer: for _ in range(batch_num): for label, action, req_id, video_play_time in zip(labels, actions, req_ids, video_play_times): instance = generate_instance(label, action, req_id=req_id, video_play_time=video_play_time, write_line_id_to_feature=True, shuffle_features=True) write_instance_into_file(writer, instance) return tmpfile file_name = mock_instance_for_filter_by_float_value(mock_batch_num) logging.info('file_name: %s', file_name) def parser(tensor: tf.Tensor): return parse_examples(tensor, sparse_features=list(features.keys()), dense_features=['label', 'req_id'], dense_feature_shapes=[5, 1], dense_feature_types=[tf.float32, tf.string], extra_features=['uid', 'req_time', 'did'], extra_feature_shapes=[1, 1, 1]) config = tf.compat.v1.ConfigProto() config.graph_options.rewrite_options.disable_meta_optimizer = True def feature_filter_fn(ts): return filter_by_feature_value(ts, field_name='video_play_time', op='gt', operand=2.5, field_type='float') with self.session(config=config) as sess: dataset = PBDataset(file_name=file_name, lagrangex_header=False, has_sort_id=True, kafka_dump=False, kafka_dump_prefix=False, input_pb_type=PbType.INSTANCE, output_pb_type=PbType.EXAMPLE) dataset_feature_filter = dataset.filter(feature_filter_fn) batch_size = 4 dataset_feature_filter = dataset_feature_filter.batch(batch_size, drop_remainder=False).map(parser) try: it = tf.compat.v1.data.make_one_shot_iterator(dataset_feature_filter) element = it.get_next() result = sess.run(element) self.assertAllEqual(len(result['req_id']), 4) self.assertAllEqual(result['req_id'], [[b'huggfyfixyz'], [b'mbzc'], [b'huggfyfixyz'], [b'mbzc']]) except tf.errors.OutOfRangeError: self.assertTrue(False) os.remove(file_name) def test_filter_by_value_not_in(self): mock_batch_num = 1 def mock_instance_for_filter_by_value(batch_num: int = 200): tmpfile = tempfile.mkstemp()[1] labels = [[1], [2], [3], [], []] actions = [[], [], [], [], []] chnids = [10, 20, 30, 40, 666] dids = ['hello', 'world', 'excluded', '300', '400'] with io.open(tmpfile, 'wb') as writer: for _ in range(batch_num): for label, action, chnid, did in zip(labels, actions, chnids, dids): instance = generate_instance(label, action, chnid, did, write_line_id_to_feature=True) write_instance_into_file(writer, instance) return tmpfile file_name = mock_instance_for_filter_by_value(mock_batch_num) logging.info('file_name: %s', file_name) # generate FilterValues serialized files tmp_filter_values_file_string = tempfile.mkstemp()[1] with tf.io.gfile.GFile(tmp_filter_values_file_string, 'w') as f: filter_values = example_pb2.FilterValues() filter_values.bytes_list.value.extend([b'hello', b'world', b'excluded']) f.write(filter_values.SerializeToString()) tmp_filter_values_file_int64 = tempfile.mkstemp()[1] with tf.io.gfile.GFile(tmp_filter_values_file_int64, 'w') as f: filter_values = example_pb2.FilterValues() filter_values.int64_list.value.extend([20, 30, 666]) f.write(filter_values.SerializeToString()) def parser(tensor: tf.Tensor): return parse_examples(tensor, sparse_features=list(features.keys()), dense_features=['label'], dense_feature_shapes=[5], dense_feature_types=[tf.float32], extra_features=['uid', 'req_time', 'did'], extra_feature_shapes=[1, 1, 1]) with tf.Graph().as_default(): config = tf.compat.v1.ConfigProto() config.graph_options.rewrite_options.disable_meta_optimizer = True with self.session(config=config) as sess: dataset_base = PBDataset(file_name=file_name, lagrangex_header=False, has_sort_id=True, kafka_dump=False, kafka_dump_prefix=False, input_pb_type=PbType.INSTANCE, output_pb_type=PbType.EXAMPLE) dataset_filter_by_list = dataset_base.filter( lambda variant: filter_by_value(variant, field_name='did', op='not-in', operand=['hello', 'world'], variant_type='example')) dataset_filter_by_file_string = dataset_base.filter( lambda variant: filter_by_value(variant, field_name='did', op='not-in', operand=None, operand_filepath= tmp_filter_values_file_string, variant_type='example')) dataset_filter_by_file_int64 = dataset_base.filter( lambda variant: filter_by_value(variant, field_name='chnid', op='in', operand=None, operand_filepath= tmp_filter_values_file_int64, variant_type='example')) dataset_feature_filter_by_list = dataset_base.filter( lambda variant: filter_by_feature_value(variant, field_name='did', op='not-in', operand=['hello', 'world'], field_type='bytes')) dataset_feature_filter_by_file_string = dataset_base.filter( lambda variant: filter_by_feature_value(variant, field_name='did', op='not-in', operand=None, operand_filepath= tmp_filter_values_file_string, field_type='bytes')) dataset_feature_filter_by_file_int64 = dataset_base.filter( lambda variant: filter_by_feature_value(variant, field_name='chnid', op='in', operand=None, operand_filepath= tmp_filter_values_file_int64, field_type='int64')) batch_size = 5 dataset_filter_by_list = dataset_filter_by_list.batch( batch_size, drop_remainder=False).map(parser) dataset_filter_by_file_string = dataset_filter_by_file_string.batch( batch_size, drop_remainder=False).map(parser) dataset_filter_by_file_int64 = dataset_filter_by_file_int64.batch( batch_size, drop_remainder=False).map(parser) dataset_feature_filter_by_list = dataset_feature_filter_by_list.batch( batch_size, drop_remainder=False).map(parser) dataset_feature_filter_by_file_string = dataset_feature_filter_by_file_string.batch( batch_size, drop_remainder=False).map(parser) dataset_feature_filter_by_file_int64 = dataset_feature_filter_by_file_int64.batch( batch_size, drop_remainder=False).map(parser) try: # test for filter by not-in list it = tf.compat.v1.data.make_one_shot_iterator(dataset_filter_by_list) element = it.get_next() element_result = sess.run(element) self.assertAllEqual(len(element_result['did']), 3) self.assertAllEqual(element_result['did'], [[b'excluded'], [b'300'], [b'400']]) it = tf.compat.v1.data.make_one_shot_iterator(dataset_feature_filter_by_list) element = it.get_next() element_result = sess.run(element) self.assertAllEqual(len(element_result['did']), 3) self.assertAllEqual(element_result['did'], [[b'excluded'], [b'300'], [b'400']]) # test for filter by not-in file it = tf.compat.v1.data.make_one_shot_iterator( dataset_filter_by_file_string) element = it.get_next() element_result = sess.run(element) self.assertAllEqual(len(element_result['did']), 2) self.assertAllEqual(element_result['did'], [[b'300'], [b'400']]) it = tf.compat.v1.data.make_one_shot_iterator( dataset_feature_filter_by_file_string) element = it.get_next() element_result = sess.run(element) self.assertAllEqual(len(element_result['did']), 2) self.assertAllEqual(element_result['did'], [[b'300'], [b'400']]) # test for filter by in file it = tf.compat.v1.data.make_one_shot_iterator( dataset_filter_by_file_int64) element = it.get_next() element_result = sess.run(element) self.assertAllEqual(len(element_result['did']), 3) self.assertAllEqual(element_result['did'], [[b'world'], [b'excluded'], [b'400']]) it = tf.compat.v1.data.make_one_shot_iterator( dataset_feature_filter_by_file_int64) element = it.get_next() element_result = sess.run(element) self.assertAllEqual(len(element_result['did']), 3) self.assertAllEqual(element_result['did'], [[b'world'], [b'excluded'], [b'400']]) except tf.errors.OutOfRangeError: self.assertTrue(False) os.remove(file_name) os.remove(tmp_filter_values_file_string) os.remove(tmp_filter_values_file_int64) def test_filter_by_value_all(self): mock_batch_num = 1 def mock_instance_for_filter_by_value(batch_num: int = 200): tmpfile = tempfile.mkstemp()[1] labels = [[1], [2], [3], [], []] actions = [[], [], [], [], []] multi_chnids = [[10], [20, 30], [20, 30, 666], [40], [666]] dids = ['hello', 'world', 'excluded', '300', '400'] with io.open(tmpfile, 'wb') as writer: for _ in range(batch_num): for label, action, chnids, did in zip(labels, actions, multi_chnids, dids): instance = generate_instance(label, action, chnid=None, did=did, chnids=chnids, write_line_id_to_feature=True) write_instance_into_file(writer, instance) return tmpfile file_name = mock_instance_for_filter_by_value(mock_batch_num) logging.info('file_name: %s', file_name) def parser(tensor: tf.Tensor): return parse_examples(tensor, sparse_features=list(features.keys()), dense_features=['label'], dense_feature_shapes=[5], dense_feature_types=[tf.float32], extra_features=['uid', 'req_time', 'did'], extra_feature_shapes=[1, 1, 1]) with tf.Graph().as_default(): config = tf.compat.v1.ConfigProto() config.graph_options.rewrite_options.disable_meta_optimizer = True with self.session(config=config) as sess: dataset_base = PBDataset(file_name=file_name, lagrangex_header=False, has_sort_id=True, kafka_dump=False, kafka_dump_prefix=False, input_pb_type=PbType.INSTANCE, output_pb_type=PbType.EXAMPLE) dataset_feature_filter_by_file_int64 = dataset_base.filter( lambda variant: filter_by_feature_value(variant, field_name='chnids', op='all', operand=[20, 30, 666], operand_filepath=None, field_type='int64')) batch_size = 5 dataset_feature_filter_by_file_int64 = dataset_feature_filter_by_file_int64.batch( batch_size, drop_remainder=False).map(parser) try: it = tf.compat.v1.data.make_one_shot_iterator( dataset_feature_filter_by_file_int64) element = it.get_next() element_result = sess.run(element) self.assertAllEqual(len(element_result['did']), 1) self.assertAllEqual(element_result['did'], [[b'excluded']]) except tf.errors.OutOfRangeError: self.assertTrue(False) os.remove(file_name) def test_map_id(self): inputs = tf.constant([123, 456, 789, 912], dtype=tf.int32) map_dict = {123: 0, 456: 1, 789: 2} config = tf.compat.v1.ConfigProto() config.graph_options.rewrite_options.disable_meta_optimizer = True with self.session(config=config) as sess: out_ts = map_id(tensor=inputs, map_dict=map_dict) out = sess.run(out_ts) self.assertListEqual(list(out), [0, 1, 2, -1]) def test_filter_by_fids(self): mock_batch_num = 1 batch_size = 4 def mock_instance(batch_num: int = 200): tmpfile = tempfile.mkstemp()[1] with io.open(tmpfile, 'wb') as writer: for _ in range(batch_num): for i in range(batch_size + 1): instance = generate_instance( [], [], fid_v1_list=[get_fid_v1(2, i), get_fid_v1(3, i)] if i > 0 else [get_fid_v1(2, i)]) write_instance_into_file(writer, instance) return tmpfile file_name = mock_instance(mock_batch_num) logging.info('file_name: %s', file_name) def parser(tensor: tf.Tensor): return parse_instances(tensor, fidv1_features=[2, 3]) with tf.Graph().as_default(): config = tf.compat.v1.ConfigProto() config.graph_options.rewrite_options.disable_meta_optimizer = True with self.session(config=config) as sess: dataset = PBDataset(file_name=file_name, lagrangex_header=False, has_sort_id=True, kafka_dump=False, kafka_dump_prefix=False, input_pb_type=PbType.INSTANCE, output_pb_type=PbType.INSTANCE) dataset = dataset.filter(lambda variant: filter_by_fids( variant, select_slots=[2, 3], variant_type='instance')) dataset = dataset.batch(batch_size, drop_remainder=False).map(parser) it = tf.compat.v1.data.make_one_shot_iterator(dataset) try: element = it.get_next() element_result = sess.run(element) self.assertAllEqual(element_result['slot_2'].values, [get_fid_v1(2, i + 1) for i in range(batch_size)]) self.assertAllEqual(element_result['slot_3'].values, [get_fid_v1(3, i + 1) for i in range(batch_size)]) except tf.errors.OutOfRangeError: self.assertTrue(False) os.remove(file_name) def test_multi_label_gen(self): mock_batch_num = 1 head_to_idx = {'ios': 3, 'wp': 1, 'android': 4, 'other': 0} def mock_instance_for_multi_label_gen(batch_num: int = 10): tmpfile = tempfile.mkstemp()[1] labels = [[1], [2], [3], [1]] actions = [[1, 2], [3], [2], [1]] chnids = [0, 100, 200, 300] device_types = ['ios', 'wp', 'android', 'ios'] with io.open(tmpfile, 'wb') as writer: for _ in range(batch_num): for label, action, chnid, device_type in zip(labels, actions, chnids, device_types): instance = generate_instance(label, action, chnid, device_type=device_type) write_instance_into_file(writer, instance) return tmpfile file_name = mock_instance_for_multi_label_gen(mock_batch_num) logging.info('file_name: %s', file_name) def parser(tensor: tf.Tensor): return parse_instances(tensor, dense_features=['label'], dense_feature_shapes=[5], dense_feature_types=[tf.float32], extra_features=[ 'uid', 'req_time', 'item_id', 'actions', 'device_type' ], extra_feature_shapes=[1, 1, 1, 3, 1]) with tf.Graph().as_default(): config = tf.compat.v1.ConfigProto() config.graph_options.rewrite_options.disable_meta_optimizer = True with self.session(config=config) as sess: dataset = PBDataset(file_name=file_name, lagrangex_header=False, has_sort_id=True, kafka_dump=False, kafka_dump_prefix=False, input_pb_type=PbType.INSTANCE, output_pb_type=PbType.INSTANCE) dataset = dataset.map( lambda variant: multi_label_gen(variant, head_to_index=head_to_idx, head_field='device_type', use_origin_label=False, pos_actions=[3, 2], neg_actions=[1], action_priority='4,3,2,1,0', variant_type='instance')) batch_size = 4 dataset = dataset.batch(batch_size, drop_remainder=False).map(parser) it = tf.compat.v1.data.make_one_shot_iterator(dataset) try: element = it.get_next() element_result = sess.run(element) self.assertAllClose( element_result['label'], [[ -3.4028235e+38, -3.4028235e+38, -3.4028235e+38, 1.0000000e+00, -3.4028235e+38 ], [ -3.4028235e+38, 1.0000000e+00, -3.4028235e+38, -3.4028235e+38, -3.4028235e+38 ], [ -3.4028235e+38, -3.4028235e+38, -3.4028235e+38, -3.4028235e+38, 1.0000000e+00 ], [ -3.4028235e+38, -3.4028235e+38, -3.4028235e+38, 0.0000000e+00, -3.4028235e+38 ]]) except tf.errors.OutOfRangeError: self.assertTrue(False) os.remove(file_name) def test_string_to_variant(self): insts = [] has_header, lg_header_flag = True, False sort_id, kafka_dump, kafka_dump_prefix = True, False, True for i in range(10): inst = proto_parser_pb2.Instance() inst.fid.extend([i for i in range(1, 20)]) inst.line_id.chnid = 1 inst_str = inst.SerializeToString() if lg_header_flag: header = lg_header(None) else: header = sort_header(sort_id, kafka_dump, kafka_dump_prefix) if i == 3: inst_str = b'' if has_header: data = struct.pack(f'<{len(header)}sQ{len(inst_str)}s', header, len(inst_str), inst_str) else: data = inst_str insts.append(data) with tf.Graph().as_default(): config = tf.compat.v1.ConfigProto() config.graph_options.rewrite_options.disable_meta_optimizer = True ips = tf.constant(value=insts, dtype=tf.string, shape=(10,), name='insts') ops = string_to_variant(ips, variant_type='instance', has_header=has_header, lagrangex_header=lg_header_flag, has_sort_id=sort_id, kafka_dump=kafka_dump, kafka_dump_prefix=kafka_dump_prefix, chnids=[1, 2], datasources=["1", "2"], default_datasource='3') zeros = variant_to_zeros(ops) with self.session(config=config) as sess: element_result = sess.run(zeros) self.assertAllEqual(ips.shape, ops.shape) def test_has_variant(self): inst = proto_parser_pb2.Instance() inst.fid.extend([i for i in range(1, 20)]) inst_str = inst.SerializeToString() data = struct.pack(f' v2 it = tf.compat.v1.data.make_one_shot_iterator(dataset) element = it.get_next() for _ in range(10): try: element_result = sess.run(element) for name, (inplace, slot) in ss_meta.items(): if not inplace: ragged_tensor = element_result[f'{name}_share'] for value in ragged_tensor.values: self.assertEqual(value >> 48, shared_slot) ragged_tensor = element_result[name] for value in ragged_tensor.values: self.assertNotEqual(value >> 48, shared_slot) else: ragged_tensor = element_result[name] for value in ragged_tensor.values: self.assertEqual(value >> 48, shared_slot) except tf.errors.OutOfRangeError: break def test_gen_fid_mask_int64(self): config = tf.compat.v1.ConfigProto() config.graph_options.rewrite_options.disable_meta_optimizer = True with self.session(config=config) as sess: ragged: tf.RaggedTensor = tf.ragged.constant([[1, 2, 3], [3], [], [4, 5, 6]], dtype=tf.int64) mask_ts = gen_fid_mask(ragged, 3) mask = sess.run(mask_ts) exp_res = [1., 1., 0., 0.] self.assertListEqual(list(mask), exp_res) def test_gen_fid_mask_int32(self): config = tf.compat.v1.ConfigProto() config.graph_options.rewrite_options.disable_meta_optimizer = True with self.session(config=config) as sess: ragged: tf.RaggedTensor = tf.ragged.constant([[1, 2, 3], [3], [], [4, 5, 6]], dtype=tf.int64, row_splits_dtype=tf.int32) mask_ts = gen_fid_mask(ragged, 3) mask = sess.run(mask_ts) exp_res = [1., 1., 0., 0.] self.assertListEqual(list(mask), exp_res) def test_negative_sample_with_positive_actions(self): neg_counter = 0 filt_counter = 0 neg_counter_mismatch = 0 filt_counter_mismatch = 0 for i in range(1000): inst = proto_parser_pb2.Instance() if i % 11 == 0: inst.label.append(1) else: inst.label.append(0) if i % 5 == 0: # match action=2, drop_rate=0 inst.line_id.actions.extend([1, 2, 4, 5]) elif i % 5 == 1: # match action=3, drop_rate=1 inst.line_id.actions.extend([1, 3, 5]) elif i % 5 == 2: # match action=5, drop_rate=0.22 inst.line_id.actions.extend([5, 6]) elif i % 5 == 3: # match priority, but not in per_action_drop_rate inst.line_id.actions.extend([6]) elif i % 5 == 4: # mismatch inst.line_id.actions.extend([10, 11, 12]) inst_str = inst.SerializeToString() data = struct.pack(f" 0: cur_global_step = run_context.session.run(self._global_step_tensor) if cur_global_step > self._last_global_step + self._save_steps: logging.info("after_run start to save item_pool at step {}".format( str(cur_global_step))) run_context.session.run( self._save_op, feed_dict={self._save_global_step: cur_global_step}) self._last_global_step = cur_global_step def end(self, session): # pylint: disable=unused-argument if self._mode != tf.estimator.ModeKeys.TRAIN: return if self._save_op is not None: cur_global_step = session.run(self._global_step_tensor) if cur_global_step > self._last_global_step: logging.info("session_end start to save item_pool at step {}".format( str(cur_global_step))) session.run(self._save_op, feed_dict={self._save_global_step: cur_global_step}) self._last_global_step = cur_global_step ================================================ FILE: monolith/native_training/data/item_pool_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from absl import logging import os import getpass import tensorflow as tf from tensorflow.python.framework import load_library from monolith.utils import get_libops_path from monolith.native_training.data.feature_utils import create_item_pool, \ save_item_pool, restore_item_pool, item_pool_random_fill, item_pool_check class ItemPoolTest(tf.test.TestCase): @classmethod def setUpClass(cls): cls.model_path = f"{os.environ.get('HOME')}/{getpass.getuser()}/tmp/monolith/data/test" global_step = tf.constant(1, dtype=tf.int64) pool = create_item_pool(start_num=20, max_item_num_per_channel=100, shared_name='first') pool = item_pool_random_fill(pool) pool = save_item_pool(pool, model_path=cls.model_path, global_step=global_step, nshards=2) def test_create_item_pool(self): global_step = tf.constant(1, dtype=tf.int64) pool = create_item_pool(start_num=20, max_item_num_per_channel=100, shared_name='second') pool = restore_item_pool(pool, model_path=self.model_path, global_step=global_step, nshards=2) pool = item_pool_check(pool, global_step=global_step, model_path=self.model_path, nshards=2) logging.info(f"model_path is {self.model_path}") if __name__ == "__main__": tf.test.main() ================================================ FILE: monolith/native_training/data/kafka_dataset_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import time import getpass from absl import flags, app from random import choice from struct import pack from absl.testing import parameterized from kafka import KafkaProducer import random import tensorflow as tf from monolith.native_training.data.parsers import parse_instances, parse_examples from monolith.native_training.model_export.data_gen_utils import gen_example, gen_instance, gen_example_batch, FeatureMeta from monolith.native_training.data.datasets import KafkaDataset, PbType from monolith.native_training.data.feature_utils import add_label, filter_by_label # flags.DEFINE_string('feature_list', None, 'string, feature_list') flags.DEFINE_bool('lagrangex_header', False, 'bool, lagrangex_header') flags.DEFINE_bool('sort_id', False, 'bool, sort_id') flags.DEFINE_bool('kafka_dump', False, 'bool, kafka_dump') flags.DEFINE_bool('kafka_dump_prefix', False, 'bool, kafka_dump_prefix') flags.DEFINE_string('topic', 'test1', 'string, topic') flags.DEFINE_string('group_id', None, 'string, group_id') flags.DEFINE_string( 'kafka_servers', 'kafka-cnaittauujjoe7a9.kafka.volces.com:9492,kafka-cnaittauujjoe7a9.kafka.volces.com:9493,kafka-cnaittauujjoe7a9.kafka.volces.com:9494', 'string, kafka_servers') flags.DEFINE_bool('data_gen', True, 'bool, data_gen') flags.DEFINE_integer('num_batch', 3, 'bool, num_batch') other_meta = "security.protocol=sasl_ssl,enable.ssl.certificate.verification=0,sasl.mechanisms=SCRAM-SHA-256,sasl.username=hupu_stream_test1_user1,sasl.password=hupu_stream_test1_user1" FLAGS = flags.FLAGS BATCH_SIZE = 8 USE_CLICK_HEAD = False VALID_FNAMES = [ 1, 2, 3, 4, 5, 6, 7, 8, 81, 82, 83, 84, 86, 87, 88, 89, 92, 93, 110, 115, 205, 208, 209, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, 501, 502, 503, 504, 505, 506, 507, 508, 509, 510, 511, 512, 513, 514, 515, 516, 517, 518, 519, 520, 521, 522, 523, 526, 527, 528, 529, 530, 531, 532, 533, 534, 536, 537, 538, 540, 542, 543, 544, 549, 562, 564, 565, 567, 568, 569, 573, 576, 577, 700, 701, 707, 708, 709, 710, 711, 712, 719, 720, 800, 801, 802, 803, 804, 805, 806, 807, 808, 809, 810, 811, 812, 813, 814, 815, 816, 817, 818, 819, 820, 821, 822, 823, 824, 825, 826, 828, 829, 830, 832, 834, 835, 836, 837, 838, 839, 840, 841, 842, 843, 844, 845, 846, 847, 848, 849, 850, 851, 852, 853, 854, 855, 856, 857, 858, 859, 860, 861, 862, 863, 864, 865, 866, 867, 868, 869, 870, 871, 872, 873, 874, 875, 876, 877, 878, 879, 880, 881, 882, 883, 884, 885, 886, 887, 888, 889, 890, 891, 892, 893, 894, 895, 896, 897, 898, 899, 900, 901, 903, 905, 906, 907, 908, 909, 910, 911, 912, 913, 914, 915, 918, 924, 925, 926, 927, 928, 929, 930, 932, 933, 934, 935, 937, 938, 939, 940, 941, 942, 944, 946, 947, 948, 949, 950, 951, 952, 954, 955, 956, 958, 959, 960, 961, 962, 963, 964, 965, 966, 967, 968, 969, 970, 971, 972, 973, 974, 975, 976, 977, 978, 979, 980, 981, 982, 983, 984, 1001, 1002, 1003, 1004, 1005, 1006, 1007, 1022 ] FNAME_TO_SLOT = { f"fc_slot_{slot}": slot for i, slot in enumerate(range(10000, 10002)) } VALID_SLOTS_V2_NAMES = list(sorted(set(FNAME_TO_SLOT))) def start_producer(input_type): FLAGS.lagrangex_header = False FLAGS.sort_id = False FLAGS.kafka_dump = False if True or FLAGS.data_gen: producer = KafkaProducer(bootstrap_servers=FLAGS.kafka_servers, security_protocol="SASL_SSL", sasl_mechanism="SCRAM-SHA-256", sasl_plain_username="hupu_stream_test1_user1", sasl_plain_password="hupu_stream_test1_user1") dense_features = [FeatureMeta(name='label', shape=4, dtype=tf.float32)] extra_features = [ FeatureMeta(name='req_time', shape=1), FeatureMeta(name='uid', shape=1), FeatureMeta(name='sample_rate', shape=1) ] actions = [-7, -9, 75, -103, 74, 101, 102, -41] time.sleep(10) for i in range(FLAGS.num_batch): all_len = 0 if input_type == PbType.EXAMPLEBATCH: inst = gen_example_batch(sparse_features=VALID_SLOTS_V2_NAMES, dense_features=dense_features, extra_features=extra_features, actions=actions, batch_size=BATCH_SIZE) inst_str = inst.SerializeToString() fmt = f' #include #include #include #include #include "absl/container/flat_hash_set.h" #include "absl/strings/str_format.h" #include "monolith/native_training/data/kernels/internal/relational_utils.h" #include "monolith/native_training/data/training_instance/cc/instance_utils.h" #include "monolith/native_training/data/training_instance/cc/pb_variant.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/framework/tensor.h" namespace tensorflow { namespace monolith_tf { using IFeature = ::idl::matrix::proto::Feature; using Instance = ::parser::proto::Instance; using Example = ::monolith::io::proto::Example; using LineId = ::idl::matrix::proto::LineId; class AddActionOp : public OpKernel { public: explicit AddActionOp(OpKernelConstruction *context) : OpKernel(context) { OP_REQUIRES_OK(context, context->GetAttr("field_name", &field_name_)); OP_REQUIRES_OK(context, context->GetAttr("op", &op_)); OP_REQUIRES_OK(context, context->GetAttr("float_operand", &float_operand_)); OP_REQUIRES_OK(context, context->GetAttr("int_operand", &int_operand_)); OP_REQUIRES_OK(context, context->GetAttr("string_operand", &string_operand_)); OP_REQUIRES_OK(context, context->GetAttr("variant_type", &variant_type_)); OP_REQUIRES_OK(context, context->GetAttr("actions", &actions_)); uint_operand_.insert(uint_operand_.end(), int_operand_.begin(), int_operand_.end()); if (!internal::VALID_OPS.count(op_)) { LOG(FATAL) << absl::StrFormat( "Invalid op: %s, please choose one from [%s]", op_, absl::StrJoin(internal::VALID_OPS, ", ")); } if (variant_type_ != "instance" && variant_type_ != "example") { LOG(FATAL) << "Invalid 'variant_type', please choose on from " "['instance', 'example']!"; } if (actions_.empty()) { LOG(FATAL) << "Please specify 'actions' to add!"; } if (op_ == internal::IN || op_ == internal::NOT_IN) { float_operand_set_.insert(float_operand_.begin(), float_operand_.end()); int_operand_set_.insert(int_operand_.begin(), int_operand_.end()); uint_operand_set_.insert(uint_operand_.begin(), uint_operand_.end()); string_operand_set_.insert(string_operand_.begin(), string_operand_.end()); } } void Compute(OpKernelContext *context) override { const Tensor &input_tensor = context->input(0); Tensor *output_tensor = nullptr; OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(), &output_tensor)); bool is_instance = variant_type_ == "instance"; if (is_instance) { Instance instance; instance.CopyFrom(*input_tensor.scalar()().get()); output_tensor->scalar()() = std::move(instance); } else { Example example; example.CopyFrom(*input_tensor.scalar()().get()); output_tensor->scalar()() = std::move(example); } const google::protobuf::Descriptor *descriptor = LineId::GetDescriptor(); const google::protobuf::Reflection *reflection = LineId::GetReflection(); const google::protobuf::FieldDescriptor *field = descriptor->FindFieldByName(field_name_); if (field == nullptr || field->is_repeated()) { return; } LineId &line_id = *GetLineId(output_tensor, is_instance); bool to_add_action = false; switch (field->cpp_type()) { case google::protobuf::FieldDescriptor::CppType::CPPTYPE_FLOAT: { float value = reflection->GetFloat(line_id, field); to_add_action = internal::COMPARE_OPS.count(op_) ? internal::compare(op_, value, float_operand_) : internal::contains(op_, value, float_operand_set_); break; } case google::protobuf::FieldDescriptor::CppType::CPPTYPE_DOUBLE: { double value = reflection->GetDouble(line_id, field); to_add_action = internal::COMPARE_OPS.count(op_) ? internal::compare(op_, value, float_operand_) : internal::contains(op_, value, float_operand_set_); break; } case google::protobuf::FieldDescriptor::CppType::CPPTYPE_INT32: { int64 value = reflection->GetInt32(line_id, field); to_add_action = internal::COMPARE_OPS.count(op_) ? internal::compare(op_, value, int_operand_) : internal::contains(op_, value, int_operand_set_); break; } case google::protobuf::FieldDescriptor::CppType::CPPTYPE_INT64: { int64 value = reflection->GetInt64(line_id, field); to_add_action = internal::COMPARE_OPS.count(op_) ? internal::compare(op_, value, int_operand_) : internal::contains(op_, value, int_operand_set_); break; } case google::protobuf::FieldDescriptor::CppType::CPPTYPE_UINT32: { int64 value = reflection->GetUInt32(line_id, field); to_add_action = internal::COMPARE_OPS.count(op_) ? internal::compare(op_, value, int_operand_) : internal::contains(op_, value, int_operand_set_); break; } case google::protobuf::FieldDescriptor::CppType::CPPTYPE_UINT64: { uint64 value = reflection->GetUInt64(line_id, field); to_add_action = internal::COMPARE_OPS.count(op_) ? internal::compare(op_, value, uint_operand_) : internal::contains(op_, value, uint_operand_set_); break; } case google::protobuf::FieldDescriptor::CppType::CPPTYPE_STRING: { std::string value = reflection->GetString(line_id, field); to_add_action = internal::COMPARE_OPS.count(op_) ? internal::compare(op_, value, string_operand_) : internal::contains(op_, value, string_operand_set_); break; } default: to_add_action = false; LOG(INFO) << "dtype is " << field->cpp_type(); break; } if (to_add_action) { for (int32 value : actions_) { line_id.mutable_actions()->Add(value); } } } private: static LineId *GetLineId(Tensor *output_tensor, bool is_instance) { if (is_instance) { return output_tensor->scalar()() .get() ->mutable_line_id(); } else { return output_tensor->scalar()() .get() ->mutable_line_id(); } } std::string field_name_; std::string op_; std::vector float_operand_; std::vector int_operand_; std::vector uint_operand_; std::vector string_operand_; std::unordered_set float_operand_set_; std::unordered_set int_operand_set_; std::unordered_set uint_operand_set_; std::unordered_set string_operand_set_; std::string variant_type_; std::vector actions_; }; namespace { REGISTER_KERNEL_BUILDER(Name("AddAction").Device(DEVICE_CPU), AddActionOp) } // namespace } // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/data/kernels/add_label_kernel.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include #include "absl/container/flat_hash_set.h" #include "absl/strings/str_format.h" #include "monolith/native_training/data/kernels/internal/label_utils.h" #include "monolith/native_training/data/training_instance/cc/instance_utils.h" #include "monolith/native_training/data/training_instance/cc/pb_variant.h" #include "monolith/native_training/runtime/common/linalg_utils.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/framework/tensor.h" namespace tensorflow { namespace monolith_tf { using IFeature = ::idl::matrix::proto::Feature; using Instance = ::parser::proto::Instance; using Example = ::monolith::io::proto::Example; using LineId = ::idl::matrix::proto::LineId; using ::monolith::common::IsAlmostEqual; class AddLabelOp : public OpKernel { public: explicit AddLabelOp(OpKernelConstruction *context) : OpKernel(context) { OP_REQUIRES_OK(context, context->GetAttr("config", &config_)); OP_REQUIRES_OK(context, context->GetAttr("negative_value", &negative_value_)); OP_REQUIRES_OK(context, context->GetAttr("sample_rate", &sample_rate_)); OP_REQUIRES_OK(context, context->GetAttr("variant_type", &variant_type_)); if (variant_type_ != "instance" && variant_type_ != "example") { LOG(FATAL) << "Invalid 'variant_type', please choose on from " "['instance', 'example']!"; } internal::ParseTaskConfig(config_, &task_configs_); for (size_t i = 0; i < task_configs_.size(); ++i) { LOG(INFO) << absl::StrFormat("Task #%d config: %s", i + 1, task_configs_[i].ToString()); } LOG(INFO) << absl::StrFormat("sample_rate = %.4f", sample_rate_); std::size_t seed = std::chrono::system_clock::now().time_since_epoch().count(); random_generator_.seed(seed); random_neg_sample_ = std::uniform_real_distribution(0.0, 1.0); } void Compute(OpKernelContext *context) override { const Tensor &input_tensor = context->input(0); Tensor *output_tensor = nullptr; OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(), &output_tensor)); bool is_instance = variant_type_ == "instance"; if (is_instance) { Instance instance; instance.CopyFrom(*input_tensor.scalar()().get()); output_tensor->scalar()() = std::move(instance); } else { Example example; example.CopyFrom(*input_tensor.scalar()().get()); output_tensor->scalar()() = std::move(example); } LineId *line_id = GetLineId(output_tensor, is_instance); auto label = GetLabel(output_tensor, is_instance); std::set actions(line_id->actions().begin(), line_id->actions().end()); if (!label->empty() && label->Get(0) <= 0) { label->Set(0, internal::INVALID_LABEL); } for (const auto &t : task_configs_) { bool has_pos = internal::HasIntersection(actions, t.pos_actions); bool has_neg = internal::HasIntersection(actions, t.neg_actions); if (!t.neg_actions.empty()) { // If there is given neg_actions if (!has_pos && !has_neg) { label->Add(internal::INVALID_LABEL); } else if (has_pos) { // (has_pos && !has_neg) || (has_pos && has_neg) label->Add(internal::POSITIVE_LABEL); } else { // !has_pos && has_neg if (SelectedByNegativeSampling(t)) { label->Add(negative_value_); } else { label->Add(internal::INVALID_LABEL); } } } else { // If there is no given neg_actions if (has_pos) { label->Add(internal::POSITIVE_LABEL); } else { if (SelectedByNegativeSampling(t)) { label->Add(negative_value_); } else { label->Add(internal::INVALID_LABEL); } } } } line_id->set_sample_rate(sample_rate_); } private: bool SelectedByNegativeSampling(const internal::TaskConfig &t) { return IsAlmostEqual(t.sample_rate, 1.0f) || random_neg_sample_(random_generator_) < t.sample_rate; } static LineId *GetLineId(Tensor *output_tensor, bool is_instance) { if (is_instance) { return output_tensor->scalar()() .get() ->mutable_line_id(); } else { return output_tensor->scalar()() .get() ->mutable_line_id(); } } static ::google::protobuf::RepeatedField *GetLabel( Tensor *output_tensor, bool is_instance) { if (is_instance) { return output_tensor->scalar()() .get() ->mutable_label(); } else { return output_tensor->scalar()().get()->mutable_label(); } } float negative_value_; float sample_rate_; std::string config_; std::string variant_type_; std::vector task_configs_; std::default_random_engine random_generator_; std::uniform_real_distribution random_neg_sample_; }; namespace { REGISTER_KERNEL_BUILDER(Name("AddLabel").Device(DEVICE_CPU), AddLabelOp) } // namespace } // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/data/kernels/cache_one_dataset_kernel.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/data/kernels/cache_one_dataset_kernel.h" #include #include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/stats_aggregator.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/protobuf/error_codes.pb.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" namespace tensorflow { namespace data { namespace monolith_tf { class CacheOneDatasetOp::Dataset : public DatasetBase { public: Dataset(OpKernelContext* ctx, const DatasetBase* input) : DatasetBase(DatasetContext(ctx)), input_(input) { input_->Ref(); output_dtypes_ = input->output_dtypes(); output_dtypes_.push_back(DT_BOOL); output_shapes_ = input->output_shapes(); output_shapes_.push_back({}); } ~Dataset() override { input_->Unref(); } std::unique_ptr MakeIteratorInternal( const string& prefix) const override { return absl::make_unique( Iterator::Params{this, absl::StrCat(prefix, ":: CacheOneDataset")}); } const DataTypeVector& output_dtypes() const override { return output_dtypes_; } const std::vector& output_shapes() const override { return output_shapes_; } string DebugString() const override { return "This is the customized Dataset: CacheOneDataset"; } int64 Cardinality() const override { return input_->Cardinality(); } Status InputDatasets(std::vector* inputs) const override { inputs->push_back(input_); return Status::OK(); } Status CheckExternalState() const override { return input_->CheckExternalState(); } private: Status AsGraphDefInternal(SerializationContext* ctx, DatasetGraphDefBuilder* b, Node** output) const override { Node* input_graph_node = nullptr; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); TF_RETURN_IF_ERROR(b->AddDataset(this, {input_graph_node}, {}, output)); return Status::OK(); } class Iterator : public DatasetIterator { public: explicit Iterator(const Params& params) : DatasetIterator(params) {} Status Initialize(IteratorContext* ctx) override { absl::MutexLock l(&mu_); return dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_); } Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, bool* end_of_sequence) override { absl::MutexLock l(&mu_); if (first_element_) { first_element_ = false; TF_RETURN_IF_ERROR( input_impl_->GetNext(ctx, &buffered_tensors_, end_of_sequence)); if (*end_of_sequence) { // This is the special case that input dataset contains no data. // Here we just throw it out. return Status::OK(); } } // We run out of the data. if (eof_) { *end_of_sequence = true; return Status::OK(); } *out_tensors = std::move(buffered_tensors_); buffered_tensors_.clear(); TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, &buffered_tensors_, &eof_)); Tensor eof_tensor(ctx->allocator({}), DT_BOOL, {}); eof_tensor.scalar()() = eof_; out_tensors->push_back(eof_tensor); *end_of_sequence = false; return Status::OK(); } std::shared_ptr CreateNode( IteratorContext* ctx, model::Node::Args args) const override { return model::MakeKnownRatioNode(std::move(args), 1); } Status SaveInternal(SerializationContext* ctx, IteratorStateWriter* writer) override { return errors::Unimplemented("Not Implemented"); } Status RestoreInternal(IteratorContext* ctx, IteratorStateReader* reader) override { return errors::Unimplemented("Not Implemented"); } absl::Mutex mu_; std::unique_ptr input_impl_; std::vector buffered_tensors_; bool first_element_ = true; bool eof_ = false; }; private: const DatasetBase* const input_; DataTypeVector output_dtypes_; std::vector output_shapes_; }; void CacheOneDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input, DatasetBase** output) { *output = new Dataset(ctx, input); } CacheOneDatasetOp::CacheOneDatasetOp(OpKernelConstruction* ctx) : UnaryDatasetOpKernel(ctx) {} namespace { REGISTER_KERNEL_BUILDER(Name("MonolithCacheOneDataset").Device(DEVICE_CPU), CacheOneDatasetOp); } // namespace } // namespace monolith_tf } // namespace data } // namespace tensorflow ================================================ FILE: monolith/native_training/data/kernels/cache_one_dataset_kernel.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_NATIVE_TRAINING_DATA_KERNELS_CACHE_ONE_DATASET_KERNEL_H_ #define MONOLITH_NATIVE_TRAINING_DATA_KERNELS_CACHE_ONE_DATASET_KERNEL_H_ #include "tensorflow/core/framework/dataset.h" namespace tensorflow { namespace data { namespace monolith_tf { class CacheOneDatasetOp : public UnaryDatasetOpKernel { public: explicit CacheOneDatasetOp(OpKernelConstruction* ctx); protected: void MakeDataset(OpKernelContext* ctx, DatasetBase* input, DatasetBase** output) override; private: class Dataset; }; } // namespace monolith_tf } // namespace data } // namespace tensorflow #endif // MONOLITH_NATIVE_TRAINING_DATA_KERNELS_CACHE_ONE_DATASET_KERNEL_H_ ================================================ FILE: monolith/native_training/data/kernels/df_resource_kernel.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/data/kernels/df_resource_kernel.h" namespace tensorflow { namespace monolith_tf { using Queue = ::monolith::concurrency::Queue; Status RegisterCancellationCallback(CancellationManager* cancellation_manager, CancelCallback callback, std::function* deregister_fn) { if (cancellation_manager) { CancellationToken token = cancellation_manager->get_cancellation_token(); if (!cancellation_manager->RegisterCallback(token, std::move(callback))) { return errors::Cancelled("Operation was cancelled"); } *deregister_fn = [cancellation_manager, token]() { cancellation_manager->DeregisterCallback(token); }; } else { VLOG(1) << "Cancellation manager is not set. Cancellation callback will " "not be registered."; *deregister_fn = []() {}; } return Status::OK(); } class CreateQueueOp : public ResourceOpKernel { public: explicit CreateQueueOp(OpKernelConstruction* c) : ResourceOpKernel(c) { OP_REQUIRES_OK(c, c->GetAttr("max_size", &max_size_)); } ~CreateQueueOp() override {} private: Status CreateResource(QueueResource** queue) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) override { *queue = new QueueResource(max_size_); return Status::OK(); } int max_size_; }; REGISTER_OP("CreateQueue") .Output("handle: resource") .Attr("max_size: int") .Attr("container: string = ''") .Attr("shared_name: string = ''") .SetIsStateful() .SetShapeFn(shape_inference::ScalarShape); REGISTER_KERNEL_BUILDER(Name("CreateQueue").Device(DEVICE_CPU), CreateQueueOp); } // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/data/kernels/df_resource_kernel.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_NATIVE_TRAINING_DATA_KERNELS_DF_RESOURCE_KERNEL_H_ #define MONOLITH_NATIVE_TRAINING_DATA_KERNELS_DF_RESOURCE_KERNEL_H_ #include #include #include "absl/synchronization/mutex.h" #include "monolith/native_training/runtime/concurrency/queue.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/framework/resource_op_kernel.h" namespace tensorflow { namespace monolith_tf { enum class VariantType { PBInstance, PBExample }; typedef struct { std::vector out_tensors; bool end_of_sequence; } Item; // It is a thin wrapper of GFile. Make it compatible with ResourceKernelOp // and thread safe. class QueueResource : public ResourceBase { public: explicit QueueResource(size_t max_size = 100) { queue_ = std::make_unique<::monolith::concurrency::Queue>(max_size); } ~QueueResource() = default; std::string DebugString() const override { return "QueueResource"; } void Push(const Item &item) { bool pushed = false; do { pushed = queue_->try_push(item, std::chrono::milliseconds(100)); if (!pushed) { std::this_thread::sleep_for(std::chrono::milliseconds(100)); } } while (!pushed); } bool TryPush(const Item &item, int64_t timeout = 100) { return queue_->try_push(item, std::chrono::milliseconds(timeout)); } Item Pop() const { bool poped = false; Item item; do { poped = queue_->try_pop(item, std::chrono::milliseconds(10)); if (!poped) { std::this_thread::sleep_for(std::chrono::milliseconds(10)); } } while (!poped); return item; } bool TryPop(Item &item, int64_t timeout = 100) const { return queue_->try_pop(item, std::chrono::milliseconds(timeout)); } bool Empty() const { return queue_->empty(); } private: mutable std::unique_ptr<::monolith::concurrency::Queue> queue_; }; Status RegisterCancellationCallback(CancellationManager *cancellation_manager, CancelCallback callback, std::function *deregister_fn); } // namespace monolith_tf } // namespace tensorflow #endif // MONOLITH_NATIVE_TRAINING_DATA_KERNELS_DF_RESOURCE_KERNEL_H_ ================================================ FILE: monolith/native_training/data/kernels/dynamic_match_file_dataset_kernel.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "monolith/native_training/data/kernels/internal/file_match_split_provider.h" #include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/lib/core/blocking_counter.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/lib/io/buffered_inputstream.h" #include "tensorflow/core/lib/io/inputbuffer.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/io/random_inputstream.h" #include "tensorflow/core/lib/io/record_reader.h" #include "tensorflow/core/lib/io/zlib_compression_options.h" #include "tensorflow/core/lib/io/zlib_inputstream.h" #include "tensorflow/core/platform/env.h" namespace tensorflow { namespace data { namespace monolith_tf { class DynamicMatchingFilesDatasetOp : public DatasetOpKernel { public: using DatasetOpKernel::DatasetOpKernel; void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { const Tensor* patterns_t; OP_REQUIRES_OK(ctx, ctx->input("patterns", &patterns_t)); const auto patterns = patterns_t->flat(); size_t num_patterns = static_cast(patterns.size()); std::vector pattern_strs; pattern_strs.reserve(num_patterns); for (size_t i = 0; i < num_patterns; i++) { LOG_EVERY_N(INFO, 100) << "pattern " << patterns(i) << ", num_patterns " << num_patterns; pattern_strs.push_back(patterns(i)); } *output = new Dataset(ctx, std::move(pattern_strs)); } private: class Dataset : public DatasetBase { public: Dataset(OpKernelContext* ctx, std::vector patterns) : DatasetBase(DatasetContext(ctx)), patterns_(std::move(patterns)) {} std::unique_ptr MakeIteratorInternal( const string& prefix) const override { return absl::make_unique(Iterator::Params{ this, strings::StrCat(prefix, "::DynamicMatchingFiles")}); } const DataTypeVector& output_dtypes() const override { static DataTypeVector* dtypes = new DataTypeVector({DT_STRING}); return *dtypes; } const std::vector& output_shapes() const override { static std::vector* shapes = new std::vector({{}}); return *shapes; } string DebugString() const override { return "DynamicMatchingFilesDatasetOp::Dataset"; } Status InputDatasets( std::vector* inputs) const override { return Status::OK(); } Status CheckExternalState() const override { return Status::OK(); } Status MakeSplitProvider( std::unique_ptr* split_provider) const override { split_provider->reset(new FileMatchSplitProvider(patterns_)); return Status::OK(); } protected: Status AsGraphDefInternal(SerializationContext* ctx, DatasetGraphDefBuilder* b, Node** output) const override { Node* patterns_node = nullptr; TF_RETURN_IF_ERROR(b->AddVector(patterns_, &patterns_node)); TF_RETURN_IF_ERROR(b->AddDataset(this, {patterns_node}, output)); return Status::OK(); } private: class Iterator : public DatasetIterator { public: explicit Iterator(const Params& params) : DatasetIterator(params) {} Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, bool* end_of_sequence) override { mutex_lock l(mu_); if (!split_provider_) { LOG(INFO) << "Begin to get split_provider from ctx!"; split_provider_ = ctx->split_provider(); if (!split_provider_) { LOG(INFO) << "No split_provider in ctx, call MakeSplitProvider!"; std::unique_ptr split_provider; TF_RETURN_IF_ERROR(dataset()->MakeSplitProvider(&split_provider)); split_provider_.reset(split_provider.release()); } else { LOG(INFO) << "Got split_provider from IteratorContext"; } LOG(INFO) << "Get split_provider done!"; } if (end_of_sequence_) { *end_of_sequence = true; out_tensors->clear(); return Status::OK(); } Tensor split; Status s = split_provider_->GetNext(&split, end_of_sequence); if (errors::IsOutOfRange(s)) { out_tensors->clear(); *end_of_sequence = true; end_of_sequence_ = true; LOG(INFO) << s.error_message(); } else if (s.ok()) { *end_of_sequence = false; out_tensors->emplace_back(std::move(split)); } else { return s; } return Status::OK(); } protected: std::shared_ptr CreateNode( IteratorContext* ctx, model::Node::Args args) const override { return model::MakeSourceNode(std::move(args)); } Status SaveInternal(SerializationContext* ctx, IteratorStateWriter* writer) override { mutex_lock l(mu_); if (split_provider_ != nullptr) { split_provider_->Save( [this](std::string name) { return FullName(prefix(), name); }, writer); } return Status::OK(); } Status RestoreInternal(IteratorContext* ctx, IteratorStateReader* reader) override { mutex_lock l(mu_); if (!split_provider_) { split_provider_ = ctx->split_provider(); } split_provider_->Restore( [this](std::string name) { return FullName(prefix(), name); }, reader); return Status::OK(); } private: mutex mu_; bool end_of_sequence_ = false; std::shared_ptr split_provider_; }; const std::vector patterns_; }; }; namespace { REGISTER_KERNEL_BUILDER(Name("DynamicMatchingFilesDataset").Device(DEVICE_CPU), DynamicMatchingFilesDatasetOp); } // namespace } // namespace monolith_tf } // namespace data } // namespace tensorflow ================================================ FILE: monolith/native_training/data/kernels/extract_fid_kernel.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "absl/hash/internal/city.h" namespace tensorflow { namespace monolith_tf { class ExtractFidOp : public OpKernel { public: using OpKernel::OpKernel; using ConstFlatSplits = typename TTypes::ConstFlat; explicit ExtractFidOp(OpKernelConstruction* ctx) : OpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("slot", &slot_)); slot_ = slot_ << 48; } void Compute(OpKernelContext* context) override { // Grab the input tensor const Tensor& input_tensor = context->input(0); auto input = input_tensor.flat(); // Create an output tensor Tensor* output_tensor = NULL; OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(), &output_tensor)); auto output_flat = output_tensor->flat(); // Set all to its fid. const int N = input.size(); int64 bits_left = (1ll << 49) - 1; for (int i = 0; i < N; i++) { uint64_t tmp = input(i); int64 hash_val = absl::hash_internal::CityHash64(reinterpret_cast(&tmp), 8); output_flat(i) = (hash_val & bits_left | slot_); } } private: int64 slot_; }; namespace { REGISTER_KERNEL_BUILDER(Name("ExtractFid").Device(DEVICE_CPU), ExtractFidOp); } // namespace } // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/data/kernels/feature_hash.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include #include "absl/hash/internal/city.h" #include "monolith/native_training/data/training_instance/cc/pb_variant.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" namespace tensorflow { namespace monolith_tf { using NamedRawFeature = ::monolith::io::proto::NamedRawFeature; using RawFeature = ::monolith::io::proto::RawFeature; using NamedFeature = ::monolith::io::proto::NamedFeature; using Feature = ::monolith::io::proto::Feature; using Example = ::monolith::io::proto::Example; class FeatureHashOp : public OpKernel { public: using OpKernel::OpKernel; using ConstFlatSplits = typename TTypes::ConstFlat; explicit FeatureHashOp(OpKernelConstruction *ctx) : OpKernel(ctx) { std::vector names; OP_REQUIRES_OK(ctx, ctx->GetAttr("names", &names)); names_.insert(names.begin(), names.end()); } void Compute(OpKernelContext *context) override { // Grab the input tensor const Tensor *pb_input; OP_REQUIRES_OK(context, context->input("input", &pb_input)); TTypes::ConstVec pb_variant_tensor = pb_input->vec(); const int batch_size = pb_variant_tensor.dimension(0); // Create an output tensor Tensor *output_tensor = nullptr; OP_REQUIRES_OK(context, context->allocate_output(0, pb_input->shape(), &output_tensor)); auto output_flat = output_tensor->flat(); for (int i = 0; i < batch_size; ++i) { const Example *in_pb = pb_variant_tensor(i).get(); Example out_pb; out_pb.mutable_line_id()->CopyFrom(in_pb->line_id()); out_pb.mutable_label()->CopyFrom(in_pb->label()); for (size_t i = 0; i < in_pb->named_raw_feature_size(); ++i) { const NamedRawFeature &named_raw_feature = in_pb->named_raw_feature(i); std::string name = named_raw_feature.name(); if (names_.find(name) == names_.end()) continue; NamedFeature *out_nf = out_pb.add_named_feature(); out_nf->set_id(named_raw_feature.id()); out_nf->set_name(name); raw_feature_to_feature(name, named_raw_feature.raw_feature(), out_nf->mutable_feature()); } output_flat(i) = std::move(out_pb); } } private: std::unordered_set names_; void raw_feature_to_feature(const std::string &name, const RawFeature &raw_feature, Feature *feature) { for (size_t i = 0; i < raw_feature.feature_size(); ++i) { const auto &rf = raw_feature.feature(i); if (rf.has_float_list()) { feature->mutable_float_list()->MergeFrom(rf.float_list()); } if (rf.has_double_list()) { feature->mutable_double_list()->MergeFrom(rf.double_list()); } if (rf.has_int64_list()) { feature->mutable_int64_list()->MergeFrom(rf.int64_list()); } if (rf.has_bytes_list()) { const auto &bytes_list = rf.bytes_list(); auto *out_list = feature->mutable_fid_v2_list(); for (size_t j = 0; j < bytes_list.value_size(); ++j) { const std::string &value = absl::StrCat(bytes_list.value(j), "-", name); int64 hash_val = absl::hash_internal::CityHash64(value.c_str(), 8); out_list->add_value(hash_val); } } if (rf.has_float_lists()) { feature->mutable_float_lists()->MergeFrom(rf.float_lists()); } if (rf.has_double_lists()) { feature->mutable_double_lists()->MergeFrom(rf.double_lists()); } if (rf.has_int64_lists()) { feature->mutable_int64_lists()->MergeFrom(rf.int64_lists()); } if (rf.has_bytes_lists()) { const auto &bytes_lists = rf.bytes_lists(); for (size_t j = 0; j < bytes_lists.list_size(); ++j) { const auto &bytes_list = bytes_lists.list(j); auto *out_list = feature->mutable_fid_v2_lists()->add_list(); for (size_t k = 0; k < bytes_list.value_size(); ++k) { const std::string &value = absl::StrCat(bytes_list.value(j), "-", name); int64 hash_val = absl::hash_internal::CityHash64(value.c_str(), 8); out_list->add_value(hash_val); } } } } } }; namespace { REGISTER_KERNEL_BUILDER(Name("FeatureHash").Device(DEVICE_CPU), FeatureHashOp); } // namespace } // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/data/kernels/feature_name_mapper_tf_bridge.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/data/kernels/feature_name_mapper_tf_bridge.h" namespace tensorflow { namespace monolith_tf { Status FeatureNameMapperTfBridge::New(FeatureNameMapperTfBridge** new_bridge) { auto bridge = core::RefCountPtr( new FeatureNameMapperTfBridge()); bridge->mapper_ = std::make_unique(); *new_bridge = bridge.release(); return Status::OK(); } Status FeatureNameMapperTfBridge::RegisterValidIds( const std::vector>& valid_ids) const { try { if (mapper_->RegisterValidIds(valid_ids)) { return Status::OK(); } else { return errors::InvalidArgument("RegisterValidIds failed!"); } } catch (const std::exception& e) { return errors::InvalidArgument(e.what()); } } } // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/data/kernels/feature_name_mapper_tf_bridge.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_MONOLITH_NATIVE_TRAINING_DATA_KERNELS_FEATURE_NAME_MAPPER_TF_BRIDGE_H_ #define MONOLITH_MONOLITH_NATIVE_TRAINING_DATA_KERNELS_FEATURE_NAME_MAPPER_TF_BRIDGE_H_ #include #include #include #include #include #include "monolith/native_training/data/training_instance/cc/reader_util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/platform/status.h" namespace tensorflow { namespace monolith_tf { // A feature name mapper which can be used in TF runtime. // It captures all potential exceptions and convert them into error. class FeatureNameMapperTfBridge : public ResourceBase { public: static constexpr const char* const kName = "FeatureNameMapper"; ~FeatureNameMapperTfBridge() override = default; static Status New(FeatureNameMapperTfBridge** new_bridge); Status RegisterValidIds( const std::vector>& valid_ids) const; std::string DebugString() const override { return mapper_->DebugString(); } FeatureNameMapper* GetFeatureNameMapper() const { return mapper_.get(); } private: FeatureNameMapperTfBridge() = default; std::unique_ptr mapper_; }; } // namespace monolith_tf } // namespace tensorflow #endif // MONOLITH_MONOLITH_NATIVE_TRAINING_DATA_KERNELS_FEATURE_NAME_MAPPER_TF_BRIDGE_H_ ================================================ FILE: monolith/native_training/data/kernels/fill_multi_rank_output_kernel.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include #include "absl/container/flat_hash_set.h" #include "absl/hash/internal/city.h" #include "absl/strings/str_format.h" #include "monolith/native_training/data/kernels/internal/label_utils.h" #include "monolith/native_training/data/training_instance/cc/instance_utils.h" #include "monolith/native_training/data/training_instance/cc/pb_variant.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/framework/tensor.h" namespace tensorflow { namespace monolith_tf { using IFeature = ::idl::matrix::proto::Feature; using Instance = ::parser::proto::Instance; using Example = ::monolith::io::proto::Example; using LineId = ::idl::matrix::proto::LineId; class FillMultiRankOutputOp : public OpKernel { public: explicit FillMultiRankOutputOp(OpKernelConstruction *context) : OpKernel(context) { OP_REQUIRES_OK(context, context->GetAttr("enable_draw_as_rank", &enable_draw_as_rank_)); OP_REQUIRES_OK(context, context->GetAttr("enable_chnid_as_rank", &enable_chnid_as_rank_)); OP_REQUIRES_OK(context, context->GetAttr("enable_lineid_rank_as_rank", &enable_lineid_rank_as_rank_)); if (!(enable_draw_as_rank_ || enable_chnid_as_rank_ || enable_lineid_rank_as_rank_)) { LOG(FATAL) << "At least one of enable_draw_as_rank, enable_chnid_as_rank, " "enable_lineid_rank_as_rank must be set"; } OP_REQUIRES_OK(context, context->GetAttr("rank_num", &rank_num_)); OP_REQUIRES_OK(context, context->GetAttr("variant_type", &variant_type_)); if (variant_type_ != "instance" && variant_type_ != "example") { LOG(FATAL) << "Invalid 'variant_type', please choose on from " "['instance', 'example']!"; } } void Compute(OpKernelContext *context) override { /* Parse data fields from input tensor. */ const Tensor &input_tensor = context->input(0); Tensor *output_tensor = nullptr; OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(), &output_tensor)); bool is_instance = variant_type_ == "instance"; if (is_instance) { Instance instance; instance.CopyFrom(*input_tensor.scalar()().get()); output_tensor->scalar()() = std::move(instance); } else { Example example; example.CopyFrom(*input_tensor.scalar()().get()); output_tensor->scalar()() = std::move(example); } LineId *line_id = GetLineId(output_tensor, is_instance); auto label = GetLabel(output_tensor, is_instance); /* fill_multi_rank_output() from matrix processor: */ if (enable_draw_as_rank_) { int rank = line_id->is_draw() ? 1 : 0; label->Add(rank); return; } if (enable_chnid_as_rank_) { int rank = 0, chnid = line_id->chnid(); if (chnid == 0 || chnid == 1) { rank = chnid; } else { rank = 2; } label->Add(rank); return; } if (enable_lineid_rank_as_rank_) { int rank = line_id->rank(); if (rank >= rank_num_) { rank = rank_num_ - 1; } label->Add(rank); return; } } private: static LineId *GetLineId(Tensor *output_tensor, bool is_instance) { if (is_instance) { return output_tensor->scalar()() .get() ->mutable_line_id(); } else { return output_tensor->scalar()() .get() ->mutable_line_id(); } } static ::google::protobuf::RepeatedField *GetLabel( Tensor *output_tensor, bool is_instance) { if (is_instance) { return output_tensor->scalar()() .get() ->mutable_label(); } else { return output_tensor->scalar()().get()->mutable_label(); } } static std::vector GetFids(Tensor *output_tensor, bool is_instance) { std::vector fids; if (is_instance) { auto instance = output_tensor->scalar()().get(); for (uint64_t fid : instance->fid()) { fids.push_back(fid); } } else { auto example = output_tensor->scalar()().get(); for (const auto &named_feature : example->named_feature()) { if (named_feature.feature().has_fid_v1_list()) { for (const auto &fid : named_feature.feature().fid_v1_list().value()) { fids.push_back(fid); } } } } return fids; } bool enable_draw_as_rank_; bool enable_chnid_as_rank_; bool enable_lineid_rank_as_rank_; int rank_num_; std::string variant_type_; }; namespace { REGISTER_KERNEL_BUILDER(Name("FillMultiRankOutput").Device(DEVICE_CPU), FillMultiRankOutputOp) } // namespace } // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/data/kernels/filter_by_label_kernel.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include #include "absl/container/flat_hash_set.h" #include "absl/strings/str_format.h" #include "monolith/native_training/data/kernels/internal/label_utils.h" #include "monolith/native_training/data/training_instance/cc/instance_utils.h" #include "monolith/native_training/data/training_instance/cc/pb_variant.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/framework/tensor.h" #include "third_party/nlohmann/json.hpp" namespace tensorflow { namespace monolith_tf { using IFeature = ::idl::matrix::proto::Feature; using Instance = ::parser::proto::Instance; using Example = ::monolith::io::proto::Example; using LineId = ::idl::matrix::proto::LineId; // filter_invalid_conseq_time: class FilterByLabelOp : public OpKernel { public: explicit FilterByLabelOp(OpKernelConstruction *context) : OpKernel(context) { OP_REQUIRES_OK(context, context->GetAttr("label_threshold", &label_threshold_)); OP_REQUIRES_OK(context, context->GetAttr("filter_equal", &filter_equal_)); OP_REQUIRES_OK(context, context->GetAttr("variant_type", &variant_type_)); if (variant_type_ != "instance" && variant_type_ != "example") { LOG(FATAL) << "Invalid 'variant_type', please choose on from " "['instance', 'example']!"; } nlohmann::json j; j["label_threshold"] = label_threshold_; LOG(INFO) << absl::StrFormat("Label threshold: %s", j.dump(2)); } void Compute(OpKernelContext *context) override { const Tensor &input_tensor = context->input(0); Tensor *output_tensor = nullptr; OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(), &output_tensor)); auto valid = output_tensor->scalar(); bool is_instance = variant_type_ == "instance"; auto labels = GetLabels(&input_tensor, is_instance); if (labels.size() < label_threshold_.size()) { LOG_EVERY_N_SEC(ERROR, 60) << absl::StrFormat( "Label size(=%ld) should be >= label_threshold size(=%ld), please " "investigate!", labels.size(), label_threshold_.size()); valid() = false; } else { bool has_valid_label = false; for (size_t i = 0; i < label_threshold_.size(); ++i) { if ((labels.Get(i) > label_threshold_[i]) || (labels.Get(i) == label_threshold_[i] && !filter_equal_)) { has_valid_label = true; break; } } valid() = has_valid_label; } } private: static const ::google::protobuf::RepeatedField &GetLabels( const Tensor *output_tensor, bool is_instance) { if (is_instance) { return output_tensor->scalar()().get()->label(); } else { return output_tensor->scalar()().get()->label(); } } std::vector label_threshold_; bool filter_equal_; std::string variant_type_; }; namespace { REGISTER_KERNEL_BUILDER(Name("FilterByLabel").Device(DEVICE_CPU), FilterByLabelOp) } // namespace } // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/data/kernels/gen_fid_mask.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" namespace tensorflow { namespace monolith_tf { template class MonolithGenFidMaskOp : public OpKernel { public: using OpKernel::OpKernel; explicit MonolithGenFidMaskOp(OpKernelConstruction *ctx) : OpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("fid", &fid_)); } void Compute(OpKernelContext *context) override { // Grab the input tensor const Tensor *splits, *values; OP_REQUIRES_OK(context, context->input("splits", &splits)); auto splits_flat = splits->flat(); OP_REQUIRES_OK(context, context->input("values", &values)); auto values_flat = values->flat(); // Create an output tensor Tensor *output_tensor = nullptr; OP_REQUIRES_OK(context, context->allocate_output(0, {splits->NumElements() - 1}, &output_tensor)); auto output_flat = output_tensor->flat(); output_flat.setZero(); for (int i = 1; i < splits->NumElements(); ++i) { int32 start = splits_flat(i - 1); int32 end = splits_flat(i); for (int j = start; j < end; ++j) { if (values_flat(j) == fid_) { output_flat(i - 1) = 1.0; break; } } } } private: int64 fid_; }; namespace { REGISTER_KERNEL_BUILDER( Name("MonolithGenFidMask").Device(DEVICE_CPU).TypeConstraint("T"), MonolithGenFidMaskOp); REGISTER_KERNEL_BUILDER( Name("MonolithGenFidMask").Device(DEVICE_CPU).TypeConstraint("T"), MonolithGenFidMaskOp); } // namespace } // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/data/kernels/instance_reweight_dataset_kernel.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "absl/container/flat_hash_map.h" #include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/framework/stats_aggregator.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/io/buffered_inputstream.h" #include "tensorflow/core/lib/io/inputbuffer.h" #include "tensorflow/core/lib/io/zlib_compression_options.h" #include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/lib/strings/str_util.h" #include "monolith/native_training/data/kernels/instance_reweight_dataset_kernel.h" #include "monolith/native_training/data/training_instance/cc/pb_variant.h" #include "third_party/nlohmann/json.hpp" namespace { const unsigned int NONEXIST_PRIORITY = 2000; const unsigned int UNKNOWN_PRIORITY = 1000; } // namespace namespace tensorflow { namespace data { namespace monolith_tf { using IFeature = ::idl::matrix::proto::Feature; using Instance = ::parser::proto::Instance; using Example = ::monolith::io::proto::Example; using EFeature = ::monolith::io::proto::Feature; using LineId = ::idl::matrix::proto::LineId; using Action = google::protobuf::RepeatedField; // See documentation in ../../ops/dataset_ops.cc for a high-level // description of the following op. /* static */ constexpr const char *const InstanceReweightDatasetOp::kDatasetType; /* static */ constexpr const char *const InstanceReweightDatasetOp::kInputDataset; /* static */ constexpr const char *const InstanceReweightDatasetOp::kMethod; /* static */ constexpr const char *const InstanceReweightDatasetOp::kActions; /* static */ constexpr const char *const InstanceReweightDatasetOp::kWeights; /* static */ constexpr const char *const InstanceReweightDatasetOp::kLabels; /* static */ constexpr const char *const InstanceReweightDatasetOp::kPriority; /* static */ constexpr const char *const InstanceReweightDatasetOp::kVariantType; class InnerIterator { public: InnerIterator(IteratorBase *input_impl, int instance_reweight_method, const std::vector &actions, const std::vector &weights, const std::vector &labels, const std::vector &priorities, std::string variant_type) : instance_reweight_method_(instance_reweight_method), variant_type_(std::string(variant_type)) { input_impl_.reset(input_impl); for (size_t i = 0; i < actions.size(); ++i) { reweight_[actions[i]] = weights[i]; relabel_[actions[i]] = labels[i]; } int idx = 1; for (const auto &value : priorities) { action_priority_[value] = idx++; } tensors_ = new std::vector(); tensors_->reserve(1); } ~InnerIterator() { delete tensors_; } Status GetNext(IteratorContext *ctx, std::vector *out_tensors, bool *end_of_sequence) { Status s = NextInternal(ctx); *end_of_sequence = end_of_sequence_; out_tensors->clear(); if (s.ok() && !end_of_sequence_) { out_tensors->push_back(tensors_->back()); } return s; } private: Status NextInternal(IteratorContext *ctx) { std::lock_guard lck(mu_); while ((replicas_ == 0 || index_ == replicas_) && !end_of_sequence_) { tensors_->clear(); TF_RETURN_IF_ERROR( input_impl_->GetNext(ctx, tensors_, &end_of_sequence_)); if (!end_of_sequence_) { if (variant_type_ == "instance") { Instance *instance = GetCurrentInstance(); if (instance->label_size()) { float *int_label = instance->mutable_label()->Mutable(0); replicas_ = CalReplicas(instance->line_id(), int_label); } else { replicas_ = 0; LOG_EVERY_N_SEC(ERROR, 60) << "label is empty, please investigate!"; } } else if (variant_type_ == "example") { Example *example = GetCurrentExample(); if (example->label_size()) { float *int_label = example->mutable_label()->Mutable(0); replicas_ = CalReplicas(example->line_id(), int_label); } else { replicas_ = 0; LOG_EVERY_N_SEC(ERROR, 60) << "label is empty, please investigate!"; } } else { return errors::InvalidArgument( absl::StrCat(variant_type_, " variant_type is invalid!")); } } else { replicas_ = 0; return Status::OK(); } index_ = 0; } index_++; return Status::OK(); } inline Instance *GetCurrentInstance() { Variant *variant = &tensors_->back().scalar()(); return variant->get(); } inline Example *GetCurrentExample() { Variant *variant = &tensors_->back().scalar()(); return variant->get(); } int CalReplicas(const LineId &lineid, float *ins_label) { // if priority is NONEXIST_PRIORITY or UNKNOWN_PRIORITY, action will be // undefined auto find_most_prior_action = [&](const Action &actions, int64_t *priority, int64_t *action) { *priority = NONEXIST_PRIORITY; if (actions.size() != 0) { *priority = UNKNOWN_PRIORITY; for (auto &act : actions) { auto action_iter = action_priority_.find(act); if (action_iter != action_priority_.end() && action_iter->second < *priority) { *priority = action_iter->second; *action = act; } } } }; auto get_pre_priority = [&]() { int64_t priority, action; find_most_prior_action(lineid.pre_actions(), &priority, &action); return priority; }; auto get_cur_priority = [&]() { int64_t priority, action; find_most_prior_action(lineid.actions(), &priority, &action); return priority; }; auto get_label = [&](const Action &actions, int64_t *label) { int64_t priority, action = 0; find_most_prior_action(actions, &priority, &action); if (priority != NONEXIST_PRIORITY && priority != UNKNOWN_PRIORITY) { auto label_iter = relabel_.find(action); if (label_iter != relabel_.end()) { *label = label_iter->second; return true; } } return false; }; // return true if we can actually get the label auto get_pre_label = [&](int64_t *label) { return get_label(lineid.pre_actions(), label); }; // return true if we can actually get the label & relabel if needed auto get_cur_label = [&](int64_t *label) { if (get_label(lineid.actions(), label)) { *ins_label = *label; } *label = *ins_label; return true; }; auto get_cnt = [&](const Action &actions) { int64_t ins_num = actions.size() != 0 ? 1 : 0; int64_t priority, action = 0; find_most_prior_action(actions, &priority, &action); for (auto act : actions) { if (action_priority_.contains(act) && act != action) { continue; } // set the ins_num if the act need reweight. auto reweight_iter = reweight_.find(act); if (reweight_iter != reweight_.end()) { if (instance_reweight_method_ == 1) { ins_num += reweight_iter->second; } else { ins_num *= reweight_iter->second; } } } return ins_num; }; auto get_pre_cnt = [&]() { return get_cnt(lineid.pre_actions()); }; auto get_cur_cnt = [&]() { return get_cnt(lineid.actions()); }; auto reverse_label = [&]() { *ins_label = -(*ins_label); }; // start from here int64_t pre_priority = get_pre_priority(); int64_t cur_priority = get_cur_priority(); if (pre_priority > cur_priority) { int64_t pre_label; // fast emit label, it's a negative sample int64_t cur_label; // the real label, can be positive or negative // Note: the same sample with different label (+1/-1), equal there is no // sample if (get_cur_label(&cur_label) && get_pre_label(&pre_label)) { // for the real sample (the second one) of fast emit if (pre_label != cur_label) { // for real positive sample // pre_cnt(-1) + pre_cnt(1) + cur_cnt(1) => cur_cnt(1) return get_pre_cnt() + get_cur_cnt(); } else { // the real negative sample auto pre_cnt = get_pre_cnt(); auto cur_cnt = get_cur_cnt(); if (pre_cnt > cur_cnt) { reverse_label(); // pre_cnt(-1) + (pre_cnt(1) - cur_cnt(1)) => cur_cnt(-1) return pre_cnt - cur_cnt; } else { // pre_cnt(-1) + (cur_cnt(-1) - pre_cnt(-1)) => cur_cnt(-1) return cur_cnt - pre_cnt; } } } else { // for fast emit sample (the first one, negative) or non fast emit // sample return get_cur_cnt(); } } else if (pre_priority == NONEXIST_PRIORITY && cur_priority == NONEXIST_PRIORITY) { return 1; } else { return 0; } } std::mutex mu_; int index_ = 0; int replicas_ = 0; bool end_of_sequence_ = false; std::vector *tensors_ = nullptr; std::shared_ptr input_impl_ = nullptr; int instance_reweight_method_; absl::flat_hash_map reweight_; absl::flat_hash_map relabel_; absl::flat_hash_map action_priority_; std::string variant_type_; }; class InstanceReweightDatasetOp::Dataset : public DatasetBase { public: Dataset(OpKernelContext *ctx, const DatasetBase *input, int instance_reweight_method, const std::vector &actions, const std::vector &weights, const std::vector &labels, const std::vector &priorities, std::string variant_type) : DatasetBase(DatasetContext(ctx)), input_(input), instance_reweight_method_(instance_reweight_method), actions_(actions), weights_(weights), labels_(labels), priorities_(priorities), variant_type_(std::move(variant_type)) { input_->Ref(); } ~Dataset() override { input_->Unref(); } std::unique_ptr MakeIteratorInternal( const string &prefix) const override { return absl::make_unique( Iterator::Params{this, strings::StrCat(prefix, "::", kDatasetType)}); } const DataTypeVector &output_dtypes() const override { return input_->output_dtypes(); } const std::vector &output_shapes() const override { return input_->output_shapes(); } string DebugString() const override { return "This is the customized Dataset: InstanceReweight"; } Status InputDatasets( std::vector *inputs) const override { inputs->push_back(input_); return Status::OK(); } Status CheckExternalState() const override { return input_->CheckExternalState(); } protected: Status AsGraphDefInternal(SerializationContext *ctx, DatasetGraphDefBuilder *b, Node **output) const override { Node *input_graph_node; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); AttrValue method_node; b->BuildAttrValue(instance_reweight_method_, &method_node); AttrValue actions_node; b->BuildAttrValue(actions_, &actions_node); AttrValue weights_node; b->BuildAttrValue(weights_, &weights_node); AttrValue labels_node; b->BuildAttrValue(labels_, &labels_node); AttrValue priorities_node; b->BuildAttrValue(priorities_, &priorities_node); AttrValue variant_type_node; b->BuildAttrValue(variant_type_, &variant_type_node); TF_RETURN_IF_ERROR( b->AddDataset(this, // dataset {input_graph_node}, // inputs {{kMethod, method_node}, {kActions, actions_node}, {kWeights, weights_node}, {kLabels, labels_node}, {kPriority, priorities_node}, {kVariantType, variant_type_node}}, // attrs output)); // Node** return Status::OK(); } private: class Iterator : public DatasetIterator { public: explicit Iterator(const Params ¶ms) : DatasetIterator(params) {} Status Initialize(IteratorContext *ctx) override { std::unique_ptr input_impl; Status s = dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl); LOG(INFO) << "Initialize InnerIterator ..."; iter_ = std::make_unique( input_impl.release(), dataset()->instance_reweight_method_, dataset()->actions_, dataset()->weights_, dataset()->labels_, dataset()->priorities_, dataset()->variant_type_); return s; } Status GetNextInternal(IteratorContext *ctx, std::vector *out_tensors, bool *end_of_sequence) override { out_tensors->reserve(1); TF_RETURN_IF_ERROR(iter_->GetNext(ctx, out_tensors, end_of_sequence)); return Status::OK(); } protected: std::shared_ptr CreateNode( IteratorContext *ctx, model::Node::Args args) const override { return model::MakeUnknownRatioNode(std::move(args)); } Status SaveInternal(SerializationContext *ctx, IteratorStateWriter *writer) override { return Status::OK(); } Status RestoreInternal(IteratorContext *ctx, IteratorStateReader *reader) override { return Status::OK(); } private: std::unique_ptr iter_; }; const DatasetBase *const input_; int instance_reweight_method_; std::vector actions_; std::vector weights_; std::vector labels_; std::vector priorities_; std::string variant_type_; }; InstanceReweightDatasetOp::InstanceReweightDatasetOp(OpKernelConstruction *ctx) : UnaryDatasetOpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr(kMethod, &instance_reweight_method_)); OP_REQUIRES_OK(ctx, ctx->GetAttr(kActions, &actions_)); OP_REQUIRES_OK(ctx, ctx->GetAttr(kWeights, &weights_)); OP_REQUIRES_OK(ctx, ctx->GetAttr(kLabels, &labels_)); OP_REQUIRES_OK(ctx, ctx->GetAttr(kPriority, &priorities_)); OP_REQUIRES_OK(ctx, ctx->GetAttr(kVariantType, &variant_type_)); nlohmann::json j; j[kMethod] = instance_reweight_method_; j[kActions] = actions_; j[kWeights] = weights_; j[kLabels] = labels_; j[kPriority] = priorities_; j[kVariantType] = variant_type_; LOG(INFO) << j.dump(); } void InstanceReweightDatasetOp::MakeDataset(OpKernelContext *ctx, DatasetBase *input, DatasetBase **output) { *output = new Dataset(ctx, input, instance_reweight_method_, actions_, weights_, labels_, priorities_, variant_type_); } namespace { REGISTER_KERNEL_BUILDER(Name("InstanceReweightDataset").Device(DEVICE_CPU), InstanceReweightDatasetOp); } // namespace } // namespace monolith_tf } // namespace data } // namespace tensorflow ================================================ FILE: monolith/native_training/data/kernels/instance_reweight_dataset_kernel.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_NATIVE_TRAINING_DATA_KERNELS_INSTANCE_REWEIGHT_DATASET_KERNEL_H_ #define MONOLITH_NATIVE_TRAINING_DATA_KERNELS_INSTANCE_REWEIGHT_DATASET_KERNEL_H_ #include "tensorflow/core/framework/dataset.h" namespace tensorflow { namespace data { namespace monolith_tf { class InstanceReweightDatasetOp : public UnaryDatasetOpKernel { public: static constexpr const char* const kDatasetType = "instance_reweight"; static constexpr const char* const kInputDataset = "input_dataset"; static constexpr const char* const kMethod = "method"; static constexpr const char* const kActions = "actions"; static constexpr const char* const kWeights = "weights"; static constexpr const char* const kLabels = "labels"; static constexpr const char* const kPriority = "priorities"; static constexpr const char* const kVariantType = "variant_type"; explicit InstanceReweightDatasetOp(OpKernelConstruction* ctx); protected: void MakeDataset(OpKernelContext* ctx, DatasetBase* input, DatasetBase** output) override; private: class Dataset; int instance_reweight_method_; std::string variant_type_; std::vector actions_; std::vector weights_; std::vector labels_; std::vector priorities_; }; } // namespace monolith_tf } // namespace data } // namespace tensorflow #endif MONOLITH_NATIVE_TRAINING_DATA_KERNELS_INSTANCE_REWEIGHT_DATASET_KERNEL_H_ ================================================ FILE: monolith/native_training/data/kernels/internal/BUILD ================================================ load("@rules_cc//cc:defs.bzl", "cc_library", "cc_test") load("@org_tensorflow//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test") package(default_visibility = ["//monolith/native_training/data:__subpackages__"]) cc_library( name = "relational_utils", srcs = [], hdrs = ["relational_utils.h"], deps = [ "@com_google_glog//:glog", ], ) cc_test( name = "relational_utils_test", srcs = ["relational_utils_test.cc"], deps = [ ":relational_utils", "@com_google_googletest//:gtest_main", ], ) cc_library( name = "label_utils", srcs = ["label_utils.cc"], hdrs = ["label_utils.h"], deps = [ "//monolith/native_training/data:data_op_config_cc_proto", "//third_party/nlohmann:json", "@com_google_absl//absl/strings", "@com_google_glog//:glog", ], ) cc_test( name = "label_utils_test", srcs = ["label_utils_test.cc"], deps = [ ":label_utils", "@com_google_googletest//:gtest_main", ], ) cc_library( name = "cache_mgr", srcs = [ "cache_mgr.cc", "cache_mgr.h", ], hdrs = ["cache_mgr.h"], deps = [ "//monolith/native_training/data/training_instance:data_reader", "//third_party/nlohmann:json", "@com_google_absl//absl/container:flat_hash_map", ], ) tf_cc_test( name = "cache_mgr_test", srcs = ["cache_mgr_test.cc"], deps = [ ":cache_mgr", "@com_google_googletest//:gtest_main", "@org_tensorflow//tensorflow/core:test", ], ) cc_library( name = "datasource_utils", srcs = [ "datasource_utils.cc", "datasource_utils.h", ], hdrs = ["datasource_utils.h"], ) tf_cc_test( name = "datasource_utils_test", srcs = ["datasource_utils_test.cc"], deps = [ ":datasource_utils", "@com_google_googletest//:gtest_main", "@org_tensorflow//tensorflow/core:test", ], ) cc_library( name = "file_match_split_provider", srcs = ["file_match_split_provider.cc"], hdrs = ["file_match_split_provider.h"], deps = [ "//monolith/native_training/runtime/concurrency:queue", "@org_tensorflow//tensorflow/core:framework_headers_lib", "@com_google_glog//:glog", ], ) tf_cc_test( name = "file_match_split_provider_test", srcs = ["file_match_split_provider_test.cc"], deps = [ ":file_match_split_provider", "@com_google_googletest//:gtest_main", "@org_tensorflow//tensorflow/core:test", ], ) cc_library( name = "parquet_example_reader", srcs = [ "arrow_random_access_file.h", "sized_random_access_file.h", "parquet_column_buffer.h", "parquet_example_reader.h", ], deps = [ "@arrow", ] ) cc_library( name = "uniq_hashtable", hdrs = ["uniq_hashtable.h"], deps = [ "@com_google_absl//absl/base:core_headers", "@com_google_glog//:glog", ], ) tf_cc_test( name = "uniq_hashtable_test", srcs = ["uniq_hashtable_test.cc"], deps = [ ":uniq_hashtable", "@com_google_googletest//:gtest_main", "@org_tensorflow//tensorflow/core:test", ] ) cc_library( name = "value_filter_by_line_id", hdrs = ["value_filter_by_line_id.h"], srcs = ["value_filter_by_line_id.cc"], deps = [ ":relational_utils", "//idl:example_cc_proto", "//third_party/nlohmann:json", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", "@org_tensorflow//tensorflow/core/platform:env", ], ) cc_library( name = "value_filter_by_feature", hdrs = ["value_filter_by_feature.h"], srcs = ["value_filter_by_feature.cc"], deps = [ ":relational_utils", "//idl:example_cc_proto", "//third_party/nlohmann:json", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", "@org_tensorflow//tensorflow/core/platform:env", ], ) ================================================ FILE: monolith/native_training/data/kernels/internal/arrow_random_access_file.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #ifndef TENSORFLOW_IO_ARROW_KERNELS_H_ #define TENSORFLOW_IO_ARROW_KERNELS_H_ #include "arrow/buffer.h" #include "arrow/io/api.h" #include "arrow/type.h" #include "parquet/windows_compatibility.h" #include "tensorflow/core/framework/op_kernel.h" // #include "tensorflow_io/core/kernels/io_stream.h" namespace tensorflow { class RandomAccessFile; namespace data { // NOTE: Both SizedRandomAccessFile and ArrowRandomAccessFile overlap // with another PR. Will remove duplicate once PR merged class ArrowRandomAccessFile : public ::arrow::io::RandomAccessFile { public: explicit ArrowRandomAccessFile(tensorflow::RandomAccessFile* file, int64 size) : file_(file), size_(size), position_(0) {} ~ArrowRandomAccessFile() {} arrow::Status Close() override { return arrow::Status::OK(); } bool closed() const override { return false; } arrow::Result Tell() const override { return position_; } arrow::Status Seek(int64_t position) override { return arrow::Status::NotImplemented("Seek"); } arrow::Result Read(int64_t nbytes, void* out) override { StringPiece result; Status status = file_->Read(position_, nbytes, &result, reinterpret_cast(out)); if (!(status.ok() || errors::IsOutOfRange(status))) { return arrow::Status::IOError(status.error_message()); } position_ += result.size(); return result.size(); } arrow::Result> Read(int64_t nbytes) override { arrow::Result> result = arrow::AllocateResizableBuffer(nbytes); ARROW_RETURN_NOT_OK(result); std::shared_ptr buffer = std::move(result).ValueUnsafe(); ARROW_ASSIGN_OR_RAISE(int64_t bytes_read, Read(nbytes, buffer->mutable_data())); RETURN_NOT_OK(buffer->Resize(bytes_read)); return buffer; } arrow::Result GetSize() override { return size_; } bool supports_zero_copy() const override { return false; } arrow::Result ReadAt(int64_t position, int64_t nbytes, void* out) override { StringPiece result; Status status = file_->Read(position, nbytes, &result, reinterpret_cast(out)); if (!(status.ok() || errors::IsOutOfRange(status))) { return arrow::Status::IOError(status.error_message()); } return result.size(); } arrow::Result> ReadAt( int64_t position, int64_t nbytes) override { string buffer; buffer.resize(nbytes); StringPiece result; Status status = file_->Read(position, nbytes, &result, reinterpret_cast(&buffer[0])); if (!(status.ok() || errors::IsOutOfRange(status))) { return arrow::Status::IOError(status.error_message()); } buffer.resize(result.size()); return arrow::Buffer::FromString(std::move(buffer)); } private: tensorflow::RandomAccessFile* file_; int64 size_; int64 position_; }; } // namespace data } // namespace tensorflow #endif // TENSORFLOW_IO_ARROW_KERNELS_H_ ================================================ FILE: monolith/native_training/data/kernels/internal/cache_mgr.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/data/kernels/internal/cache_mgr.h" #include #include #include #include "monolith/native_training/data/training_instance/cc/pb_variant.h" #include "third_party/nlohmann/json.hpp" using json = nlohmann::json; using LineId = ::idl::matrix::proto::LineId; using Action = google::protobuf::RepeatedField; using Example = ::monolith::io::proto::Example; using EFeature = ::monolith::io::proto::NamedFeature; using ChannelCache = ::monolith::io::proto::ChannelCache; namespace tensorflow { namespace monolith_tf { namespace internal { std::shared_ptr MakeItemFeaturesFromProto( const ::monolith::io::proto::FeatureData& feature_data) { std::shared_ptr item_feature_ptr = std::make_shared(); item_feature_ptr->item_id = feature_data.gid(); for (const auto& fid : feature_data.fids()) { item_feature_ptr->fids.push_back(fid); } for (const auto& fc : feature_data.feature_columns()) { item_feature_ptr->example_features.emplace(fc.name(), fc); } return item_feature_ptr; } bool ItemFeatures::Equal(const ItemFeatures& other) const { if (item_id != other.item_id) { return false; } if (fids.size() != other.fids.size()) { return false; } else { std::unordered_set this_fids(fids.begin(), fids.end()); std::unordered_set other_fids(other.fids.begin(), other.fids.end()); if (this_fids.size() != other_fids.size()) { return false; } else { std::set intersection; std::set_intersection(this_fids.begin(), this_fids.end(), other_fids.begin(), other_fids.end(), std::inserter(intersection, intersection.begin())); if (this_fids.size() != intersection.size()) { return false; } } } if (example_features.size() != other.example_features.size()) { for (const auto& it : example_features) { if (other.example_features.count(it.first) == 0) { return false; } else { auto this_feat = it.second.SerializeAsString(); auto other_feat = other.example_features.at(it.first).SerializeAsString(); if (this_feat != other_feat) { return false; } } } } return true; } CacheWithGid::CacheWithGid(int max_item_num, int start_num) : start_num_(start_num), max_item_num_(max_item_num) {} void CacheWithGid::Push(uint64_t item_id, std::shared_ptr item, int64_t origin_cnt, int64_t sample_cnt) { auto it = data_.find(item_id); if (it == data_.end()) { data_queue_.emplace_back(item_id); data_.emplace(item_id, item); } auto iit = stats_.find(item_id); if (iit == stats_.end()) { auto stats_ptr = std::make_shared(); stats_ptr->origin_cnt = origin_cnt; stats_ptr->sample_cnt = sample_cnt; stats_.emplace(item_id, stats_ptr); } else { iit->second->origin_cnt += origin_cnt; iit->second->sample_cnt += sample_cnt; } if ((int64_t)data_queue_.size() > max_item_num_) { uint64_t item_id = data_queue_.front(); data_.erase(item_id); stats_.erase(item_id); data_queue_.pop_front(); } } std::shared_ptr CacheWithGid::RandomSelectOne( double* freq_factor, double* time_factor) const { if ((int64_t)data_queue_.size() <= start_num_) { return nullptr; } thread_local std::mt19937 gen((std::random_device())()); size_t index = gen() % data_queue_.size(); uint64_t item_id = data_queue_[index]; auto it = data_.find(item_id); if (it != data_.end()) { *freq_factor = 1.0 / ++(stats_[item_id]->sample_cnt); *time_factor = (index + 1.0) / data_queue_.size(); return it->second; } else { LOG_EVERY_N_SEC(ERROR, 1) << "item_id " << item_id << "in queue but not in map"; } return nullptr; } void CacheWithGid::ToProto(ChannelCache* proto) const { for (auto it : data_) { auto* feature_data = proto->add_feature_datas(); feature_data->set_gid(it.first); // gid for (const auto& fid : it.second->fids) { feature_data->add_fids(fid); } const auto& example_features = it.second->example_features; // std::shared_ptr for (const auto& fc_it : example_features) { auto* feature_columns = feature_data->add_feature_columns(); feature_columns->CopyFrom(fc_it.second); } const auto& stats = stats_[it.first]; feature_data->set_origin_cnt(stats->origin_cnt); feature_data->set_sample_cnt(stats->sample_cnt); } LOG_EVERY_N(INFO, 1000) << "save size " << data_queue_.size() << " " << data_.size(); } void CacheWithGid::FromProto(const ChannelCache& proto) { data_queue_.clear(); data_.clear(); for (int i = 0; i < proto.feature_datas_size(); ++i) { const auto& feature_data = proto.feature_datas(i); auto gid = feature_data.gid(); data_queue_.emplace_back(gid); auto group_feature_ptr = std::make_shared(); group_feature_ptr->item_id = gid; for (const auto& fid : feature_data.fids()) { group_feature_ptr->fids.push_back(fid); } for (const auto& fc : feature_data.feature_columns()) { group_feature_ptr->example_features.emplace(fc.name(), fc); } data_.emplace(gid, group_feature_ptr); std::shared_ptr stats = std::make_shared(); stats->origin_cnt = feature_data.origin_cnt(); stats->sample_cnt = feature_data.sample_cnt(); stats_[gid] = stats; } LOG_EVERY_N(INFO, 1000) << "restore size " << data_queue_.size() << " " << data_.size(); } bool CacheWithGid::Equal(const CacheWithGid& other) const { if (start_num_ != other.start_num_) { return false; } if (max_item_num_ != other.max_item_num_) { return false; } if (stats_.size() != other.stats_.size()) { return false; } else { for (const auto& it : stats_) { auto oit = other.stats_.find(it.first); if (oit == other.stats_.end()) { return false; } else { if (it.second->Equal(*oit->second.get())) { return false; } } } } if (data_.size() != other.data_.size()) { return false; } else { for (const auto& it : data_) { if (other.data_.count(it.first) == 0) { return false; } else { return it.second->Equal(*other.data_.at(it.first).get()); } } } return true; } CacheManager::CacheManager(int max_item_num_per_channel, int start_num) : start_num_(start_num), max_item_num_per_channel_(max_item_num_per_channel) {} std::shared_ptr CacheManager::RandomSelectOne( uint64_t channel_id, double* freq_factor, double* time_factor) const { auto it = channel_cache_.find(channel_id); if (it != channel_cache_.end()) { return it->second.RandomSelectOne(freq_factor, time_factor); } return nullptr; } void CacheManager::Push(uint64_t channel_id, uint64_t item_id, const std::shared_ptr& item, int64_t origin_cnt, int64_t sample_cnt) { auto it = channel_cache_.find(channel_id); if (it == channel_cache_.end()) { LOG(INFO) << "Create channel(" << channel_id << ") in ItemPoolResource CacheManager"; auto ret = channel_cache_.emplace( channel_id, CacheWithGid(max_item_num_per_channel_, start_num_)); it = ret.first; } it->second.Push(item_id, item, origin_cnt, sample_cnt); } absl::flat_hash_map& CacheManager::GetCache() { return channel_cache_; } void CacheManager::SampleChannelID(uint64_t* channel_id) { std::vector channel_ids; std::vector cache_size; if (channel_cache_.size() >= 2) { for (auto iter = channel_cache_.begin(); iter != channel_cache_.end(); ++iter) { channel_ids.emplace_back(iter->first); cache_size.emplace_back(iter->second.Size()); } std::discrete_distribution discrete_dist(cache_size.begin(), cache_size.end()); std::mt19937 gen(std::random_device{}()); *channel_id = channel_ids[discrete_dist(gen)]; } } } // namespace internal } // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/data/kernels/internal/cache_mgr.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_MONOLITH_NATIVE_TRAINING_DATA_KERNELS_CACHE_MGR_H_ #define MONOLITH_MONOLITH_NATIVE_TRAINING_DATA_KERNELS_CACHE_MGR_H_ #include #include #include #include #include #include #include "absl/container/flat_hash_map.h" #include "monolith/native_training/data/training_instance/cc/pb_variant.h" #include "tensorflow/core/platform/env.h" namespace tensorflow { namespace monolith_tf { namespace internal { struct GroupStat { uint32_t origin_cnt = 0; uint32_t sample_cnt = 0; inline bool operator==(const GroupStat &rhs) const { return origin_cnt == rhs.origin_cnt && sample_cnt == rhs.sample_cnt; } inline bool Equal(const GroupStat &rhs) const { return origin_cnt == rhs.origin_cnt && sample_cnt == rhs.sample_cnt; } }; struct ItemFeatures { uint64_t item_id; std::vector fids; absl::flat_hash_map example_features; bool Equal(const ItemFeatures &other) const; }; std::shared_ptr MakeItemFeaturesFromProto( const ::monolith::io::proto::FeatureData &feature_data); class CacheWithGid { public: explicit CacheWithGid(int max_item_num, int start_num = 0); void Push(uint64_t item_id, std::shared_ptr item, int64_t origin_cnt = 1, int64_t sample_cnt = 0); std::shared_ptr RandomSelectOne( double *freq_factor, double *time_factor) const; void ToProto(::monolith::io::proto::ChannelCache *proto) const; void FromProto(const ::monolith::io::proto::ChannelCache &proto); bool Equal(const CacheWithGid &other) const; inline int Size() const { return data_queue_.size(); } private: int start_num_; int max_item_num_; absl::flat_hash_map> data_; mutable absl::flat_hash_map> stats_; std::deque data_queue_; }; class CacheManager { public: explicit CacheManager(int max_item_num_per_channel, int start_num = 0); std::shared_ptr RandomSelectOne( uint64_t channel_id, double *freq_factor, double *time_factor) const; void Push(uint64_t channel_id, uint64_t item_id, const std::shared_ptr &item, int64_t origin_cnt = 1, int64_t sample_cnt = 0); absl::flat_hash_map &GetCache(); void SampleChannelID(uint64_t* channel_id); private: int start_num_; int max_item_num_per_channel_; absl::flat_hash_map channel_cache_; }; } // namespace internal } // namespace monolith_tf } // namespace tensorflow #endif // MONOLITH_MONOLITH_NATIVE_TRAINING_DATA_KERNELS_CACHE_MGR_H_ ================================================ FILE: monolith/native_training/data/kernels/internal/cache_mgr_test.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/data/kernels/internal/cache_mgr.h" #include #include "gtest/gtest.h" #include "absl/strings/str_cat.h" using NamedFeature = ::monolith::io::proto::NamedFeature; using ChannelCache = ::monolith::io::proto::ChannelCache; namespace tensorflow { namespace monolith_tf { namespace internal { namespace { static constexpr uint64_t MASK = (1L << 48) - 1; void gen_named_feature(NamedFeature *nf) { int slot = std::rand() % 1024; nf->set_name(absl::StrCat("fc_", slot)); auto *fid_v2_list = nf->mutable_feature()->mutable_fid_v2_list(); int num_fids = std::abs(std::rand() % 20) + 1; for (int i = 0; i < num_fids; ++i) { uint64_t fid = ((uint64_t)slot << 48) | ((std::rand() % 100000) & MASK); fid_v2_list->add_value(fid); } } void gen_item_features(ItemFeatures *item) { int num_feats = std::abs(std::rand() % 20) + 1; for (int i = 0; i < num_feats; ++i) { NamedFeature nf; gen_named_feature(&nf); if (!item->example_features.contains(nf.name())) { item->example_features.insert({nf.name(), nf}); } } } void fill_cache_with_gid(CacheWithGid *cwg) { for (int i = 0; i < 80; ++i) { std::shared_ptr item = std::make_shared(); gen_item_features(item.get()); int gid = std::abs(std::rand() % 1024) + 1; cwg->Push(gid, item); } } TEST(CACHE_MGR, CacheWithGid) { CacheWithGid cwg(100, 20); fill_cache_with_gid(&cwg); ChannelCache cache; cwg.ToProto(&cache); CacheWithGid cwg2(100, 20); cwg2.FromProto(cache); } TEST(CACHE_MGR, CacheManager) { CacheManager cm(1000, 20); for (int i = 0; i < 50; ++i) { const std::shared_ptr item = std::make_shared(); gen_item_features(item.get()); int gid = std::abs(std::rand() % 1024) + 1; cm.Push(1, gid, item); } EXPECT_EQ(cm.GetCache().size(), 1); } } // namespace } // namespace internal } // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/data/kernels/internal/datasource_utils.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/data/kernels/internal/datasource_utils.h" namespace tensorflow { namespace monolith_tf { namespace internal { int32_t java_hash_code(const std::string &data_flow_name) { int32_t h = 0; if (h == 0 && data_flow_name.length() > 0) { for (uint32_t i = 0; i < data_flow_name.length(); ++i) { h = 31 * h + data_flow_name.at(i); } } return h; } } // namespace internal } // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/data/kernels/internal/datasource_utils.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_MONOLITH_NATIVE_TRAINING_DATA_KERNELS_INTERNAL_DATASOURCE_UTILS_H_ #define MONOLITH_MONOLITH_NATIVE_TRAINING_DATA_KERNELS_INTERNAL_DATASOURCE_UTILS_H_ #include namespace tensorflow { namespace monolith_tf { namespace internal { int32_t java_hash_code(const std::string &data_flow_name); } // namespace internal } // namespace monolith_tf } // namespace tensorflow #endif // MONOLITH_MONOLITH_NATIVE_TRAINING_DATA_KERNELS_INTERNAL_DATASOURCE_UTILS_H_ ================================================ FILE: monolith/native_training/data/kernels/internal/datasource_utils_test.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/data/kernels/internal/datasource_utils.h" #include #include "gtest/gtest.h" namespace tensorflow { namespace monolith_tf { namespace internal { namespace { TEST(DatasourceUtils, JavaHashCode) { int32_t code = java_hash_code("datasource_inst"); EXPECT_EQ(code, -1487072768); } } // namespace } // namespace internal } // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/data/kernels/internal/file_match_split_provider.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/data/kernels/internal/file_match_split_provider.h" #include #include #include #include "absl/memory/memory.h" #include "absl/strings/str_join.h" #include "absl/strings/str_split.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/platform/default/logging.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/errors.h" static constexpr char kCurrentPat[] = "current_pattern"; static constexpr char kCurrentFile[] = "current_file"; static constexpr char kQueueContent[] = "queue_content"; using std::chrono::milliseconds; namespace tensorflow { namespace data { namespace monolith_tf { Status FileMatchSplitProvider::GetNext(Tensor *split, bool *end_of_splits) { mutex_lock l(mu_); if (!feeder_) { TF_RETURN_IF_ERROR(EnsureFeederInitialized()); LOG(INFO) << "EnsureFeederInitialized Done!"; } *end_of_splits = false; *split = Tensor(DT_STRING, TensorShape{}); if (canceled_) { *end_of_splits = true; return errors::Cancelled( "FileMatchSplitProvider canceled, get an end_of_splits!"); } std::string item; while (!results_.try_pop(item, milliseconds(10))) { if (finished_feed_ && results_.empty()) { *end_of_splits = true; std::string info = absl::StrCat( "finished_feed is ", finished_feed_.load(), ", and results empty is ", results_.empty(), ", get an end_of_splits!"); LOG_EVERY_N_SEC(INFO, 300) << info; return errors::OutOfRange(info); } } split->scalar()() = item; return Status::OK(); } Status FileMatchSplitProvider::Reset() { mutex_lock l(mu_); // ensure feeder thread join canceled_ = true; finished_feed_ = true; feeder_ = nullptr; // clear queue std::string item; while (!results_.empty()) { results_.try_pop(item, milliseconds(1)); } canceled_ = false; finished_feed_ = false; current_pat_ = ""; current_file_ = ""; TF_RETURN_IF_ERROR(EnsureFeederInitialized()); return Status::OK(); } Status FileMatchSplitProvider::Save( std::function key_name_fn, IteratorStateWriter *writer) { TF_RETURN_IF_ERROR( writer->WriteScalar(key_name_fn(kCurrentPat), current_pat_)); TF_RETURN_IF_ERROR( writer->WriteScalar(key_name_fn(kCurrentFile), current_file_)); std::vector content; while (!results_.empty()) { std::string item; results_.pop(item); content.push_back(item); } TF_RETURN_IF_ERROR( writer->WriteScalar(key_name_fn(kQueueContent), absl::StrJoin(content.begin(), content.end(), ","))); return Status::OK(); } Status FileMatchSplitProvider::Restore( std::function key_name_fn, IteratorStateReader *reader) { canceled_ = false; finished_feed_ = false; tstring current_pat, current_file, content_str; TF_RETURN_IF_ERROR( reader->ReadScalar(key_name_fn(kCurrentPat), ¤t_pat)); current_pat_ = std::string(current_pat); TF_RETURN_IF_ERROR( reader->ReadScalar(key_name_fn(kCurrentFile), ¤t_file)); current_file_ = std::string(current_file); TF_RETURN_IF_ERROR( reader->ReadScalar(key_name_fn(kQueueContent), &content_str)); std::vector content_list = absl::StrSplit(absl::string_view(content_str), ','); for (const std::string &item : content_list) { results_.push(item); } return Status::OK(); } Status FileMatchSplitProvider::EnsureFeederInitialized() { finished_feed_ = false; feeder_ = absl::WrapUnique(Env::Default()->StartThread( {}, "file-match-split-provider-feeder", [this]() { FeederThread(); })); return Status::OK(); } void FileMatchSplitProvider::FeederThread() { LOG(INFO) << "thread file-match-split-provider-feeder started!"; int max_retry = 5, current_try = 0; const auto timeout = milliseconds(10); Env *env = Env::Default(); // find the start point int start = 0; if (!current_pat_.empty()) { for (const std::string &pattern : patterns_) { start++; if (pattern == current_pat_) { break; } } LOG(INFO) << "current_pat is " << current_pat_ << ", start at " << start - 1; } if (start >= patterns_.size() && patterns_.back() != current_pat_) { LOG(WARNING) << "Cannot find " << current_pat_ << " in patterns, skip!"; current_pat_ = ""; start = 0; } // finish the files in current_pat_ if any std::vector matched_files; if (!current_pat_.empty()) { current_try = 0; while (!env->GetMatchingPaths(current_pat_, &matched_files).ok()) { if (canceled_) return; current_try++; matched_files.clear(); std::this_thread::sleep_for(milliseconds(1000)); if (current_try >= max_retry) { LOG(INFO) << "GetMatchingPaths for pattern " << current_pat_ << " fail, retry!"; break; } } int idx = 0; if (!current_file_.empty()) { for (const std::string &file : matched_files) { if (file != current_file_) { idx++; } else { break; } } LOG(INFO) << "current_file is " << current_file_ << ", start at " << idx; } int num_files = 0; for (size_t i = idx; i < matched_files.size(); ++i) { current_file_ = matched_files[i]; while (!results_.try_push(current_file_, timeout)) { if (canceled_) return; std::this_thread::sleep_for(timeout); } num_files++; } LOG_EVERY_N(INFO, 100) << "Pattern " << current_pat_ << " has matched " << num_files << "/" << matched_files.size() << " files"; } // for the patterns after current_pat_ for (size_t i = start; i < patterns_.size(); ++i) { if (canceled_) return; matched_files.clear(); current_pat_ = patterns_[i]; current_try = 0; while (!env->GetMatchingPaths(current_pat_, &matched_files).ok()) { if (canceled_) return; current_try++; matched_files.clear(); std::this_thread::sleep_for(milliseconds(1000)); if (current_try >= max_retry) { LOG(INFO) << "GetMatchingPaths for pattern " << current_pat_ << " fail, retry!"; } } int num_files = 0; for (const std::string &file : matched_files) { current_file_ = file; while (!results_.try_push(current_file_, timeout)) { if (canceled_) return; std::this_thread::sleep_for(timeout); } num_files++; } LOG_EVERY_N(INFO, 100) << "Pattern " << current_pat_ << " has matched " << num_files << "/" << matched_files.size() << " files"; } finished_feed_ = true; LOG(INFO) << "thread file-match-split-provider-feeder finished!"; } } // namespace monolith_tf } // namespace data } // namespace tensorflow ================================================ FILE: monolith/native_training/data/kernels/internal/file_match_split_provider.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_MONOLITH_NATIVE_TRAINING_DATA_KERNELS_INTERNAL_FILE_MATCH_SPLIT_PROVIDER_H_ #define MONOLITH_MONOLITH_NATIVE_TRAINING_DATA_KERNELS_INTERNAL_FILE_MATCH_SPLIT_PROVIDER_H_ #include #include "monolith/native_training/runtime/concurrency/queue.h" #include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/platform/status.h" namespace tensorflow { namespace data { namespace monolith_tf { // SplitProvider which reads splits from a tf.data service dispatcher over RPC. class FileMatchSplitProvider : public SplitProvider { public: explicit FileMatchSplitProvider(const std::vector& patterns, int queue_size = 1024) : canceled_(false), finished_feed_(false), patterns_(patterns), results_(queue_size) {} Status GetNext(Tensor* split, bool* end_of_splits) override; Status Reset() override; Status Save(std::function full_name, IteratorStateWriter* writer) override; Status Restore(std::function full_name, IteratorStateReader* reader) override; private: mutex mu_; std::atomic canceled_; std::atomic finished_feed_; std::string current_pat_ = ""; std::string current_file_ = ""; const std::vector patterns_; ::monolith::concurrency::Queue results_; std::unique_ptr feeder_; Status EnsureFeederInitialized(); void FeederThread(); }; } // namespace monolith_tf } // namespace data } // namespace tensorflow #endif // MONOLITH_MONOLITH_NATIVE_TRAINING_DATA_KERNELS_INTERNAL_FILE_MATCH_SPLIT_PROVIDER_H_ ================================================ FILE: monolith/native_training/data/kernels/internal/file_match_split_provider_test.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/data/kernels/internal/file_match_split_provider.h" #include #include #include #include #include "absl/strings/str_cat.h" #include "gtest/gtest.h" using std::chrono::milliseconds; namespace tensorflow { namespace data { namespace monolith_tf { namespace { TEST(FileMatchSplitProvider, Create) { char tmp[256]; getcwd(tmp, 256); std::vector patterns = { absl::StrCat(tmp, "/monolith/native_training/data/kernels/*.h"), absl::StrCat(tmp, "/monolith/native_training/data/kernels/*.cc")}; FileMatchSplitProvider split_provider(patterns); Tensor split; bool end_of_splits = false; int cnt = 0; while (!end_of_splits) { Status s = split_provider.GetNext(&split, &end_of_splits); if (!s.ok() || end_of_splits) { return; } else { cnt++; LOG(INFO) << split.scalar()(); } } EXPECT_GE(cnt, 0); } TEST(FileMatchSplitProvider, Reset) { char tmp[256]; getcwd(tmp, 256); std::vector patterns = { absl::StrCat(tmp, "/monolith/native_training/data/kernels/internal/*")}; FileMatchSplitProvider split_provider(patterns); Status s; Tensor split; bool end_of_splits = false; s.Update(split_provider.GetNext(&split, &end_of_splits)); split_provider.Reset(); int cnt = 0; end_of_splits = false; while (!end_of_splits) { s.Update(split_provider.GetNext(&split, &end_of_splits)); if (!s.ok() || end_of_splits) { return; } else { cnt++; LOG(INFO) << split.scalar()(); } } EXPECT_GE(cnt, 1); } } // namespace } // namespace monolith_tf } // namespace data } // namespace tensorflow ================================================ FILE: monolith/native_training/data/kernels/internal/label_utils.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/data/kernels/internal/label_utils.h" #include #include #include #include #include "absl/strings/numbers.h" #include "absl/strings/str_split.h" #include "glog/logging.h" namespace tensorflow { namespace monolith_tf { namespace internal { using LabelConf = ::monolith::native_training::data::config::LabelConf; bool HasIntersection(const std::set &lhs, const std::set &rhs) { std::set intersection; std::set_intersection(lhs.begin(), lhs.end(), rhs.begin(), rhs.end(), std::inserter(intersection, intersection.begin())); return !intersection.empty(); } bool ParseTaskConfig(const std::string &config, std::vector *task_configs) { task_configs->clear(); LabelConf label_conf; if (!label_conf.ParseFromString(config)) { LOG(FATAL) << "Parse label config error: " << config; } CHECK_GT(label_conf.conf_size(), 0); task_configs->reserve(label_conf.conf_size()); for (const auto &t : label_conf.conf()) { // pos_actions : neg_actions : sample_rate std::set pos_actions, neg_actions; CHECK(!t.pos_actions().empty()); pos_actions.insert(t.pos_actions().begin(), t.pos_actions().end()); if (!t.neg_actions().empty()) { neg_actions.insert(t.neg_actions().begin(), t.neg_actions().end()); } CHECK(!HasIntersection(pos_actions, neg_actions)); float sample_rate = t.sample_rate(); CHECK_GE(sample_rate, 0); CHECK_LE(sample_rate, 1.0); task_configs->push_back({pos_actions, neg_actions, sample_rate}); } return true; } } // namespace internal } // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/data/kernels/internal/label_utils.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_MONOLITH_NATIVE_TRAINING_DATA_KERNELS_INTERNAL_LABEL_UTILS_H_ #define MONOLITH_MONOLITH_NATIVE_TRAINING_DATA_KERNELS_INTERNAL_LABEL_UTILS_H_ #include #include #include #include #include "monolith/native_training/data/data_op_config.pb.h" #include "third_party/nlohmann/json.hpp" namespace tensorflow { namespace monolith_tf { namespace internal { constexpr float INVALID_LABEL = std::numeric_limits::lowest(); constexpr float POSITIVE_LABEL = 1.0; struct TaskConfig { std::set pos_actions; std::set neg_actions; float sample_rate; std::string ToString() const { nlohmann::json j; j["pos_actions"] = pos_actions; j["neg_actions"] = neg_actions; j["sample_rate"] = sample_rate; return j.dump(2); } }; bool HasIntersection(const std::set &lhs, const std::set &rhs); bool ParseTaskConfig(const std::string &config, std::vector *task_configs); } // namespace internal } // namespace monolith_tf } // namespace tensorflow #endif // MONOLITH_MONOLITH_NATIVE_TRAINING_DATA_KERNELS_INTERNAL_ADD_LABEL_UTILS_H_ ================================================ FILE: monolith/native_training/data/kernels/internal/label_utils_test.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/data/kernels/internal/label_utils.h" #include #include "gtest/gtest.h" namespace tensorflow { namespace monolith_tf { namespace internal { namespace { using LabelConf = ::monolith::native_training::data::config::LabelConf; TEST(LabelUtils, HasIntersection) { std::set lhs = {1, 2, 3}, rhs1 = {3, 4, 5}, rhs2 = {}; EXPECT_TRUE(HasIntersection(lhs, rhs1)); EXPECT_FALSE(HasIntersection(lhs, rhs2)); } TEST(LabelUtils, ParseTaskConfigBasic) { LabelConf label_conf; auto *task_conf = label_conf.add_conf(); task_conf->add_pos_actions(-7); task_conf->add_pos_actions(-9); task_conf->add_neg_actions(-41); task_conf->set_sample_rate(0.5f); task_conf = label_conf.add_conf(); task_conf->add_pos_actions(75); task_conf->add_pos_actions(-103); task_conf->add_pos_actions(74); task_conf->add_neg_actions(-41); task_conf->set_sample_rate(1.0f); task_conf = label_conf.add_conf(); task_conf->add_pos_actions(101); task_conf->add_pos_actions(102); task_conf->set_sample_rate(1.0f); std::string config; label_conf.SerializeToString(&config); std::vector task_configs; ParseTaskConfig(config, &task_configs); EXPECT_EQ(task_configs.size(), 3); std::set pos_actions0 = {-7, -9}, pos_actions1 = {-103, 74, 75}, pos_actions2 = {101, 102}; std::set neg_actions0 = {-41}, neg_actions1 = {-41}, neg_actions2 = {}; EXPECT_EQ(task_configs[0].pos_actions, pos_actions0); EXPECT_EQ(task_configs[1].pos_actions, pos_actions1); EXPECT_EQ(task_configs[2].pos_actions, pos_actions2); EXPECT_EQ(task_configs[0].neg_actions, neg_actions0); EXPECT_EQ(task_configs[1].neg_actions, neg_actions1); EXPECT_EQ(task_configs[2].neg_actions, neg_actions2); EXPECT_EQ(task_configs[0].sample_rate, 0.5); EXPECT_EQ(task_configs[1].sample_rate, 1.0); EXPECT_EQ(task_configs[2].sample_rate, 1.0); } } // namespace } // namespace internal } // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/data/kernels/internal/parquet_column_buffer.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef PARQUET_COLUMN_BUFFER_H_ #define PARQUET_COLUMN_BUFFER_H_ #include "parquet/api/reader.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/profiler/lib/traceme.h" namespace tensorflow { namespace data { class ColumnBuffer { public: explicit ColumnBuffer(std::shared_ptr& col_reader) : col_reader_(col_reader), buffer_limit_(0), values_limit_(0), values_p_(0), levels_p_(0) { max_definition_level_ = col_reader->descr()->max_definition_level(); max_repetition_level_ = col_reader->descr()->max_repetition_level(); } virtual ~ColumnBuffer() = default; protected: std::shared_ptr col_reader_; std::unique_ptr def_levels_buffer_; std::unique_ptr rep_levels_buffer_; int64_t buffer_limit_; int64_t values_limit_; int64_t values_p_; int64_t levels_p_; const int64_t BUFFER_SIZE = 256; int16_t max_definition_level_; int16_t max_repetition_level_; }; template class TypedColumnBuffer : public ColumnBuffer { public: typedef typename DType::c_type T; explicit TypedColumnBuffer(std::shared_ptr& col_reader) : ColumnBuffer(col_reader) { typed_col_reader_ = static_cast*>(col_reader.get()); value_buffer_.reset(new T[BUFFER_SIZE]); def_levels_buffer_.reset(new int16_t[BUFFER_SIZE]); rep_levels_buffer_.reset(new int16_t[BUFFER_SIZE]); } Status GetNextValues(std::vector& values) { T value; int16_t def_value, rep_level; bool is_null; if (max_repetition_level_ == 0 && max_definition_level_ == 0) { TF_RETURN_IF_ERROR( ReadNextValue(&value, &def_value, &rep_level, &is_null)); values.push_back(value); } else if (max_repetition_level_ == 0 && max_definition_level_ > 0) { TF_RETURN_IF_ERROR( ReadNextValue(&value, &def_value, &rep_level, &is_null)); if (!is_null) { values.push_back(value); } } else { do { ReadNextValue(&value, &def_value, &rep_level, &is_null); // debug use // LOG(INFO) << "In GetNextValues " << def_value << " " << rep_level << // " " << is_null; if (is_null) { break; } values.push_back(value); } while (HasNextRepeatedValue()); } return Status::OK(); } bool HasNextRepeatedValue() { if (levels_p_ >= buffer_limit_) { Status status = ReadBuffer(); if (!status.ok()) { return false; } } return rep_levels_buffer_[levels_p_] ? true : false; } Status ReadNextValue(T* out, int16_t* def_level, int16_t* rep_level, bool* is_null) { if (max_repetition_level_ == 0 && max_definition_level_ == 0) { // required column while (values_p_ >= values_limit_) { TF_RETURN_IF_ERROR(ReadBuffer()); } *is_null = false; *out = value_buffer_[values_p_++]; } else if (max_repetition_level_ == 0 && max_definition_level_ > 0) { // optional column while (levels_p_ >= buffer_limit_) { TF_RETURN_IF_ERROR(ReadBuffer()); } *def_level = def_levels_buffer_[levels_p_]; if (*def_level == 0) { *is_null = true; } else { *is_null = false; if (values_p_ >= values_limit_) { return errors::InvalidArgument("No extra values in buffer."); } *out = value_buffer_[values_p_++]; } levels_p_++; } else { // repeated column while (levels_p_ >= buffer_limit_) { TF_RETURN_IF_ERROR(ReadBuffer()); } *def_level = def_levels_buffer_[levels_p_]; *rep_level = rep_levels_buffer_[levels_p_]; if ((*def_level != max_definition_level_) || (*def_level == 0 && *rep_level == 0)) { *is_null = true; } else { *is_null = false; if (values_p_ >= values_limit_) { return errors::InvalidArgument("No extra values in buffer."); } *out = value_buffer_[values_p_++]; } levels_p_++; } return Status::OK(); } Status ReadBuffer() { profiler::TraceMe activity([]() { return "ParquetDatasetOp::ReadBuffer"; }); if (!typed_col_reader_->HasNext()) { return errors::OutOfRange("Column values all consumed, out of range"); } int64_t values_read; int64_t levels_read = typed_col_reader_->ReadBatch( BUFFER_SIZE, def_levels_buffer_.get(), rep_levels_buffer_.get(), value_buffer_.get(), &values_read); buffer_limit_ = levels_read; values_limit_ = values_read; values_p_ = 0; levels_p_ = 0; return Status::OK(); } private: parquet::TypedColumnReader* typed_col_reader_; std::unique_ptr value_buffer_; }; } // namespace data } // namespace tensorflow #endif // PARQUET_COLUMN_BUFFER_H_ ================================================ FILE: monolith/native_training/data/kernels/internal/parquet_example_reader.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef PARQUET_EXAMPLE_READER_H_ #define PARQUET_EXAMPLE_READER_H_ #include #include #include "absl/strings/ascii.h" #include "absl/strings/str_split.h" #include "idl/matrix/proto/example.pb.h" #include "monolith/native_training/data/kernels/internal/arrow_random_access_file.h" #include "monolith/native_training/data/kernels/internal/parquet_column_buffer.h" #include "monolith/native_training/data/kernels/internal/sized_random_access_file.h" #include "parquet/api/reader.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/platform/stacktrace.h" #include "tensorflow/core/profiler/lib/traceme.h" namespace tensorflow { namespace data { using idl::matrix::proto::LineId; using monolith::io::proto::BytesList; using monolith::io::proto::DoubleList; using monolith::io::proto::Example; using monolith::io::proto::ExampleBatch; using monolith::io::proto::Feature; using monolith::io::proto::FidList; using monolith::io::proto::FloatList; using monolith::io::proto::Int64List; using monolith::io::proto::NamedFeature; using monolith::io::proto::NamedFeatureList; enum ParsedDataType { INT = 1, FIDV1 = 2, FIDV2 = 3, FLOAT = 4, BYTES = 5 }; class ParquetExampleReader { public: explicit ParquetExampleReader(Env* env) : env_(env) {} virtual ~ParquetExampleReader() {} Status Init(std::string file_name, const std::vector& selected_col_names, const std::vector& selected_col_types) { // open paruqet file, and hold handler file_.reset(new SizedRandomAccessFile(env_, file_name, nullptr, 0)); TF_RETURN_IF_ERROR(file_->GetFileSize(&file_size_)); parquet_file_.reset(new ArrowRandomAccessFile(file_.get(), file_size_)); parquet_reader_ = parquet::ParquetFileReader::Open(std::move(parquet_file_)); parquet_metadata_ = parquet_reader_->metadata(); // register columns names columns_.clear(); for (int i = 0; i < parquet_metadata_->num_columns(); i++) { const std::string& full_col_name = parquet_metadata_->schema()->Column(i)->path().get()->ToDotString(); std::vector split_result = absl::StrSplit(full_col_name, "."); if (split_result.empty()) { LOG(WARNING) << "Split column full name " << full_col_name << ", get empty result, will skip this column."; continue; } std::string col_name = {split_result[0].data(), split_result[0].size()}; columns_.push_back(col_name); columns_index_map_[col_name] = i; col_pure_name_map_[i] = col_name; } // LOG(INFO) << "parquet file schema: "; // for (int i = 0; i < parquet_metadata_->num_columns(); i++) { // parquet::Type::type col_type = // parquet_metadata_->schema()->Column(i)->physical_type(); // LOG(INFO) << "column " << i << " : " << columns_[i] << " : " // << ColTypeToString(col_type); // } // LOG(INFO) << "end of schema"; LOG(INFO) << "parquet file columns: " << parquet_metadata_->num_columns(); LOG(INFO) << "parquet file rows: " << parquet_metadata_->num_rows(); // select column, and check col type selected_col_ids_.clear(); selected_col_feature_type_.clear(); TF_RETURN_IF_ERROR(SetSelectedCols(selected_col_names, selected_col_types)); // init global iter_ and row_group related variables iter_ = 0; row_group_offset_ = -1; row_group_id_ = -1; row_group_reader_.reset(); TF_RETURN_IF_ERROR(NextRowGroup()); // init line_id descriptor descriptor_ = ::idl::matrix::proto::LineId::GetDescriptor(); reflection_ = ::idl::matrix::proto::LineId::GetReflection(); for (size_t i = 0; i < selected_col_ids_.size(); i++) { int64_t col_id = selected_col_ids_[i]; std::string col_name = col_pure_name_map_[col_id]; const google::protobuf::FieldDescriptor* field_descriptor = GetLineIdFieldByName(col_name); line_id_discriptor_map_[col_id] = field_descriptor; } LOG(INFO) << "Init of ParquetReader Success. file_name = " << file_name; return Status::OK(); } static const char* ColTypeToString(parquet::Type::type type) { switch (type) { case parquet::Type::BOOLEAN: return "BOOLEAN"; case parquet::Type::INT32: return "INT32"; case parquet::Type::INT64: return "INT64"; case parquet::Type::FLOAT: return "FLOAT"; case parquet::Type::DOUBLE: return "DOUBLE"; case parquet::Type::BYTE_ARRAY: return "BYTE_ARRAY"; case parquet::Type::FIXED_LEN_BYTE_ARRAY: return "FIXED_LEN_BYTE_ARRAY"; default: return "UNKNOWN"; } } Status SetSelectedCols(const std::vector& selected_col_names, const std::vector& selected_col_types) { // check size equal if (selected_col_names.size() != selected_col_types.size()) { return errors::InvalidArgument( "list selected_col_names should have the same size as list " "selected_col_types"); } // check column names valid, and not duplicated std::unordered_set selected_col_id_set; for (const std::string& col_name : selected_col_names) { auto it = columns_index_map_.find(col_name); if (it == columns_index_map_.end()) { return errors::InvalidArgument("column name: ", col_name, " not in paruquet schema"); } selected_col_ids_.push_back(it->second); selected_col_id_set.insert(it->second); } if (selected_col_ids_.size() != selected_col_id_set.size()) { return errors::InvalidArgument( "selected_col_names have duplicate columns"); } // check seleced_col_types if vaild, and fill enum values in // selected_col_feature_type_ for (uint64_t i = 0; i < selected_col_ids_.size(); i++) { uint64_t col_id = selected_col_ids_[i]; const std::string& feature_type = selected_col_types[i]; const std::string& col_name = selected_col_names[i]; parquet::Type::type col_type = parquet_metadata_->schema()->Column(col_id)->physical_type(); switch (col_type) { case parquet::Type::INT32: if (feature_type == "int") { selected_col_feature_type_.push_back(ParsedDataType::INT); } else { return errors::InvalidArgument( "invalid selected_col_types, col_name = ", col_name, ", feature type should be int"); } break; case parquet::Type::INT64: if (feature_type == "int") { selected_col_feature_type_.push_back(ParsedDataType::INT); } else if (feature_type == "fid_v1") { selected_col_feature_type_.push_back(ParsedDataType::FIDV1); } else if (feature_type == "fid_v2") { selected_col_feature_type_.push_back(ParsedDataType::FIDV2); } else { return errors::InvalidArgument( "invalid selected_col_types, col_name = ", col_name, ", feature type should in [int, fid_v1, fid_v2]"); } break; case parquet::Type::FLOAT: case parquet::Type::DOUBLE: if (feature_type == "float") { selected_col_feature_type_.push_back(ParsedDataType::FLOAT); } else { return errors::InvalidArgument( "invalid selected_col_types, col_name = ", col_name, ", feature type should be float"); } break; case parquet::Type::BYTE_ARRAY: if (feature_type == "bytes") { selected_col_feature_type_.push_back(ParsedDataType::BYTES); } else { return errors::InvalidArgument( "invalid selected_col_types, col_name = ", col_name, ", feature type should be bytes"); } break; default: return errors::InvalidArgument( "invalid column parquet type col_name = ", col_name, "parquet type is", ColTypeToString(col_type)); } } return Status::OK(); } const google::protobuf::FieldDescriptor* GetLineIdFieldByName( std::string name) { static std::regex reg("__[A-Z_]+__"); bool is_match = std::regex_match(name, reg); if (!is_match) { return nullptr; } std::string subname = name.substr(2, name.length() - 4); std::string lower_subname = absl::AsciiStrToLower(subname); return descriptor_->FindFieldByName(lower_subname); } Status GetNextExample(Example& example) { if (IsEOF()) { return errors::OutOfRange("GetNextExample out of range, iter = ", iter_); } while (iter_ >= row_group_offset_ + row_group_reader_->metadata()->num_rows()) { TF_RETURN_IF_ERROR(NextRowGroup()); } for (size_t i = 0; i < selected_col_ids_.size(); i++) { int64_t col_id = selected_col_ids_[i]; parquet::Type::type col_type = parquet_metadata_->schema()->Column(col_id)->physical_type(); std::string col_name = col_pure_name_map_[col_id]; // if column is __LABEL__ if (col_name == "__LABEL__") { if (col_type == parquet::Type::INT32) { TF_RETURN_IF_ERROR(FillLabel(i, example)); } else if (col_type == parquet::Type::INT64) { TF_RETURN_IF_ERROR(FillLabel(i, example)); } else if (col_type == parquet::Type::FLOAT) { TF_RETURN_IF_ERROR(FillLabel(i, example)); } else if (col_type == parquet::Type::DOUBLE) { TF_RETURN_IF_ERROR(FillLabel(i, example)); } else { LOG(FATAL) << "In parquet: __LABEL__ column has wrong type, pls check"; } continue; } // if column in line_id auto it = line_id_discriptor_map_.find(col_id); const google::protobuf::FieldDescriptor* line_field = (it != line_id_discriptor_map_.end()) ? it->second : nullptr; if (line_field != nullptr) { if (line_field->is_repeated()) { // TODO(libo.bob): will support it later LOG(FATAL) << "Not support repeated line_id field now."; } switch (line_field->cpp_type()) { case google::protobuf::FieldDescriptor::CppType::CPPTYPE_INT32: { CHECK(col_type == parquet::Type::INT32) << "column: " << col_name << " should have the same type as line_id field"; reflection_->SetInt32(example.mutable_line_id(), line_field, GetSingleValue(i)); break; } case google::protobuf::FieldDescriptor::CppType::CPPTYPE_INT64: { CHECK(col_type == parquet::Type::INT64) << "column: " << col_name << " should have the same type as line_id field"; reflection_->SetInt64(example.mutable_line_id(), line_field, GetSingleValue(i)); break; } case google::protobuf::FieldDescriptor::CppType::CPPTYPE_UINT32: { CHECK(col_type == parquet::Type::INT32) << "column: " << col_name << " should have the same type as line_id field"; reflection_->SetUInt32(example.mutable_line_id(), line_field, GetSingleValue(i)); break; } case google::protobuf::FieldDescriptor::CppType::CPPTYPE_UINT64: { CHECK(col_type == parquet::Type::INT64) << "column: " << col_name << " should have the same type as line_id field"; reflection_->SetUInt64(example.mutable_line_id(), line_field, GetSingleValue(i)); break; } case google::protobuf::FieldDescriptor::CppType::CPPTYPE_FLOAT: { CHECK(col_type == parquet::Type::FLOAT) << "column: " << col_name << " should have the same type as line_id field"; reflection_->SetFloat(example.mutable_line_id(), line_field, GetSingleValue(i)); break; } case google::protobuf::FieldDescriptor::CppType::CPPTYPE_DOUBLE: { CHECK(col_type == parquet::Type::DOUBLE) << "column: " << col_name << " should have the same type as line_id field"; reflection_->SetDouble(example.mutable_line_id(), line_field, GetSingleValue(i)); break; } case google::protobuf::FieldDescriptor::CppType::CPPTYPE_STRING: { CHECK(col_type == parquet::Type::BYTE_ARRAY) << "column: " << col_name << " should have the same type as line_id field"; std::string value = ByteArrayToString(GetSingleValue(i)); reflection_->SetString(example.mutable_line_id(), line_field, value); break; } default: LOG(FATAL) << "not support line_id type for column " << col_name; } continue; } NamedFeature* named_feature = example.add_named_feature(); named_feature->set_id(col_id + 10000); named_feature->set_name(col_name); Feature* feature = named_feature->mutable_feature(); switch (col_type) { case parquet::Type::INT32: { TF_RETURN_IF_ERROR(FillValueList( i, feature->mutable_int64_list())); break; } case parquet::Type::INT64: { if (selected_col_feature_type_[i] == ParsedDataType::INT) { TF_RETURN_IF_ERROR(FillValueList( i, feature->mutable_int64_list())); } else if (selected_col_feature_type_[i] == ParsedDataType::FIDV1) { TF_RETURN_IF_ERROR(FillValueList( i, feature->mutable_fid_v1_list())); } else { TF_RETURN_IF_ERROR(FillValueList( i, feature->mutable_fid_v2_list())); } break; } case parquet::Type::FLOAT: { TF_RETURN_IF_ERROR(FillValueList( i, feature->mutable_float_list())); break; } case parquet::Type::DOUBLE: { TF_RETURN_IF_ERROR(FillValueList( i, feature->mutable_float_list())); break; } case parquet::Type::BYTE_ARRAY: { TypedColumnBuffer* typed_col_buf = dynamic_cast*>( col_buffers_[i].get()); std::vector values; TF_RETURN_IF_ERROR(typed_col_buf->GetNextValues(values)); BytesList* bytes_list = feature->mutable_bytes_list(); for (const parquet::ByteArray& value : values) { bytes_list->add_value(ByteArrayToString(value)); } break; } default: return errors::InvalidArgument("not support column type"); } } iter_++; return Status::OK(); } Status GetNextExampleBatch(ExampleBatch& example_batch, int64_t batch_size) { if (IsEOF()) { return errors::OutOfRange("GetNextExampleBatch out of range, iter = ", iter_); } // cread namedfeaturelist(s) { profiler::TraceMe activity( []() { return "ParquetDataset::CreateNamedFeatureLists"; }); for (size_t i = 0; i < selected_col_ids_.size(); i++) { int64_t col_id = selected_col_ids_[i]; const std::string& col_name = col_pure_name_map_[col_id]; NamedFeatureList* named_feature_list = example_batch.add_named_feature_list(); named_feature_list->set_id(col_id); named_feature_list->set_name(col_name); } } // calculate batch_size int64_t rows_to_read_left = iter_ + batch_size >= parquet_metadata_->num_rows() ? parquet_metadata_->num_rows() - iter_ : batch_size; example_batch.set_batch_size(rows_to_read_left); // read features for each column while (rows_to_read_left > 0) { // if need to go to next row group while (iter_ >= row_group_offset_ + row_group_reader_->metadata()->num_rows()) { TF_RETURN_IF_ERROR(NextRowGroup()); } // calculate max rows can read in current row group int64_t rows_in_row_group = iter_ + rows_to_read_left >= row_group_offset_ + row_group_reader_->metadata()->num_rows() ? row_group_offset_ + row_group_reader_->metadata()->num_rows() - iter_ : rows_to_read_left; rows_to_read_left -= rows_in_row_group; // read from current row group for (size_t i = 0; i < selected_col_ids_.size(); i++) { profiler::TraceMe activity( []() { return "ParquetDataset::ReadOneColumnWithBatchSize"; }); int64_t col_id = selected_col_ids_[i]; parquet::Type::type col_type = parquet_metadata_->schema()->Column(col_id)->physical_type(); NamedFeatureList* named_feature_list = example_batch.mutable_named_feature_list(i); for (int64_t ft = 0; ft < rows_in_row_group; ft++) { Feature* feature = named_feature_list->add_feature(); switch (col_type) { case parquet::Type::INT32: { TF_RETURN_IF_ERROR(FillValueList( i, feature->mutable_int64_list())); break; } case parquet::Type::INT64: { if (selected_col_feature_type_[i] == ParsedDataType::INT) { TF_RETURN_IF_ERROR(FillValueList( i, feature->mutable_int64_list())); } else if (selected_col_feature_type_[i] == ParsedDataType::FIDV1) { TF_RETURN_IF_ERROR(FillValueList( i, feature->mutable_fid_v1_list())); } else { TF_RETURN_IF_ERROR(FillValueList( i, feature->mutable_fid_v2_list())); } break; } case parquet::Type::FLOAT: { TF_RETURN_IF_ERROR(FillValueList( i, feature->mutable_float_list())); break; } case parquet::Type::DOUBLE: { TF_RETURN_IF_ERROR(FillValueList( i, feature->mutable_float_list())); break; } case parquet::Type::BYTE_ARRAY: { TypedColumnBuffer* typed_col_buf = dynamic_cast*>( col_buffers_[i].get()); std::vector values; TF_RETURN_IF_ERROR(typed_col_buf->GetNextValues(values)); BytesList* bytes_list = feature->mutable_bytes_list(); for (const parquet::ByteArrayType::c_type& value : values) { bytes_list->add_value(ByteArrayToString(value)); } break; } default: return errors::InvalidArgument("not support column type"); } } } iter_ += rows_in_row_group; } return Status::OK(); } template Status FillValueList(int64_t col_buffer_id, PB_TLIST* value_list) { TypedColumnBuffer* typed_col_buf = dynamic_cast*>( col_buffers_[col_buffer_id].get()); std::vector values; Status status = typed_col_buf->GetNextValues(values); if (!status.ok()) { std::string stack_trace = CurrentStackTrace(); LOG(INFO) << stack_trace; return status; } for (const typename PARQUET_TYPE::c_type& value : values) { value_list->add_value(value); } return Status::OK(); } template Status FillLabel(int64_t col_buffer_id, Example& example) { TypedColumnBuffer* typed_col_buf = dynamic_cast*>( col_buffers_[col_buffer_id].get()); std::vector values; Status status = typed_col_buf->GetNextValues(values); if (!status.ok()) { std::string stack_trace = CurrentStackTrace(); LOG(INFO) << stack_trace; return status; } for (const typename PARQUET_TYPE::c_type& value : values) { example.mutable_label()->Add(static_cast(value)); } return Status::OK(); } template typename PARQUET_TYPE::c_type GetSingleValue(int64_t col_buffer_id) { TypedColumnBuffer* typed_col_buf = dynamic_cast*>( col_buffers_[col_buffer_id].get()); std::vector values; Status status = typed_col_buf->GetNextValues(values); if (!status.ok()) { std::string stack_trace = CurrentStackTrace(); LOG(FATAL) << stack_trace; } if (values.size() != 1) { LOG(FATAL) << "Parquet column id = " << col_buffer_id << ", should have single value for one row, but got " << values.size(); } return values[0]; } Status NextRowGroup() { profiler::TraceMe activity([]() { return "ParquetDataset::NextRowGroup"; }); if (row_group_id_ + 1 >= parquet_metadata_->num_row_groups()) { return errors::OutOfRange("row group out of range"); } if (!row_group_reader_) { // first initialize row_group_reader_ = parquet_reader_->RowGroup(0); row_group_id_ = 0; row_group_offset_ = 0; } else { row_group_offset_ = row_group_offset_ + row_group_reader_->metadata()->num_rows(); row_group_id_++; row_group_reader_ = parquet_reader_->RowGroup(row_group_id_); } // update col_buffers col_buffers_.clear(); for (uint64_t col_id : selected_col_ids_) { std::shared_ptr column_reader = row_group_reader_->Column(col_id); switch (parquet_metadata_->schema()->Column(col_id)->physical_type()) { case parquet::Type::INT32: col_buffers_.emplace_back( new TypedColumnBuffer(column_reader)); break; case parquet::Type::INT64: col_buffers_.emplace_back( new TypedColumnBuffer(column_reader)); break; case parquet::Type::FLOAT: col_buffers_.emplace_back( new TypedColumnBuffer(column_reader)); break; case parquet::Type::DOUBLE: col_buffers_.emplace_back( new TypedColumnBuffer(column_reader)); break; case parquet::Type::BYTE_ARRAY: col_buffers_.emplace_back( new TypedColumnBuffer(column_reader)); break; default: return errors::InvalidArgument("not support column type"); } } return Status::OK(); } bool IsEOF() { // LOG(INFO) << "iter_ = " << iter_; return iter_ >= parquet_metadata_->num_rows(); } private: Env* env_; std::unique_ptr file_; uint64 file_size_; std::unique_ptr parquet_file_; std::string file_name_; std::shared_ptr<::parquet::ParquetFileReader> parquet_reader_; std::shared_ptr<::parquet::FileMetaData> parquet_metadata_; std::vector columns_; std::unordered_map columns_index_map_; std::unordered_map col_pure_name_map_; std::vector selected_col_ids_; std::vector selected_col_feature_type_; // iter_ and row_group variables int64_t iter_; std::shared_ptr row_group_reader_; int64_t row_group_id_; int64_t row_group_offset_; std::vector> col_buffers_; // line_id related const google::protobuf::Descriptor* descriptor_; const google::protobuf::Reflection* reflection_; std::unordered_map line_id_discriptor_map_; }; } // namespace data } // namespace tensorflow #endif // PARQUET_EXAMPLE_READER_H_ ================================================ FILE: monolith/native_training/data/kernels/internal/relational_utils.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_MONOLITH_NATIVE_TRAINING_DATA_KERNELS_INTERNAL_RELATIONAL_UTILS_H_ #define MONOLITH_MONOLITH_NATIVE_TRAINING_DATA_KERNELS_INTERNAL_RELATIONAL_UTILS_H_ #include #include #include #include #include "glog/logging.h" namespace tensorflow { namespace monolith_tf { namespace internal { static const std::string GT = "gt"; static const std::string GE = "ge"; static const std::string EQ = "eq"; static const std::string LT = "lt"; static const std::string LE = "le"; static const std::string NEQ = "neq"; static const std::string BETWEEN = "between"; static const std::string IN = "in"; static const std::string NOT_IN = "not-in"; static const std::unordered_set VALID_OPS = { GT, GE, EQ, LT, LE, NEQ, BETWEEN, IN, NOT_IN}; static const std::unordered_set COMPARE_OPS = {GT, GE, EQ, LT, LE, NEQ, BETWEEN}; template bool compare(const std::string& op, const T1& value, const std::vector& operands) { if (op == GT) { return value > operands[0]; } else if (op == GE) { return value >= operands[0]; } else if (op == EQ) { return value == operands[0]; } else if (op == LT) { return value < operands[0]; } else if (op == LE) { return value <= operands[0]; } else if (op == NEQ) { return value != operands[0]; } else if (op == BETWEEN) { return value >= operands[0] && value < operands[1]; } else { LOG(FATAL) << "Invalid op: " << op; return false; } } template bool contains(const std::string& op, const T1& value, const std::unordered_set& operand_set) { if (op == IN) { return operand_set.count(value); } else if (op == NOT_IN) { return !operand_set.count(value); } else { LOG(FATAL) << "Invalid op: " << op; return false; } } } // namespace internal } // namespace monolith_tf } // namespace tensorflow #endif // MONOLITH_MONOLITH_NATIVE_TRAINING_DATA_KERNELS_INTERNAL_RELATIONAL_UTILS_H_ ================================================ FILE: monolith/native_training/data/kernels/internal/relational_utils_test.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/data/kernels/internal/relational_utils.h" #include #include #include #include "gmock/gmock.h" #include "gtest/gtest.h" namespace tensorflow { namespace monolith_tf { namespace internal { namespace { TEST(RelationalUtils, GT) { EXPECT_TRUE(compare(GT, 1, {0})); EXPECT_TRUE(!compare(GT, 1, {1})); EXPECT_TRUE(!compare(GT, -1, {0LL})); EXPECT_TRUE(compare(GT, std::string("1"), {"0"})); EXPECT_TRUE(!compare(GT, std::string("1"), {"1"})); } TEST(RelationalUtils, GE) { EXPECT_TRUE(compare(GE, 1, {0})); EXPECT_TRUE(compare(GE, 1, {1})); EXPECT_TRUE(!compare(GE, -1, {0LL})); EXPECT_TRUE(compare(GE, std::string("1"), {"0"})); EXPECT_TRUE(compare(GE, std::string("1"), {"1"})); } TEST(RelationalUtils, EQ) { EXPECT_TRUE(!compare(EQ, 1, {0})); EXPECT_TRUE(compare(EQ, 1, {1})); EXPECT_TRUE(!compare(EQ, -1, {0LL})); EXPECT_TRUE(!compare(EQ, std::string("1"), {"0"})); EXPECT_TRUE(compare(EQ, std::string("1"), {"1"})); } TEST(RelationalUtils, LT) { EXPECT_TRUE(!compare(LT, 1, {0})); EXPECT_TRUE(!compare(LT, 1, {1})); EXPECT_TRUE(compare(LT, -1, {0LL})); EXPECT_TRUE(!compare(LT, std::string("1"), {"0"})); EXPECT_TRUE(!compare(LT, std::string("1"), {"1"})); } TEST(RelationalUtils, LE) { EXPECT_TRUE(!compare(LE, 1, {0})); EXPECT_TRUE(compare(LE, 1, {1})); EXPECT_TRUE(compare(LE, -1, {0LL})); EXPECT_TRUE(!compare(LE, std::string("1"), {"0"})); EXPECT_TRUE(compare(LE, std::string("1"), {"1"})); } TEST(RelationalUtils, NEQ) { EXPECT_TRUE(compare(NEQ, 1, {0})); EXPECT_TRUE(!compare(NEQ, 1, {1})); EXPECT_TRUE(compare(NEQ, -1, {0LL})); EXPECT_TRUE(compare(NEQ, std::string("1"), {"0"})); EXPECT_TRUE(!compare(NEQ, std::string("1"), {"1"})); } TEST(RelationalUtils, BETWEEN) { EXPECT_TRUE(!compare(BETWEEN, 1, {0, 1})); EXPECT_TRUE(compare(BETWEEN, 1, {1, 2})); EXPECT_TRUE(!compare(BETWEEN, -1, {0LL, 1LL})); EXPECT_TRUE(!compare(BETWEEN, std::string("1"), {"0", "1"})); EXPECT_TRUE(compare(BETWEEN, std::string("1"), {"1", "2"})); } TEST(RelationalUtils, IN) { EXPECT_TRUE(!contains(IN, 1, {0})); EXPECT_TRUE(contains(IN, 1, {1, 2})); EXPECT_TRUE(!contains(IN, -1, {0LL, 1LL})); EXPECT_TRUE(!contains(IN, std::string("1"), {"0"})); EXPECT_TRUE(contains(IN, std::string("1"), {"1", "2"})); } TEST(RelationalUtils, NOT_IN) { EXPECT_TRUE(contains(NOT_IN, 1, {0})); EXPECT_TRUE(!contains(NOT_IN, 1, {1, 2})); EXPECT_TRUE(contains(NOT_IN, -1, {0LL, 1LL})); EXPECT_TRUE(contains(NOT_IN, std::string("1"), {"0"})); EXPECT_TRUE(!contains(NOT_IN, std::string("1"), {"1", "2"})); } } // namespace } // namespace internal } // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/data/kernels/internal/sized_random_access_file.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #ifndef TENSORFLOW_IO_CORE_KERNELS_STREAM_H_ #define TENSORFLOW_IO_CORE_KERNELS_STREAM_H_ #include "tensorflow/core/lib/io/inputstream_interface.h" #include "tensorflow/core/lib/io/random_inputstream.h" namespace tensorflow { namespace data { // Note: This SizedRandomAccessFile should only lives within Compute() // of the kernel as buffer could be released by outside. class SizedRandomAccessFile : public tensorflow::RandomAccessFile { public: SizedRandomAccessFile(Env* env, const string& filename, const void* optional_memory_buff, const size_t optional_memory_size) : file_(nullptr), size_(optional_memory_size), buff_((const char*)(optional_memory_buff)), size_status_(Status::OK()) { if (size_ == 0) { size_status_ = env->GetFileSize(filename, &size_); if (size_status_.ok()) { size_status_ = env->NewRandomAccessFile(filename, &file_); } } } virtual ~SizedRandomAccessFile() {} Status Read(uint64 offset, size_t n, StringPiece* result, char* scratch) const override { if (file_.get() != nullptr) { return file_.get()->Read(offset, n, result, scratch); } size_t bytes_to_read = 0; if (offset < size_) { bytes_to_read = (offset + n < size_) ? n : (size_ - offset); } if (bytes_to_read > 0) { memcpy(scratch, &buff_[offset], bytes_to_read); } *result = StringPiece(scratch, bytes_to_read); if (bytes_to_read < n) { return errors::OutOfRange("EOF reached"); } return Status::OK(); } Status GetFileSize(uint64* size) { if (size_status_.ok()) { *size = size_; } return size_status_; } private: std::unique_ptr file_; uint64 size_; const char* buff_; Status size_status_; }; } // namespace data } // namespace tensorflow #endif // TENSORFLOW_IO_CORE_KERNELS_STREAM_H_ ================================================ FILE: monolith/native_training/data/kernels/internal/uniq_hashtable.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_NATIVE_TRAINING_DATA_KERNELS_INTERNAL_UNIQ_HASHTABLE_H_ #define MONOLITH_NATIVE_TRAINING_DATA_KERNELS_INTERNAL_UNIQ_HASHTABLE_H_ #include #include #include #include #include #include #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/raw_coding.h" #include "absl/base/macros.h" #include "glog/logging.h" #define MONOLITH_INLINE __attribute__((always_inline)) namespace tensorflow { namespace monolith_tf { class UniqHashTable { using FID = uint64_t; static constexpr uint32_t MIN_BUCKET_CAP = 1u << 10; static constexpr FID EMPTY_FID = static_cast(-1); static constexpr uint32_t ILLEGAL_BUCKET = std::numeric_limits::max(); static constexpr float LOAD_FACTOR = 0.75; struct HTItem { HTItem() = default; HTItem(const FID fid, uint32_t req_id, uint32_t uniq_idx) : fid(fid), req_id(req_id), uniq_idx(uniq_idx) {} HTItem(const HTItem& other) = default; HTItem& operator=(const HTItem& other) = default; bool operator==(const HTItem& other) const { return fid == other.fid && req_id == other.req_id; } bool IsEmpty(uint32_t cur_req_id) const { return req_id != cur_req_id || fid == EMPTY_FID;} FID fid = EMPTY_FID; uint32_t req_id = 0; uint32_t uniq_idx = 0; }; struct HTIdx { HTIdx() = default; HTIdx(uint32_t item_pos, uint32_t insert_pos) : item_pos(item_pos), insert_pos(insert_pos) {} uint32_t item_pos = ILLEGAL_BUCKET; uint32_t insert_pos = ILLEGAL_BUCKET; }; using HTItemPtr = HTItem*; public: UniqHashTable() : capacity_(MIN_BUCKET_CAP), num_elements_(0), cur_req_id_(1) { prob_table_ = CreateProbTable(MIN_BUCKET_CAP); bucket_idx_mask_ = MIN_BUCKET_CAP - 1; expand_threshold_ = static_cast(capacity_ * LOAD_FACTOR); } ~UniqHashTable() { DeleteProbTable(prob_table_, capacity_); } uint32_t UniqFid(const FID fid, const uint32_t uniq_idx) { DCHECK_NE(fid, EMPTY_FID); auto idx = FindPosition(fid, prob_table_, bucket_idx_mask_, cur_req_id_); if (idx.item_pos != ILLEGAL_BUCKET) { DCHECK(idx.insert_pos == ILLEGAL_BUCKET); return prob_table_[idx.item_pos].uniq_idx; } else { DCHECK(idx.insert_pos != ILLEGAL_BUCKET); DCHECK_LE(idx.insert_pos, bucket_idx_mask_); prob_table_[idx.insert_pos] = HTItem(fid, cur_req_id_, uniq_idx); num_elements_++; } MaybeExpand(); return uniq_idx; } void Reset() { num_elements_ = 0; if (++cur_req_id_ == 0) { FillTableWithEmptyItem(prob_table_, capacity_); } // std::cerr << "abcd " << cur_req_id_ << std::endl; } size_t Size() { return static_cast(num_elements_); } size_t Capacity() { return static_cast(capacity_); } private: static MONOLITH_INLINE HTIdx FindPosition(const FID fid, HTItemPtr prob_table, uint32_t bucket_idx_mask, uint32_t req_id) { uint32_t bucket_idx = FidHash(fid) & bucket_idx_mask; while (true) { const auto& item = prob_table[bucket_idx]; if (item == HTItem(fid, req_id, 0)) { return HTIdx(bucket_idx, ILLEGAL_BUCKET); } else if (item.IsEmpty(req_id)) { return HTIdx(ILLEGAL_BUCKET, bucket_idx); } bucket_idx = (bucket_idx + 1) & bucket_idx_mask; } } static MONOLITH_INLINE bool TestEqual(const HTItem& item, const FID fid, uint32_t req_id) { return item.fid == fid && item.req_id == req_id; } static MONOLITH_INLINE HTItemPtr CreateProbTable(uint32_t capacity) { auto* prob_table = malloc(sizeof(HTItem) * capacity); FillTableWithEmptyItem(reinterpret_cast(prob_table), capacity); return reinterpret_cast(prob_table); } static MONOLITH_INLINE void FillTableWithEmptyItem(HTItemPtr prob_table, uint32_t capacity) { DCHECK(!!prob_table); static HTItem empty_item(EMPTY_FID, 0, 0); std::uninitialized_fill_n(prob_table, capacity, empty_item); } void DeleteProbTable(HTItemPtr prob_table, uint32_t capacity) { DCHECK(!!prob_table); if (!std::is_trivial::value) { for (uint32_t i = 0; i < capacity; ++i) { prob_table[i].~HTItem(); } } free(prob_table); } void MaybeExpand() { if (GOOGLE_PREDICT_TRUE(num_elements_ <= expand_threshold_)) { return; } auto new_capacity = capacity_ << 1; auto new_bucket_idx_mask = new_capacity - 1; auto new_prob_table = CreateProbTable(new_capacity); for (uint32_t i = 0; i < capacity_; ++i) { if (prob_table_[i].IsEmpty(cur_req_id_)) { continue; } uint32_t bucket_idx = FidHash(prob_table_[i].fid) & new_bucket_idx_mask; while (!new_prob_table[bucket_idx].IsEmpty(cur_req_id_)) { bucket_idx = (bucket_idx + 1) & new_bucket_idx_mask; } new_prob_table[bucket_idx] = prob_table_[i]; } DeleteProbTable(prob_table_, capacity_); prob_table_ = new_prob_table; capacity_ = new_capacity; expand_threshold_ = static_cast(capacity_ * LOAD_FACTOR); bucket_idx_mask_ = capacity_ - 1; } static MONOLITH_INLINE uint32_t FidHash(const FID fid) { return Hash(reinterpret_cast(&fid), sizeof(FID), 0); } // Copy from tensorflow/tsl/lib/io/cache.cc // Question: 这里怎么引用比较规范? static uint32_t Hash(const char* data, size_t n, uint32_t seed) { // Similar to murmur hash const uint32_t m = 0xc6a4a793; const uint32_t r = 24; const char* limit = data + n; uint32_t h = seed ^ (n * m); // Pick up four bytes at a time while (data + 4 <= limit) { uint32_t w = tensorflow::core::DecodeFixed32(data); data += 4; h += w; h *= m; h ^= (h >> 16); } // Pick up remaining bytes switch (limit - data) { case 3: h += static_cast(data[2]) << 16; ABSL_FALLTHROUGH_INTENDED; case 2: h += static_cast(data[1]) << 8; ABSL_FALLTHROUGH_INTENDED; case 1: h += static_cast(data[0]); h *= m; h ^= (h >> r); break; } return h; } HTItemPtr prob_table_ = nullptr; uint32_t capacity_ = 0; // must be 2 power; uint32_t expand_threshold_ = 0; uint32_t bucket_idx_mask_ = 0; uint32_t num_elements_ = 0; uint32_t cur_req_id_ = 0; TF_DISALLOW_COPY_AND_ASSIGN(UniqHashTable); }; class MultiShardUniqHashTable { using FID = uint64_t; public: MultiShardUniqHashTable() = default; ~MultiShardUniqHashTable() = default; void init(UniqHashTable *uniq_hashtable) { uniq_hashtable_ = uniq_hashtable; } size_t uniq_fid(const FID fid, int shard) { DCHECK_LT(shard, fid_lists_.size()); // store all shards' uniq indices in a single uniq_hashtable_ auto uniq_idx = uniq_hashtable_->UniqFid(fid, fid_lists_[shard].size()); if (uniq_idx == fid_lists_[shard].size()) { fid_lists_[shard].push_back(fid); } return uniq_idx; } int fid_num(size_t shard) const { return static_cast(fid_lists_[shard].size()); } std::vector& fid_list(int shard) { DCHECK_LT(shard, fid_lists_.size()); return fid_lists_[shard]; } void reset() { uniq_hashtable_->Reset(); } void resize(size_t shard_num) { fid_lists_.resize(shard_num); } void reserve(size_t fid_num) { for (auto& fid_list : fid_lists_) { fid_list.reserve(fid_num); } } private: UniqHashTable* uniq_hashtable_; std::vector> fid_lists_; }; } // namespace monolith_tf } // namespace tensorflow #undef MONOLITH_INLINE #endif // MONOLITH_NATIVE_TRAINING_DATA_KERNELS_INTERNAL_UNIQ_HASHTABLE_H_ ================================================ FILE: monolith/native_training/data/kernels/internal/uniq_hashtable_test.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/data/kernels/internal/uniq_hashtable.h" #include #include #include #include #include #include #include "absl/strings/str_cat.h" #include "gtest/gtest.h" using FID = uint64_t; namespace tensorflow { namespace monolith_tf { namespace { class UniqHashTableTest { public: UniqHashTable uniq_hashtable; std::vector fids; int fid_num_ = 0; int fid_range_ = 0; UniqHashTableTest(int fid_num, int fid_range) { fid_num_ = fid_num; fid_range_ = fid_range; } void Reset() { for (size_t fi = 0; fi < fid_num_; ++fi) { fids.resize(fid_num_); fids[fi] = (rand() % fid_range_) << 8; } } void Check() { Reset(); std::unordered_map result; std::vector uniq_fids; for (const auto& fid : fids) { auto uniq_idx1 = result.size(); auto uniq_idx2 = uniq_hashtable.UniqFid(fid, uniq_hashtable.Size()); auto iter = result.find(fid); if (iter != result.end()) { EXPECT_EQ(iter->second, uniq_idx2) << "size: " << uniq_hashtable.Size(); } else { result[fid] = uniq_idx1; EXPECT_EQ(uniq_idx1, uniq_idx2); EXPECT_EQ(uniq_idx2 + 1, uniq_hashtable.Size()); uniq_fids.push_back(fid); } } EXPECT_EQ(uniq_fids.size(), result.size()); EXPECT_EQ(result.size(), uniq_hashtable.Size()); // check no repetition result.clear(); for (const auto& fid : uniq_fids) { EXPECT_EQ(result.count(fid), 0); result.insert({fid, 0}); } } }; TEST(UniqHashTableTest, Small) { size_t fid_num = 1e3; size_t fid_range = 1e2; UniqHashTableTest test(fid_num, fid_range); test.Check(); } TEST(UniqHashTableTest, Medium) { size_t fid_num = 1e5; size_t fid_range = 1e4; UniqHashTableTest test(fid_num, fid_range); test.Check(); } TEST(UniqHashTableTest, Reset) { size_t fid_num = 1e5; size_t fid_range = 1e4; UniqHashTableTest test(fid_num, fid_range); test.Check(); test.uniq_hashtable.Reset(); test.Check(); } TEST(UniqHashTableTest, ReqId) { size_t fid_num = 1e3; size_t fid_range = 1e2; UniqHashTableTest test(fid_num, fid_range); test.Check(); for (uint64_t i = 0; i < static_cast(std::numeric_limits::max()) + 1; i++) { test.uniq_hashtable.Reset(); } test.Check(); } } // namespace } // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/data/kernels/internal/value_filter_by_feature.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/data/kernels/internal/value_filter_by_feature.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "idl/matrix/proto/example.pb.h" #include "monolith/native_training/data/kernels/internal/relational_utils.h" #include "third_party/nlohmann/json.hpp" namespace tensorflow { namespace monolith_tf { namespace internal { using EFeature = ::monolith::io::proto::Feature; using FilterValues = ::monolith::io::proto::FilterValues; std::unordered_set FeatureValueFilter::VALID_SET_OPS = { "any", "all", "diff", "startswith", "endswith"}; FeatureValueFilter::FeatureValueFilter(std::string field_name, std::string field_type, std::string op, std::vector float_operand, std::vector int_operand, std::vector string_operand, std::string operand_filepath, bool keep_empty) : field_name_(std::move(field_name)), field_type_(std::move(field_type)), op_(std::move(op)), feature_index_valid_score_(1.0), float_operand_(std::move(float_operand)), int_operand_(std::move(int_operand)), string_operand_(std::move(string_operand)), operand_filepath_(std::move(operand_filepath)), keep_empty_(keep_empty) { if (!internal::VALID_OPS.count(op_) && !VALID_SET_OPS.count(op_)) { std::string valid_ops_str = absl::StrJoin(internal::VALID_OPS, ", "); std::string valid_set_ops_str = absl::StrJoin(VALID_SET_OPS, ", "); LOG(FATAL) << absl::StrFormat( "Invalid op: %s, please choose one from [%s] or [%s]", op_, valid_ops_str, valid_set_ops_str); } nlohmann::json j; j["field_name"] = field_name_; j["field_type"] = field_type_; j["op"] = op_; j["float_operand_count"] = float_operand_.size(); j["int_operand_count"] = int_operand_.size(); j["string_operand_count"] = string_operand_.size(); j["operand_filepath"] = operand_filepath_; int64_t limit = 1000; if (float_operand_.size() <= limit) { j["float_operand"] = float_operand_; } else { std::vector values(float_operand_.begin(), float_operand_.begin() + limit); j["float_operand_first_1000"] = values; } if (int_operand_.size() <= limit) { j["int_operand"] = int_operand_; } else { std::vector values(int_operand_.begin(), int_operand_.begin() + limit); j["int_operand_first_1000"] = values; } if (string_operand_.size() <= limit) { j["string_operand"] = string_operand_; } else { std::vector values(string_operand_.begin(), string_operand_.begin() + limit); j["string_operand_first_1000"] = values; } LOG(INFO) << j.dump(2); if ((op_ == internal::IN || op_ == internal::NOT_IN) && operand_filepath_.empty()) { float_operand_set_.insert(float_operand_.begin(), float_operand_.end()); int_operand_set_.insert(int_operand_.begin(), int_operand_.end()); string_operand_set_.insert(string_operand_.begin(), string_operand_.end()); } } Status FeatureValueFilter::EnsureLoadFilterValues(tensorflow::Env* env) { absl::MutexLock l(&load_filter_values_mu_); if (load_filter_values_finished_ || operand_filepath_.empty()) { return Status::OK(); } std::string filter_values_serialized; TF_RETURN_IF_ERROR( ReadFileToString(env, operand_filepath_, &filter_values_serialized)); FilterValues filter_values; if (!filter_values.ParseFromString(filter_values_serialized)) { return errors::InvalidArgument( "Unable to parse filter values, please make sure it is " "serialized version of message:FilterValues."); } switch (static_cast(filter_values.type_case())) { case FilterValues::TypeCase::kFloatList: { if (field_type_ != "float") { return errors::InvalidArgument( "Filter values' type(float) should be the same with field type(", field_type_, ")"); } float_operand_set_.insert(filter_values.float_list().value().begin(), filter_values.float_list().value().end()); break; } case FilterValues::TypeCase::kInt64List: { if (field_type_ != "int64") { return errors::InvalidArgument( "Filter values' type(int64) should be the same with field type(", field_type_, ")"); } int_operand_set_.insert(filter_values.int64_list().value().begin(), filter_values.int64_list().value().end()); break; } case FilterValues::TypeCase::kBytesList: { if (field_type_ != "bytes") { return errors::InvalidArgument( "Filter values' type(bytes) should be the same with field type(", field_type_, ")"); } string_operand_set_.insert(filter_values.bytes_list().value().begin(), filter_values.bytes_list().value().end()); break; } case FilterValues::TypeCase::TYPE_NOT_SET: return errors::InvalidArgument("FilterValue TYPE_NOT_SET, field type(", field_type_, ")"); default: return errors::InvalidArgument( "Invalid field type for feature value filter, field_type: ", field_type_, " FilterValues: ", filter_values.ShortDebugString()); } load_filter_values_finished_ = true; return Status::OK(); } bool FeatureValueFilter::CheckFeatureIndex(const Example& example, int* feature_index) { find_feature_index_mu_.ReaderLock(); bool result = true; if (cached_feature_index_ == -1 || cached_feature_index_ >= example.named_feature_size()) { result = false; } else { const auto& feature = example.named_feature(cached_feature_index_); if (feature.name() != field_name_) { result = false; } } if (result) { *feature_index = cached_feature_index_; } find_feature_index_mu_.ReaderUnlock(); return result; } bool FeatureValueFilter::IsInstanceOfInterest(tensorflow::Env* env, const Example& example) { bool output = false; int feature_index = -1; if (!CheckFeatureIndex(example, &feature_index)) { for (int i = 0; i < example.named_feature_size(); i++) { const auto& feature = example.named_feature(i); if (feature.name() == field_name_) { feature_index = i; } } if (feature_index != -1) { absl::MutexLock l(&find_feature_index_mu_); cached_feature_index_ = feature_index; } double score = feature_index_valid_score_.load(); score = 0.99 * score; feature_index_valid_score_.store(score); if (score < 0.7) { LOG_EVERY_N_SEC(ERROR, 15) << "Potential performance problem! feature index valid score: " << score; } } else { double score = feature_index_valid_score_.load(); feature_index_valid_score_.store(0.99 * score + 0.01); } LOG_EVERY_N_SEC(INFO, 120) << "Feature index valid score (performance related): " << feature_index_valid_score_.load(); if (feature_index == -1 && !keep_empty_) { output = false; LOG_EVERY_N_SEC(ERROR, 15) << "Feature not found!" << " field name: " << field_name_; return output; } const auto& feature = example.named_feature(feature_index).feature(); const auto& type_case = feature.type_case(); // op是in/not_in,且feature是单值类型的场景 if ((op_ == internal::IN || op_ == internal::NOT_IN) && !operand_filepath_.empty() && (type_case == EFeature::TypeCase::kFloatList || type_case == EFeature::TypeCase::kDoubleList || type_case == EFeature::TypeCase::kInt64List || type_case == EFeature::TypeCase::kBytesList)) { TF_CHECK_OK(EnsureLoadFilterValues(env)); } switch (static_cast(type_case)) { case EFeature::TypeCase::TYPE_NOT_SET: { LOG_EVERY_N_SEC(ERROR, 15) << "Invalid data: feature not set!" << " field name: " << field_name_; break; } case EFeature::TypeCase::kFloatValue: { LOG_EVERY_N_SEC(ERROR, 15) << "Invalid data: float value is not " "supported, please use float list!" << " field name: " << field_name_; break; } case EFeature::TypeCase::kDoubleValue: { LOG_EVERY_N_SEC(ERROR, 15) << "Invalid data: double value is not " "supported, please use double list!" << " field name: " << field_name_; break; } case EFeature::TypeCase::kInt64Value: { LOG_EVERY_N_SEC(ERROR, 15) << "Invalid data: int64 value is not " "supported, please use double list!" << " field name: " << field_name_; break; } case EFeature::TypeCase::kBytesValue: { LOG_EVERY_N_SEC(ERROR, 15) << "Invalid data: bytes value is not " "supported, please use bytes list!" << " field name: " << field_name_; break; } default: break; } std::vector values; switch (static_cast(type_case)) { case EFeature::TypeCase::kFloatList: { if (field_type_ != "float") { LOG_EVERY_N_SEC(ERROR, 15) << "Field type not match: field name: " << field_name_ << " field type: " << field_type_ << " but feature has float value."; break; } if (feature.float_list().value_size() == 1) { float value = feature.float_list().value(0); output = internal::COMPARE_OPS.count(op_) ? internal::compare(op_, value, float_operand_) : internal::contains(op_, value, float_operand_set_); return output; } else if (feature.float_list().value_size() > 1) { LOG_EVERY_N_SEC(ERROR, 15) << "Invalid data: float list with multiple elements is not " "supported, please investigate and retry!" << " field name: " << field_name_; } break; } case EFeature::TypeCase::kDoubleList: { if (field_type_ != "double") { LOG_EVERY_N_SEC(ERROR, 15) << "Field type not match: field name: " << field_name_ << " field type: " << field_type_ << " but feature has double value."; break; } if (feature.double_list().value_size() == 1) { double value = feature.double_list().value(0); output = internal::COMPARE_OPS.count(op_) ? internal::compare(op_, value, float_operand_) : internal::contains(op_, value, float_operand_set_); return output; } else if (feature.double_list().value_size() > 1) { LOG_EVERY_N_SEC(ERROR, 15) << "Invalid data: double_list with multiple elements is not " "supported, please investigate and retry!" << " field name: " << field_name_; } break; } case EFeature::TypeCase::kInt64List: { if (field_type_ != "int64") { LOG_EVERY_N_SEC(ERROR, 15) << "Field type not match: field name: " << field_name_ << " field type: " << field_type_ << " but feature has int64 value."; break; } if (VALID_SET_OPS.count(op_)) { for (const auto& value : feature.int64_list().value()) { values.push_back(value); } } else { if (feature.int64_list().value_size() == 1) { int64 value = feature.int64_list().value(0); output = internal::COMPARE_OPS.count(op_) ? internal::compare(op_, value, int_operand_) : internal::contains(op_, value, int_operand_set_); return output; } else if (feature.double_list().value_size() > 1) { LOG_EVERY_N_SEC(ERROR, 15) << "Invalid data: int64_list with multiple elements when not " "using set_ops is not supported, please investigate and retry!" << " field name: " << field_name_ << " op: " << op_; } } break; } case EFeature::TypeCase::kBytesList: { if (field_type_ != "bytes") { LOG_EVERY_N_SEC(ERROR, 15) << "Field type not match: field name: " << field_name_ << " field type: " << field_type_ << " but feature has bytes value."; break; } if (feature.bytes_list().value_size() == 1) { std::string value = feature.bytes_list().value(0); output = false; if (op_ == "startswith") { for (const std::string& operand : string_operand_) { if (value.find(operand) == 0) { output = true; break; } } } else if (op_ == "endswith") { for (const std::string& operand : string_operand_) { if (operand.size() <= value.size()) { bool found = std::equal(operand.rbegin(), operand.rend(), value.rbegin()); if (found) { output = true; break; } } } } else { output = internal::COMPARE_OPS.count(op_) ? internal::compare(op_, value, string_operand_) : internal::contains(op_, value, string_operand_set_); } return output; } else if (feature.bytes_list().value_size() > 1) { LOG_EVERY_N_SEC(ERROR, 15) << "Invalid data: bytes_list with multiple elements is not " "supported, please investigate and retry!" << " field name: " << field_name_; } break; } default: { output = false; const auto descriptor = EFeature::GetDescriptor(); const auto reflection = EFeature::GetReflection(); const auto oneof_descriptor = descriptor->FindOneofByName("type"); std::string feature_dtype = ""; if (oneof_descriptor != nullptr) { const auto field_descriptor = reflection->GetOneofFieldDescriptor(feature, oneof_descriptor); if (field_descriptor != nullptr) { if (field_descriptor->type() == google::protobuf::FieldDescriptor::TYPE_MESSAGE) { // 处理嵌套消息类型 const auto nested_descriptor = field_descriptor->message_type(); if (nested_descriptor != nullptr) { feature_dtype = nested_descriptor->name(); } } else if (field_descriptor->type() == google::protobuf::FieldDescriptor::TYPE_ENUM) { // 处理枚举类型 const auto enum_descriptor = field_descriptor->enum_type(); if (enum_descriptor != nullptr) { feature_dtype = enum_descriptor->name(); } } else { feature_dtype = field_descriptor->type_name(); } } } LOG(INFO) << "feature not match, feature dtype is: " << feature_dtype << ", supposed field type is: " << field_type_ << " type case: " << int(type_case); break; } } if (values.size() > 0) { output = cmp(values); } else { output = keep_empty_; } return output; } } // namespace internal } // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/data/kernels/internal/value_filter_by_feature.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_MONOLITH_NATIVE_TRAINING_DATA_KERNELS_INTERNAL_FEATURE_VALUE_FILTER_H_ #define MONOLITH_MONOLITH_NATIVE_TRAINING_DATA_KERNELS_INTERNAL_FEATURE_VALUE_FILTER_H_ #include #include #include "absl/synchronization/mutex.h" #include "idl/matrix/proto/example.pb.h" #include "tensorflow/core/platform/env.h" namespace tensorflow { namespace monolith_tf { namespace internal { using Example = ::monolith::io::proto::Example; class FeatureValueFilter { public: FeatureValueFilter(std::string field_name, std::string field_type, std::string op, std::vector float_operand, std::vector int_operand, std::vector string_operand, std::string operand_filepath, bool keep_empty); bool IsInstanceOfInterest(tensorflow::Env* env, const Example& example); static std::unordered_set VALID_SET_OPS; private: Status EnsureLoadFilterValues(tensorflow::Env* env); bool CheckFeatureIndex(const Example& example, int* feature_index); bool cmp(const std::vector& values) { std::set intersection; std::set_intersection(values.begin(), values.end(), int_operand_.begin(), int_operand_.end(), std::inserter(intersection, intersection.begin())); if (op_ == "any") { return intersection.size() > 0; } else if (op_ == "all") { return intersection.size() == int_operand_.size(); } else if (op_ == "diff") { return intersection.size() == 0; } else { LOG_EVERY_N_SEC(ERROR, 15) << "Invalid op for int64_list feature: " << op_; return false; } } private: mutable absl::Mutex load_filter_values_mu_; bool load_filter_values_finished_ ABSL_GUARDED_BY(load_filter_values_mu_) = false; std::string field_name_; std::string field_type_; mutable absl::Mutex find_feature_index_mu_; int cached_feature_index_ ABSL_GUARDED_BY(find_feature_index_mu_) = -1; std::atomic feature_index_valid_score_; std::string op_; // gt, ge, eq, lt, le, neq, between std::vector float_operand_; std::vector int_operand_; std::vector string_operand_; std::unordered_set float_operand_set_; std::unordered_set int_operand_set_; std::unordered_set string_operand_set_; std::string operand_filepath_; bool keep_empty_ = false; }; } // namespace internal } // namespace monolith_tf } // namespace tensorflow #endif // MONOLITH_MONOLITH_NATIVE_TRAINING_DATA_KERNELS_INTERNAL_FEATURE_VALUE_FILTER_H_ ================================================ FILE: monolith/native_training/data/kernels/internal/value_filter_by_line_id.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/data/kernels/internal/value_filter_by_line_id.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "idl/matrix/proto/example.pb.h" #include "monolith/native_training/data/kernels/internal/relational_utils.h" #include "third_party/nlohmann/json.hpp" namespace tensorflow { namespace monolith_tf { namespace internal { LineIdValueFilter::LineIdValueFilter(std::string field_name, std::string op, std::vector float_operand, std::vector int_operand, std::vector string_operand, std::string operand_filepath, bool keep_empty) : field_name_(std::move(field_name)), op_(std::move(op)), float_operand_(std::move(float_operand)), int_operand_(std::move(int_operand)), string_operand_(std::move(string_operand)), operand_filepath_(std::move(operand_filepath)), keep_empty_(keep_empty) { const auto descriptor = ::idl::matrix::proto::LineId::GetDescriptor(); const auto reflection = ::idl::matrix::proto::LineId::GetReflection(); field_ = descriptor->FindFieldByName(field_name_); uint_operand_.insert(uint_operand_.end(), int_operand_.begin(), int_operand_.end()); std::unordered_set valid_set_ops = {"any", "all", "diff", "startswith", "endswith"}; if (!internal::VALID_OPS.count(op_) && !valid_set_ops.count(op_)) { std::string valid_ops_str = absl::StrJoin(internal::VALID_OPS, ", "); std::string valid_set_ops_str = absl::StrJoin(valid_set_ops, ", "); LOG(FATAL) << absl::StrFormat( "Invalid op: %s, please choose one from [%s] or [%s]", op_, valid_ops_str, valid_set_ops_str); } nlohmann::json j; j["field_name"] = field_name_; j["op"] = op_; j["float_operand_count"] = float_operand_.size(); j["int_operand_count"] = int_operand_.size(); j["string_operand_count"] = string_operand_.size(); j["operand_filepath"] = operand_filepath_; int64_t limit = 1000; if (float_operand_.size() <= limit) { j["float_operand"] = float_operand_; } else { std::vector values(float_operand_.begin(), float_operand_.begin() + limit); j["float_operand_first_1000"] = values; } if (int_operand_.size() <= limit) { j["int_operand"] = int_operand_; } else { std::vector values(int_operand_.begin(), int_operand_.begin() + limit); j["int_operand_first_1000"] = values; } if (string_operand_.size() <= limit) { j["string_operand"] = string_operand_; } else { std::vector values(string_operand_.begin(), string_operand_.begin() + limit); j["string_operand_first_1000"] = values; } LOG(INFO) << j.dump(2); if ((op_ == internal::IN || op_ == internal::NOT_IN) && operand_filepath_.empty()) { float_operand_set_.insert(float_operand_.begin(), float_operand_.end()); int_operand_set_.insert(int_operand_.begin(), int_operand_.end()); uint_operand_set_.insert(uint_operand_.begin(), uint_operand_.end()); string_operand_set_.insert(string_operand_.begin(), string_operand_.end()); } } Status LineIdValueFilter::EnsureLoadFilterValues(tensorflow::Env *env) { absl::MutexLock l(&mu_); if (load_filter_values_finished_ || operand_filepath_.empty()) { return Status::OK(); } std::string filter_values_serialized; TF_RETURN_IF_ERROR( ReadFileToString(env, operand_filepath_, &filter_values_serialized)); ::monolith::io::proto::FilterValues filter_values; if (!filter_values.ParseFromString(filter_values_serialized)) { return errors::InvalidArgument( "Unable to parse filter values, please make sure it is " "serialized version of message:FilterValues."); } auto field = field_; switch (field->cpp_type()) { case google::protobuf::FieldDescriptor::CppType::CPPTYPE_FLOAT: case google::protobuf::FieldDescriptor::CppType::CPPTYPE_DOUBLE: { if (!filter_values.has_float_list()) { return errors::InvalidArgument( "Filter values' type should be the same with field type."); } float_operand_set_.insert(filter_values.float_list().value().begin(), filter_values.float_list().value().end()); break; } case google::protobuf::FieldDescriptor::CppType::CPPTYPE_INT32: case google::protobuf::FieldDescriptor::CppType::CPPTYPE_INT64: { if (!filter_values.has_int64_list()) { return errors::InvalidArgument( "Filter values' type should be the same with field type."); } int_operand_set_.insert(filter_values.int64_list().value().begin(), filter_values.int64_list().value().end()); break; } case google::protobuf::FieldDescriptor::CppType::CPPTYPE_UINT32: case google::protobuf::FieldDescriptor::CppType::CPPTYPE_UINT64: { if (!filter_values.has_int64_list()) { return errors::InvalidArgument( "Filter values' type should be the same with field type."); } uint_operand_set_.insert(filter_values.int64_list().value().begin(), filter_values.int64_list().value().end()); break; } case google::protobuf::FieldDescriptor::CppType::CPPTYPE_STRING: { if (!filter_values.has_bytes_list()) { return errors::InvalidArgument( "Filter values' type should be the same with field type."); } string_operand_set_.insert(filter_values.bytes_list().value().begin(), filter_values.bytes_list().value().end()); break; } default: { return errors::InvalidArgument("Invalid field type for filter."); } } load_filter_values_finished_ = true; return Status::OK(); } bool LineIdValueFilter::IsInstanceOfInterest( tensorflow::Env *env, const ::idl::matrix::proto::LineId &line_id) { bool output = false; const auto reflection = ::idl::matrix::proto::LineId::GetReflection(); auto field = field_; if (field == nullptr) { output = false; return output; } if (!field->is_repeated()) { if ((op_ == internal::IN || op_ == internal::NOT_IN) && !operand_filepath_.empty()) { TF_CHECK_OK(EnsureLoadFilterValues(env)); } switch (field->cpp_type()) { case google::protobuf::FieldDescriptor::CppType::CPPTYPE_FLOAT: { float value = reflection->GetFloat(line_id, field); output = internal::COMPARE_OPS.count(op_) ? internal::compare(op_, value, float_operand_) : internal::contains(op_, value, float_operand_set_); break; } case google::protobuf::FieldDescriptor::CppType::CPPTYPE_DOUBLE: { double value = reflection->GetDouble(line_id, field); output = internal::COMPARE_OPS.count(op_) ? internal::compare(op_, value, float_operand_) : internal::contains(op_, value, float_operand_set_); break; } case google::protobuf::FieldDescriptor::CppType::CPPTYPE_INT32: { int64 value = reflection->GetInt32(line_id, field); output = internal::COMPARE_OPS.count(op_) ? internal::compare(op_, value, int_operand_) : internal::contains(op_, value, int_operand_set_); break; } case google::protobuf::FieldDescriptor::CppType::CPPTYPE_INT64: { int64 value = reflection->GetInt64(line_id, field); output = internal::COMPARE_OPS.count(op_) ? internal::compare(op_, value, int_operand_) : internal::contains(op_, value, int_operand_set_); break; } case google::protobuf::FieldDescriptor::CppType::CPPTYPE_UINT32: { int64 value = reflection->GetUInt32(line_id, field); output = internal::COMPARE_OPS.count(op_) ? internal::compare(op_, value, int_operand_) : internal::contains(op_, value, int_operand_set_); break; } case google::protobuf::FieldDescriptor::CppType::CPPTYPE_UINT64: { uint64 value = reflection->GetUInt64(line_id, field); output = internal::COMPARE_OPS.count(op_) ? internal::compare(op_, value, uint_operand_) : internal::contains(op_, value, uint_operand_set_); break; } case google::protobuf::FieldDescriptor::CppType::CPPTYPE_STRING: { std::string value = reflection->GetString(line_id, field); output = false; if (op_ == "startswith") { for (const std::string &operand : string_operand_) { if (value.find(operand) == 0) { output = true; break; } } } else if (op_ == "endswith") { for (const std::string &operand : string_operand_) { if (operand.size() <= value.size()) { bool found = std::equal(operand.rbegin(), operand.rend(), value.rbegin()); if (found) { output = true; break; } } } } else { output = internal::COMPARE_OPS.count(op_) ? internal::compare(op_, value, string_operand_) : internal::contains(op_, value, string_operand_set_); } break; } default: output = false; LOG(INFO) << "dtype is " << field->cpp_type(); break; } } else { const int field_size = reflection->FieldSize(line_id, field); std::vector values; switch (field->cpp_type()) { case google::protobuf::FieldDescriptor::CppType::CPPTYPE_INT32: for (int i = 0; i < field_size; ++i) { values.push_back(reflection->GetRepeatedInt32(line_id, field, i)); } break; case google::protobuf::FieldDescriptor::CppType::CPPTYPE_INT64: for (int i = 0; i < field_size; ++i) { values.push_back(reflection->GetRepeatedInt64(line_id, field, i)); } break; case google::protobuf::FieldDescriptor::CppType::CPPTYPE_UINT32: for (int i = 0; i < field_size; ++i) { values.push_back(reflection->GetRepeatedUInt32(line_id, field, i)); } break; case google::protobuf::FieldDescriptor::CppType::CPPTYPE_UINT64: for (int i = 0; i < field_size; ++i) { values.push_back(reflection->GetRepeatedUInt64(line_id, field, i)); } break; default: LOG(INFO) << "dtype is " << field->cpp_type(); break; } if (values.size() > 0) { output = cmp(values); } else { output = keep_empty_; } } return output; } } // namespace internal } // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/data/kernels/internal/value_filter_by_line_id.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_MONOLITH_NATIVE_TRAINING_DATA_KERNELS_INTERNAL_LINE_ID_VALUE_FILTER_H_ #define MONOLITH_MONOLITH_NATIVE_TRAINING_DATA_KERNELS_INTERNAL_LINE_ID_VALUE_FILTER_H_ #include #include "absl/synchronization/mutex.h" #include "idl/matrix/proto/line_id.pb.h" #include "tensorflow/core/platform/env.h" namespace tensorflow { namespace monolith_tf { namespace internal { class LineIdValueFilter { public: LineIdValueFilter(std::string field_name, std::string op, std::vector float_operand, std::vector int_operand, std::vector string_operand, std::string operand_filepath, bool keep_empty); bool IsInstanceOfInterest(tensorflow::Env* env, const ::idl::matrix::proto::LineId& line_id); private: Status EnsureLoadFilterValues(tensorflow::Env* env); bool cmp(const std::vector& values) { std::set intersection; std::set_intersection(values.begin(), values.end(), int_operand_.begin(), int_operand_.end(), std::inserter(intersection, intersection.begin())); if (op_ == "any") { return intersection.size() > 0; } else if (op_ == "all") { return intersection.size() == int_operand_.size(); } else if (op_ == "diff") { return intersection.size() == 0; } else { LOG(FATAL) << "Invalid op: " << op_; return false; } } private: mutable absl::Mutex mu_; bool load_filter_values_finished_ ABSL_GUARDED_BY(mu_) = false; const google::protobuf::FieldDescriptor* field_; std::string field_name_; std::string op_; // gt, ge, eq, lt, le, neq, between bool keep_empty_ = false; std::string operand_filepath_; std::vector float_operand_; std::vector int_operand_; std::vector uint_operand_; std::vector string_operand_; std::unordered_set float_operand_set_; std::unordered_set int_operand_set_; std::unordered_set uint_operand_set_; std::unordered_set string_operand_set_; }; } // namespace internal } // namespace monolith_tf } // namespace tensorflow #endif // MONOLITH_MONOLITH_NATIVE_TRAINING_DATA_KERNELS_INTERNAL_LINE_ID_VALUE_FILTER_H_ ================================================ FILE: monolith/native_training/data/kernels/internal/value_filter_test.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/data/kernels/internal/value_filter_by_line_id.h" #include "monolith/native_training/data/kernels/internal/value_filter_by_feature.h" #include #include #include "gmock/gmock.h" #include "gtest/gtest.h" #include "idl/matrix/proto/line_id.pb.h" namespace tensorflow { namespace monolith_tf { namespace internal { namespace { using ::idl::matrix::proto::LineId; TEST(LineIdValueFilter, Int) { LineId line_id; line_id.set_uid(2); tensorflow::Env* env = tensorflow::Env::Default(); LineIdValueFilter filter_eq("uid", "eq", {}, {2}, {}, "", false); EXPECT_TRUE(filter_eq.IsInstanceOfInterest(env, line_id)); LineIdValueFilter filter_neq("uid", "neq", {}, {2}, {}, "", false); EXPECT_FALSE(filter_neq.IsInstanceOfInterest(env, line_id)); LineIdValueFilter filter_le("uid", "le", {}, {3}, {}, "", false); EXPECT_TRUE(filter_le.IsInstanceOfInterest(env, line_id)); LineIdValueFilter filter_ge("uid", "ge", {}, {1}, {}, "", false); EXPECT_TRUE(filter_ge.IsInstanceOfInterest(env, line_id)); LineIdValueFilter filter_between("uid", "between", {}, {1, 3}, {}, "", false); EXPECT_TRUE(filter_between.IsInstanceOfInterest(env, line_id)); LineIdValueFilter filter_in("uid", "in", {}, {1, 2, 3}, {}, "", false); EXPECT_TRUE(filter_in.IsInstanceOfInterest(env, line_id)); LineIdValueFilter filter_notin("uid", "not-in", {}, {1, 2, 3}, {}, "", false); EXPECT_FALSE(filter_notin.IsInstanceOfInterest(env, line_id)); } TEST(LineIdValueFilter, IntArray) { LineId line_id; line_id.mutable_actions()->Add(2); line_id.mutable_actions()->Add(3); tensorflow::Env* env = tensorflow::Env::Default(); LineIdValueFilter filter_any1("actions", "any", {}, {1, 2}, {}, "", false); EXPECT_TRUE(filter_any1.IsInstanceOfInterest(env, line_id)); LineIdValueFilter filter_any2("actions", "any", {}, {1, 4}, {}, "", false); EXPECT_FALSE(filter_any2.IsInstanceOfInterest(env, line_id)); LineIdValueFilter filter_all1("actions", "all", {}, {2, 3}, {}, "", false); EXPECT_TRUE(filter_all1.IsInstanceOfInterest(env, line_id)); LineIdValueFilter filter_all2("actions", "all", {}, {2, 3, 4}, {}, "", false); EXPECT_FALSE(filter_all2.IsInstanceOfInterest(env, line_id)); LineIdValueFilter filter_diff1("actions", "diff", {}, {1, 4}, {}, "", false); EXPECT_TRUE(filter_diff1.IsInstanceOfInterest(env, line_id)); LineIdValueFilter filter_diff2("actions", "diff", {}, {1, 2, 4}, {}, "", false); EXPECT_FALSE(filter_diff2.IsInstanceOfInterest(env, line_id)); } TEST(LineIdValueFilter, Float) { LineId line_id; line_id.set_q_pred(2.0f); tensorflow::Env* env = tensorflow::Env::Default(); LineIdValueFilter filter_eq("q_pred", "eq", {2.0f}, {}, {}, "", false); EXPECT_TRUE(filter_eq.IsInstanceOfInterest(env, line_id)); LineIdValueFilter filter_neq("q_pred", "neq", {2.0f}, {}, {}, "", false); EXPECT_FALSE(filter_neq.IsInstanceOfInterest(env, line_id)); LineIdValueFilter filter_le("q_pred", "le", {3.0f}, {}, {}, "", false); EXPECT_TRUE(filter_le.IsInstanceOfInterest(env, line_id)); LineIdValueFilter filter_ge("q_pred", "ge", {1.0f}, {}, {}, "", false); EXPECT_TRUE(filter_ge.IsInstanceOfInterest(env, line_id)); LineIdValueFilter filter_between("q_pred", "between", {1.0f, 3.0f}, {}, {}, "", false); EXPECT_TRUE(filter_between.IsInstanceOfInterest(env, line_id)); LineIdValueFilter filter_in("q_pred", "in", {1.0f, 2.0f, 3.0f}, {}, {}, "", false); EXPECT_TRUE(filter_in.IsInstanceOfInterest(env, line_id)); LineIdValueFilter filter_notin("q_pred", "not-in", {1.0f, 2.0f, 3.0f}, {}, {}, "", false); EXPECT_FALSE(filter_notin.IsInstanceOfInterest(env, line_id)); } TEST(LineIdValueFilter, String) { LineId line_id; line_id.set_vid("hello"); tensorflow::Env* env = tensorflow::Env::Default(); LineIdValueFilter filter_eq("vid", "eq", {}, {}, {"hello"}, "", false); EXPECT_TRUE(filter_eq.IsInstanceOfInterest(env, line_id)); LineIdValueFilter filter_neq("vid", "neq", {}, {}, {"hello"}, "", false); EXPECT_FALSE(filter_neq.IsInstanceOfInterest(env, line_id)); LineIdValueFilter filter_le("vid", "le", {}, {}, {"hello1"}, "", false); EXPECT_TRUE(filter_le.IsInstanceOfInterest(env, line_id)); LineIdValueFilter filter_ge("vid", "ge", {}, {}, {"hell"}, "", false); EXPECT_TRUE(filter_ge.IsInstanceOfInterest(env, line_id)); LineIdValueFilter filter_between("vid", "between", {}, {}, {"hell", "hello1"}, "", false); EXPECT_TRUE(filter_between.IsInstanceOfInterest(env, line_id)); LineIdValueFilter filter_in("vid", "in", {}, {}, {"hello", "world"}, "", false); EXPECT_TRUE(filter_in.IsInstanceOfInterest(env, line_id)); LineIdValueFilter filter_notin("vid", "not-in", {}, {}, {"hello", "world"}, "", false); EXPECT_FALSE(filter_notin.IsInstanceOfInterest(env, line_id)); LineIdValueFilter filter_startswith("vid", "startswith", {}, {}, {"hell"}, "", false); EXPECT_TRUE(filter_startswith.IsInstanceOfInterest(env, line_id)); LineIdValueFilter filter_endswith("vid", "endswith", {}, {}, {"llo"}, "", false); EXPECT_TRUE(filter_endswith.IsInstanceOfInterest(env, line_id)); } } // namespace } // namespace internal } // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/data/kernels/item_pool_kernels.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include "monolith/native_training/data/kernels/item_pool_kernels.h" #include "tensorflow/core/lib/io/record_reader.h" #include "tensorflow/core/lib/io/record_writer.h" #include "tensorflow/core/platform/path.h" #include "absl/random/random.h" #include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_split.h" #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "tensorflow/core/platform/random.h" #include "tensorflow/core/platform/threadpool.h" #include "third_party/nlohmann/json.hpp" using json = nlohmann::json; using NamedFeature = ::monolith::io::proto::NamedFeature; using ChannelCache = ::monolith::io::proto::ChannelCache; static const std::string FILE_NAME_PREFIX = "item_pool_"; static constexpr uint64_t MASK = (1L << 48) - 1; namespace tensorflow { namespace monolith_tf { // Carries the data through async process. // It will ref and unref |p_hash_table|. struct AsyncPack { AsyncPack(OpKernelContext* p_ctx, ItemPoolResource* p_pool, std::function p_done, int p_thread_num) : ctx(p_ctx), pool(p_pool), done(std::move(p_done)), thread_num(p_thread_num), finish_num(0), status(thread_num) { pool->Ref(); } ~AsyncPack() { pool->Unref(); } OpKernelContext* ctx; ItemPoolResource* pool; std::function done; const int thread_num; std::atomic_int finish_num; std::vector status; }; ItemPoolResource::ItemPoolResource(int max_item_num_per_channel, int start_num) : start_num_(start_num), max_item_num_per_channel_(max_item_num_per_channel), cache_(std::make_unique(max_item_num_per_channel, start_num)) {} Status ItemPoolResource::Add( uint64_t channel_id, uint64_t item_id, const std::shared_ptr& item) { absl::MutexLock l(&mu_); cache_->Push(channel_id, item_id, item, 1, 0); return Status::OK(); } std::shared_ptr ItemPoolResource::Sample( uint64_t channel_id, double* freq_factor, double* time_factor) { absl::MutexLock l(&mu_); return cache_->RandomSelectOne(channel_id, freq_factor, time_factor); } Status ItemPoolResource::Save(WritableFile* ostream, int shard_index, int shard_num) { absl::MutexLock l(&mu_); const absl::flat_hash_map& channel_cache_ = cache_->GetCache(); io::RecordWriter writer(ostream); Status write_status = Status::OK(); for (const auto& pair : channel_cache_) { if (pair.first % shard_num != shard_index) { continue; } ChannelCache channel_cache; channel_cache.set_channel_id(pair.first); pair.second.ToProto(&channel_cache); Status s = writer.WriteRecord(channel_cache.SerializeAsString()); if (TF_PREDICT_FALSE(!s.ok())) { write_status.Update(s); break; } } TF_RETURN_IF_ERROR(write_status); TF_RETURN_IF_ERROR(writer.Close()); return Status::OK(); } Status ItemPoolResource::Restore(RandomAccessFile* istream, int64 buffer_size) { absl::MutexLock l(&mu_); io::RecordReaderOptions opts; opts.buffer_size = buffer_size; io::SequentialRecordReader reader(istream, opts); Status restore_status = Status::OK(); while (true) { tstring s; ChannelCache channel_cache; // read record Status rs = reader.ReadRecord(&s); if (errors::IsOutOfRange(rs)) { LOG(INFO) << "EOF, read file done..."; break; } else { restore_status.Update(rs); } if (!channel_cache.ParseFromArray(s.data(), s.size())) { restore_status.Update(errors::FailedPrecondition( "Unable to parse data. Data might be corrupted")); break; } else { restore_status.Update(Status::OK()); } for (const auto& feature_data : channel_cache.feature_datas()) { auto item_feature_ptr = internal::MakeItemFeaturesFromProto(feature_data); cache_->Push(channel_cache.channel_id(), feature_data.gid(), item_feature_ptr, feature_data.origin_cnt(), feature_data.sample_cnt()); } LOG(INFO) << absl::StrFormat( "ItemPoolResource: after restore, channel %lld restore %llu items", channel_cache.channel_id(), channel_cache.feature_datas_size()); } TF_RETURN_IF_ERROR(restore_status); return Status::OK(); } bool ItemPoolResource::Equal(const ItemPoolResource& other) const { if (other.max_item_num_per_channel_ != max_item_num_per_channel_) { return false; } if (other.start_num_ != start_num_) { return false; } auto this_cache = cache_->GetCache(); auto other_cache = other.cache_->GetCache(); if (this_cache.size() != other_cache.size()) { return false; } else { for (const auto& it : this_cache) { if (other_cache.count(it.first) == 0) { return false; } else { auto this_channel = it.second; auto other_channel = other_cache.at(it.first); return this_channel.Equal(other_channel); } } } return true; } void ItemPoolResource::SampleChannelID(uint64_t* channel_id) { absl::MutexLock l(&mu_); cache_->SampleChannelID(channel_id); } void get_index_and_worker_num(int* index, int* worker_num) { const char* env_p = std::getenv("TF_CONFIG"); if (env_p == nullptr) { *index = 0; *worker_num = 1; } else { auto tf_config = json::parse(env_p); // assert TF_CONFIG only has ps + chief + worker for (const auto& conf_item : tf_config["cluster"].items()) { if (conf_item.key() != "ps" && conf_item.key() != "chief" && conf_item.key() != "worker") { LOG(ERROR) << "Unknown Cluster Type: " << conf_item.key(); } } auto chief = tf_config["cluster"]["chief"]; auto workers = tf_config["cluster"]["worker"]; *worker_num = chief.size() + workers.size(); if (tf_config["task"]["type"] == "worker") { *index = static_cast(tf_config["task"]["index"]) + 1; } else { *index = 0; } } } class ItemPoolCreateOp : public ResourceOpKernel { public: explicit ItemPoolCreateOp(OpKernelConstruction* ctx) : ResourceOpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("start_num", &start_num_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("max_item_num_per_channel", &max_item_num_per_channel_)); } private: Status CreateResource(ItemPoolResource** wrapper) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) override { *wrapper = new ItemPoolResource(max_item_num_per_channel_, start_num_); return Status::OK(); } int start_num_, max_item_num_per_channel_; }; // for test only class ItemPoolRandomFillOp : public OpKernel { public: explicit ItemPoolRandomFillOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} void Compute(OpKernelContext* ctx) override { ItemPoolResource* pool; OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &pool)); core::ScopedUnref unref(pool); ctx->set_output(0, ctx->input(0)); for (int i = 0; i < 10; ++i) { for (int j = 0; j < 50; ++j) { std::shared_ptr item = std::make_shared(); GenItemFeatures(item.get()); pool->Add(i, j, item); } } } private: void GenNamedFeature(NamedFeature* nf) { int slot = std::rand() % 1024; nf->set_name(absl::StrCat("fc_", slot)); auto* fid_v2_list = nf->mutable_feature()->mutable_fid_v2_list(); int num_fids = std::abs(std::rand() % 20) + 1; for (int i = 0; i < num_fids; ++i) { uint64_t fid = ((uint64_t)slot << 48) | ((std::rand() % 100000) & MASK); fid_v2_list->add_value(fid); } } void GenItemFeatures(internal::ItemFeatures* item) { int num_feats = std::abs(std::rand() % 20) + 1; for (int i = 0; i < num_feats; ++i) { NamedFeature nf; GenNamedFeature(&nf); if (!item->example_features.contains(nf.name())) { item->example_features.insert({nf.name(), nf}); } } } }; // for test only class ItemPoolCheckOp : public OpKernel { public: explicit ItemPoolCheckOp(OpKernelConstruction* ctx) : OpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("model_path", &model_path_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("nshards", &nshards_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("buffer_size", &buffer_size_)); } void Compute(OpKernelContext* ctx) override { ItemPoolResource* pool; OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &pool)); core::ScopedUnref unref(pool); ctx->set_output(0, ctx->input(0)); const Tensor& global_step_tensor = ctx->input(1); global_step_ = global_step_tensor.scalar()(); ItemPoolResource pool2(pool->max_item_num_per_channel(), pool->start_num()); for (int idx = 0; idx < nshards_; ++idx) { std::string filename = GetRestoreFileName(ctx, idx); if (!filename.empty()) { std::unique_ptr istream; OP_REQUIRES_OK(ctx, ctx->env()->NewRandomAccessFile(filename, &istream)); OP_REQUIRES_OK(ctx, pool2.Restore(istream.get(), buffer_size_)); } } if (!pool->Equal(pool2)) { LOG(INFO) << "resotre not equal~ ..."; } else { LOG(INFO) << "resotre equal~ ..."; } } private: std::string model_path_; int64 buffer_size_; int nshards_; int64 global_step_; std::string GetRestoreFileName(OpKernelContext* ctx, int shard_index) { int index, worker_num; get_index_and_worker_num(&index, &worker_num); std::vector files; Status s = ctx->env()->GetMatchingPaths( absl::StrCat(model_path_, "/model.ckpt-", global_step_, "_", FILE_NAME_PREFIX, "*"), &files); if (!s.ok()) { LOG(INFO) << "GetMatchingPaths Error: " << s; return ""; } int last_worker_num = 1; int64 mtime_nsec = 0; for (const auto& file : files) { FileStatistics stat; ctx->env()->Stat(file, &stat); if (mtime_nsec < stat.mtime_nsec) { std::vector items = absl::StrSplit(file, "_"); absl::SimpleAtoi(items.back(), &last_worker_num); mtime_nsec = stat.mtime_nsec; } } if (files.size() > 0) { // {model_path}/item_pool_{index}_{worker_num} return absl::StrCat(model_path_, "/model.ckpt-", global_step_, "_", FILE_NAME_PREFIX, index % last_worker_num, "_", shard_index, "_", last_worker_num); } else { return ""; } } }; class ItemPoolSaveOp : public AsyncOpKernel { public: explicit ItemPoolSaveOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("model_path", &model_path_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("nshards", &nshards_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("random_sleep_ms", &random_sleep_ms_)); } void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { ItemPoolResource* pool; OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &pool)); core::ScopedUnref unref(pool); const Tensor& global_step_tensor = ctx->input(1); global_step_ = global_step_tensor.scalar()(); ctx->set_output(0, ctx->input(0)); if (!ctx->env()->FileExists(model_path_).ok()) { OP_REQUIRES_OK_ASYNC(ctx, ctx->env()->RecursivelyCreateDir(model_path_), done); } // add multi-thread auto pack = new AsyncPack(ctx, pool, std::move(done), nshards_); for (int i = 0; i < nshards_; ++i) { ctx->device()->tensorflow_cpu_worker_threads()->workers->Schedule( [this, pack, i] { WorkerThread(i, pack); }); } } private: std::string model_path_; int64 global_step_; int nshards_; int64 random_sleep_ms_; std::string GetSaveFileName(int shard_index) { int index, worker_num; get_index_and_worker_num(&index, &worker_num); return absl::StrCat(model_path_, "/model.ckpt-", global_step_, "_", FILE_NAME_PREFIX, index, "_", shard_index, "_", worker_num); } void WorkerThread(int shard_index, AsyncPack* p) { absl::BitGen bitgen; p->status[shard_index] = SaveOneShard(shard_index, p); if (p->finish_num.fetch_add(1) == p->thread_num - 1) { Cleanup(p); } } Status SaveOneShard(int shard_index, AsyncPack* p) { std::string filename = GetSaveFileName(shard_index); std::string tmp_filename = absl::StrCat(filename, "_tmp"); std::unique_ptr ostream; TF_RETURN_IF_ERROR(p->ctx->env()->NewWritableFile(tmp_filename, &ostream)); TF_RETURN_IF_ERROR( p->pool->Save(ostream.get(), shard_index, p->thread_num)); TF_RETURN_IF_ERROR(ostream->Close()); if (p->ctx->env()->FileExists(filename).ok()) { TF_RETURN_IF_ERROR( p->ctx->env()->RenameFile(filename, absl::StrCat(filename, "_old"))); TF_RETURN_IF_ERROR(p->ctx->env()->RenameFile(tmp_filename, filename)); TF_RETURN_IF_ERROR( p->ctx->env()->DeleteFile(absl::StrCat(filename, "_old"))); } else { TF_RETURN_IF_ERROR(p->ctx->env()->RenameFile(tmp_filename, filename)); } return Status::OK(); } // Clean up when all shards are done. void Cleanup(AsyncPack* p) { auto done = [p]() { // We want to delete p first and then call done. auto done = std::move(p->done); delete p; done(); }; for (int i = 0; i < p->thread_num; ++i) { OP_REQUIRES_OK_ASYNC(p->ctx, p->status[i], done); } done(); } }; class ItemPoolRestoreOp : public AsyncOpKernel { public: explicit ItemPoolRestoreOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("model_path", &model_path_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("buffer_size", &buffer_size_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("nshards", &nshards_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("random_sleep_ms", &random_sleep_ms_)); } void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { ItemPoolResource* pool; OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &pool)); core::ScopedUnref unref(pool); const Tensor& global_step_tensor = ctx->input(1); global_step_ = global_step_tensor.scalar()(); ctx->set_output(0, ctx->input(0)); auto pack = new AsyncPack(ctx, pool, std::move(done), nshards_); for (int i = 0; i < nshards_; i++) { ctx->device()->tensorflow_cpu_worker_threads()->workers->Schedule( [this, pack, i] { WorkerThread(i, pack); }); } } private: std::string model_path_; int64 global_step_; int64 buffer_size_; int nshards_; int64 random_sleep_ms_; void WorkerThread(int shard_index, AsyncPack* p) { absl::BitGen bitgen; p->status[shard_index] = RestoreOneShard(shard_index, p); if (p->finish_num.fetch_add(1) == p->thread_num - 1) { Cleanup(p); } } Status RestoreOneShard(int shard_index, AsyncPack* p) { std::string filename = GetRestoreFileName(p->ctx, shard_index); if (filename == "") { LOG(INFO) << "Cannot find file to restore, skip!"; } else if (p->ctx->env()->FileExists(filename).ok()) { LOG(INFO) << "Restoring file: " << filename; std::unique_ptr istream; TF_RETURN_IF_ERROR( p->ctx->env()->NewRandomAccessFile(filename, &istream)); TF_RETURN_IF_ERROR(p->pool->Restore(istream.get(), buffer_size_)); } else { LOG(INFO) << "File dose not exist: " << filename; } return Status::OK(); } // Clean up when all shards are done. void Cleanup(AsyncPack* p) { auto done = [p]() { // We want to delete p first and then call done. auto done = std::move(p->done); delete p; done(); }; for (int i = 0; i < p->thread_num; ++i) { OP_REQUIRES_OK_ASYNC(p->ctx, p->status[i], done); } done(); } int FindLastNumber(std::vector const &files, OpKernelContext* ctx) { // 支持 restore 时的 worker_num 可以和 save 时不同 int last_worker_num = 1; for (const auto& file : files) { if (absl::EndsWith(file, "tmp")) { LOG(INFO) << "Files vector contains file with tmp suffix."; continue; } std::vector items = absl::StrSplit(file, "_"); if (!items.empty() && absl::SimpleAtoi(items.back(), &last_worker_num)) { break; } } return last_worker_num; } int FindFuzzyCkptNumber(const std::vector& files) { int max_match_step = -1; for (const std::string& file : files) { LOG(INFO) << "match fuzzy ckpt:" << file; if (absl::EndsWith(file, "tmp")) { LOG(INFO) << "Files vector contains file with tmp suffix."; continue; } // file like "xxx/model.ckpt-25541095_item_pool_28_0_60" std::vector items = absl::StrSplit(file, absl::ByAnyChar("-_")); CHECK_GT(items.size(), 6) << absl::StrFormat( "item_pool ckpt's filepath is not correct: %s", file); int global_step = -1; CHECK(absl::SimpleAtoi(items.at(items.size() - 6), &global_step)); if (global_step > max_match_step) { max_match_step = global_step; } } return max_match_step; } std::string GetRestoreFileName(OpKernelContext* ctx, int shard_index) { int index, worker_num; get_index_and_worker_num(&index, &worker_num); std::vector files_new; std::vector files_old; // the global step of chief's item_pool ckpt is correct Status s_new = ctx->env()->GetMatchingPaths( absl::StrCat(model_path_, "/model.ckpt-", global_step_, "_", FILE_NAME_PREFIX, "*"), &files_new); if (s_new.ok() && !files_new.empty()) { int last_save_worker_num = FindLastNumber(files_new, ctx); LOG(INFO) << "last worker num is: " << last_save_worker_num; std::vector files_fuzzy; std::string fuzzy_matching_path = absl::StrCat(model_path_, "/model.ckpt-", "*", "_", FILE_NAME_PREFIX, index % last_save_worker_num, "_", shard_index, "_", last_save_worker_num); Status fuzzy_match = ctx->env()->GetMatchingPaths(fuzzy_matching_path, &files_fuzzy); if (fuzzy_match.ok() && !files_fuzzy.empty()) { int ckpt_num = FindFuzzyCkptNumber(files_fuzzy); if (ckpt_num <= global_step_) { return absl::StrCat(model_path_, "/model.ckpt-", ckpt_num, "_", FILE_NAME_PREFIX, index % last_save_worker_num, "_", shard_index, "_", last_save_worker_num); } else { LOG(INFO) << absl::StrFormat( "step not match: fuzzy match step is %d, target global step is " "%d", ckpt_num, global_step_); } } else { LOG(INFO) << absl::StrFormat("path not match: %s", fuzzy_matching_path); } } Status s_old = ctx->env()->GetMatchingPaths( absl::StrCat(model_path_, "/", FILE_NAME_PREFIX, "*"), &files_old); if (s_old.ok() && !files_old.empty()) { LOG(INFO) << "old version files > 0"; int last_worker_num = FindLastNumber(files_old, ctx); LOG(INFO) << "last worker num is: " << last_worker_num; return absl::StrCat(model_path_, "/", FILE_NAME_PREFIX, index % last_worker_num, "_", shard_index, "_", last_worker_num); } LOG(INFO) << "GetMatchingPaths Error: [new] " << s_new << " and [old] " << s_old; return ""; } }; namespace { REGISTER_KERNEL_BUILDER(Name("ItemPoolCreate").Device(DEVICE_CPU), ItemPoolCreateOp); // for test only REGISTER_KERNEL_BUILDER(Name("ItemPoolCheck").Device(DEVICE_CPU), ItemPoolCheckOp); // for test only REGISTER_KERNEL_BUILDER(Name("ItemPoolRandomFill").Device(DEVICE_CPU), ItemPoolRandomFillOp); REGISTER_KERNEL_BUILDER(Name("ItemPoolSave").Device(DEVICE_CPU), ItemPoolSaveOp); REGISTER_KERNEL_BUILDER(Name("ItemPoolRestore").Device(DEVICE_CPU), ItemPoolRestoreOp); } // namespace } // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/data/kernels/item_pool_kernels.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_NATIVE_TRAINING_DATA_KERNELS_ITEM_POOL_KERNELS_H_ #define MONOLITH_NATIVE_TRAINING_DATA_KERNELS_ITEM_POOL_KERNELS_H_ #include "monolith/native_training/data/kernels/internal/cache_mgr.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/framework/resource_op_kernel.h" namespace tensorflow { namespace monolith_tf { class ItemPoolResource : public ResourceBase { public: explicit ItemPoolResource(int max_item_num_per_channel, int start_num = 0); std::string DebugString() const override { return "ItemPoolResource"; } Status Add(uint64_t channel_id, uint64_t item_id, const std::shared_ptr& item); std::shared_ptr Sample(uint64_t channel_id, double* freq_factor, double* time_factor); Status Save(WritableFile* ostream, int shard_index, int shard_num); Status Restore(RandomAccessFile* istream, int64 buffer_size); inline int start_num() { return start_num_; } inline int max_item_num_per_channel() { return max_item_num_per_channel_; } bool Equal(const ItemPoolResource& other) const; void SampleChannelID(uint64_t* channel_id); private: absl::Mutex mu_; int start_num_, max_item_num_per_channel_; std::unique_ptr cache_ ABSL_GUARDED_BY(mu_); }; } // namespace monolith_tf } // namespace tensorflow #endif // MONOLITH_NATIVE_TRAINING_DATA_KERNELS_ITEM_POOL_KERNELS_H_ ================================================ FILE: monolith/native_training/data/kernels/kafka_kernels.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "rdkafkacpp.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/framework/resource_op_kernel.h" #include "monolith/native_training/data/kernels/feature_name_mapper_tf_bridge.h" #include "monolith/native_training/data/training_instance/cc/data_reader.h" namespace tensorflow { namespace monolith_tf { namespace { using Example = ::monolith::io::proto::Example; using ExampleBatch = ::monolith::io::proto::ExampleBatch; using Instance = ::parser::proto::Instance; using ::tensorflow::monolith_tf::BaseStreamReader; using ::tensorflow::monolith_tf::DataFormatOptions; using ::tensorflow::monolith_tf::ExampleBatchIterator; using ::tensorflow::monolith_tf::ExampleToInstance; using ::tensorflow::monolith_tf::FeatureNameMapper; using ::tensorflow::monolith_tf::FeatureNameMapperTfBridge; using ::tensorflow::monolith_tf::FeaturePruningType; using ::tensorflow::monolith_tf::FileStreamReader; using ::tensorflow::monolith_tf::InstanceToExample; using ::tensorflow::monolith_tf::PBIterator; using ::tensorflow::monolith_tf::StdinStreamReader; } // namespace class KafkaEventCb : public RdKafka::EventCb { public: KafkaEventCb() : run_(true) {} bool run() { return run_; } void event_cb(RdKafka::Event& event) { switch (event.type()) { case RdKafka::Event::EVENT_ERROR: LOG(ERROR) << "EVENT_ERROR: " << "(" << RdKafka::err2str(event.err()) << "): " << event.str(); { run_ = !event.fatal(); } break; case RdKafka::Event::EVENT_STATS: LOG(ERROR) << "EVENT_STATS: " << event.str(); break; case RdKafka::Event::EVENT_LOG: LOG(ERROR) << "EVENT_LOG: " << event.severity() << "-" << event.fac().c_str() << "-" << event.str().c_str(); break; case RdKafka::Event::EVENT_THROTTLE: LOG(ERROR) << "EVENT_THROTTLE: " << event.throttle_time() << "ms by " << event.broker_name() << " id " << static_cast(event.broker_id()); break; default: LOG(ERROR) << "EVENT: " << event.type() << " (" << RdKafka::err2str(event.err()) << "): " << event.str(); break; } } private: mutable mutex mu_; bool run_ TF_GUARDED_BY(mu_) = true; }; static int64 partition_count = 0; static int64 eof_count = 0; class KafkaRebalanceCb : public RdKafka::RebalanceCb { public: KafkaRebalanceCb() : run_(true) {} bool run() { return run_; } void rebalance_cb(RdKafka::KafkaConsumer* consumer, RdKafka::ErrorCode err, std::vector& partitions) { LOG(ERROR) << "REBALANCE: " << RdKafka::err2str(err); int timeout = 5000; // milliseconds LOG(ERROR) << "Retrieved committed offsets with status code: " << consumer->committed(partitions, timeout); for (int partition = 0; partition < partitions.size(); partition++) { // OFFSET MAPPINGS: // // RD_KAFKA_OFFSET_BEGINNING -2 // RD_KAFKA_OFFSET_END -1 // RD_KAFKA_OFFSET_STORED -1000 // RD_KAFKA_OFFSET_INVALID -1001 LOG(INFO) << "REBALANCE: " << partitions[partition]->topic() << "[" << partitions[partition]->partition() << "], " << "OFFSET: " << partitions[partition]->offset() << " " << "ERROR_CODE: " << partitions[partition]->err(); } if (err == RdKafka::ERR__ASSIGN_PARTITIONS) { // librdkafka does not actually look up the stored offsets before // calling your rebalance callback, the partition offsets are set to // RD_KAFKA_OFFSET_INVALID at this point to allow us to change it to use // some sort of external offset store. But calling assign() with offset // RD_KAFKA_OFFSET_INVALID will cause librdkafka to look up the stored // offset on the broker. // If there was no stored offset it will fall back to `auto.offset.reset` // configuration parameter. LOG(INFO) << "REBALANCE: Assigning partitions"; consumer->assign(partitions); partition_count = static_cast(partitions.size()); } else { LOG(INFO) << "REBALANCE: Unassigning partitions"; consumer->unassign(); partition_count = 0; } eof_count = 0; } private: mutable mutex mu_; bool run_ TF_GUARDED_BY(mu_) = true; }; class KafkaGroupReadableResource : public ResourceBase { public: explicit KafkaGroupReadableResource(Env* env) : env_(env) {} virtual ~KafkaGroupReadableResource() { if (consumer_.get()) { consumer_->unassign(); consumer_->close(); consumer_.reset(nullptr); } } virtual Status Init(const std::vector& topics, const std::vector& metadata, const DataFormatOptions& options, const std::string& input_pb_type, const std::string& output_pb_type) { mutex_lock l(mu_); std::unique_ptr conf( RdKafka::Conf::create(RdKafka::Conf::CONF_GLOBAL)); std::unique_ptr conf_topic( RdKafka::Conf::create(RdKafka::Conf::CONF_TOPIC)); string errstr; RdKafka::Conf::ConfResult result = RdKafka::Conf::CONF_UNKNOWN; // The default kafka topic configurations are set first before // setting the global confs for (size_t i = 0; i < metadata.size(); i++) { if (metadata[i].find("conf.topic.") == 0) { std::vector parts = str_util::Split(metadata[i], "="); if (parts.size() != 2) { return errors::InvalidArgument("invalid topic configuration: ", metadata[i]); } result = conf_topic->set(parts[0].substr(11), parts[1], errstr); if (result != RdKafka::Conf::CONF_OK) { return errors::Internal("failed to do topic configuration:", metadata[i], "error:", errstr); } LOG(INFO) << "Kafka configuration: " << metadata[i]; } } if ((result = conf->set("default_topic_conf", conf_topic.get(), errstr)) != RdKafka::Conf::CONF_OK) { return errors::Internal("failed to set default_topic_conf:", errstr); } // Once the `default_topic_conf` is set, the global confs can now be set // without any risk of being overwritten. // Setting the global confs before setting the `default_topic_conf` // results in erratic behaviour. for (size_t i = 0; i < metadata.size(); i++) { if (metadata[i] != "" && metadata[i].find("conf.") == string::npos) { std::vector parts = str_util::Split(metadata[i], "="); if (parts.size() != 2) { return errors::InvalidArgument("invalid topic configuration: ", metadata[i]); } if ((result = conf->set(parts[0], parts[1], errstr)) != RdKafka::Conf::CONF_OK) { return errors::Internal("failed to do global configuration: ", metadata[i], "error:", errstr); } LOG(INFO) << "Kafka configuration: " << metadata[i]; } } // default consumer.properties: // bootstrap.servers=localhost:9092 // group.id=test-consumer-group string bootstrap_servers; if ((result = conf->get("bootstrap.servers", bootstrap_servers)) != RdKafka::Conf::CONF_OK) { bootstrap_servers = "localhost:9092"; if ((result = conf->set("bootstrap.servers", bootstrap_servers, errstr)) != RdKafka::Conf::CONF_OK) { return errors::Internal("failed to set bootstrap.servers [", bootstrap_servers, "]:", errstr); } } string group_id; if ((result = conf->get("group.id", group_id)) != RdKafka::Conf::CONF_OK) { group_id = "test-consumer-group"; if ((result = conf->set("group.id", group_id, errstr)) != RdKafka::Conf::CONF_OK) { return errors::Internal("failed to set group.id [", group_id, "]:", errstr); } } // Always set enable.partition.eof=true if ((result = conf->set("enable.partition.eof", "true", errstr)) != RdKafka::Conf::CONF_OK) { return errors::Internal("Failed to set enable.partition.eof=true :", errstr); } if ((result = conf->set("event_cb", &kafka_event_cb_, errstr)) != RdKafka::Conf::CONF_OK) { return errors::Internal("failed to set event_cb:", errstr); } if ((result = conf->set("rebalance_cb", &kafka_rebalance_cb_, errstr)) != RdKafka::Conf::CONF_OK) { return errors::Internal("failed to set rebalance_cb:", errstr); } // set max.poll.records configuration std::string batch_num_messages; if ((result = conf->get("batch.num.messages", batch_num_messages)) != RdKafka::Conf::CONF_OK) { batch_num_messages = "1024"; if ((result = conf->set("batch.num.messages", batch_num_messages, errstr)) != RdKafka::Conf::CONF_OK) { return errors::Internal("failed to set batch.num.messages [", batch_num_messages, "]:", errstr); } } sscanf(batch_num_messages.c_str(), "%d", &batch_num_messages_); LOG(INFO) << "max num of messages per batch: " << batch_num_messages_; LOG(INFO) << "Creating the kafka consumer"; consumer_.reset(RdKafka::KafkaConsumer::create(conf.get(), errstr)); if (!consumer_.get()) { return errors::Internal("failed to create consumer:", errstr); } for (int i = 0; i < topics.size(); i++) { LOG(INFO) << "Subscribing to the kafka topic: " << topics[i]; } RdKafka::ErrorCode err = consumer_->subscribe(topics); if (err != RdKafka::ERR_NO_ERROR) { return errors::Internal("failed to subscribe to topics: ", RdKafka::err2str(err)); } if (input_pb_type == "" && output_pb_type == "") { version_ = 1; } else { input_pb_type_ = data_format::StringToDataFormat(input_pb_type); output_pb_type_ = data_format::StringToDataFormat(output_pb_type); if (input_pb_type_ == data_format::UNKNOW || output_pb_type_ == data_format::UNKNOW) { return errors::Internal("input_pb_type or output_pb_type err:", input_pb_type, output_pb_type); } version_ = 2; } options_ = options; return Status::OK(); } class CurPBIteratorHandler { public: struct CurOutput : public PBIteratorWithDataFormatTransBaseOutput { std::vector exa_pb_list; std::vector ins_pb_list; std::vector eb_pb_list; std::vector string_list; }; Status HandleReaderNextStauts(const Status& s, const tstring& result) { if (s != Status::OK()) { if (s.code() != error::OUT_OF_RANGE) { LOG(ERROR) << "pb parse error:" << s; } return s; } if (result.size() == 0) { LOG(ERROR) << "tstring size can not be 0"; return errors::FailedPrecondition("tstring size=0"); } return Status::OK(); } template Status HandleReaderNextStauts(const Status& s, const TResult& result) { if (s != Status::OK()) { if (s.code() != error::OUT_OF_RANGE) { LOG(ERROR) << "pb parse error:" << s; } return s; } if (result.ByteSize() == 0) { LOG(ERROR) << "pb struct size can not be 0"; return errors::FailedPrecondition("pb size=0"); } return Status::OK(); } template Status HandleResult(TResult&& result, CurOutput* output) { return errors::Unimplemented("not implement"); } Status HandleResult(tstring&& serialized, CurOutput* output) { output->string_list.emplace_back(std::move(serialized)); return Status::OK(); } virtual Status HandleResult(Example&& exa_pb, CurOutput* output) { output->exa_pb_list.emplace_back(std::move(exa_pb)); return Status::OK(); } virtual Status HandleResult(Instance&& ins_pb, CurOutput* output) { output->ins_pb_list.emplace_back(std::move(ins_pb)); return Status::OK(); } virtual Status HandleResult(ExampleBatch&& eb_pb, CurOutput* output) { output->eb_pb_list.emplace_back(std::move(eb_pb)); return Status::OK(); } }; Status Next(const int64 index, const int64 message_poll_timeout, const int64 stream_timeout, std::function allocate_func) { mutex_lock l(mu_); // Initialize necessary variables int64 num_messages = 0; max_stream_timeout_polls_ = stream_timeout / message_poll_timeout; // Allocate memory for message_value and key_value vectors std::vector message_value, key_value; message_value.reserve(batch_num_messages_); // key_value.reserve(batch_num_messages_); std::unique_ptr message; while (consumer_.get() != nullptr && num_messages < batch_num_messages_) { if (!kafka_event_cb_.run()) { return errors::Internal( "failed to consume messages due to broker issue"); } message.reset(consumer_->consume(message_poll_timeout)); if (message->err() == RdKafka::ERR_NO_ERROR) { // Produce the line as output. message_value.emplace_back(tstring( static_cast(message->payload()), message->len())); // key_value.emplace_back( // (message->key() != nullptr) ? tstring(*message->key()) : ""); num_messages++; // Once a message has been successfully retrieved, the // `stream_timeout_polls_` is reset to 0. This allows the dataset // to wait for the entire `stream_timeout` duration when a data // slump occurs in the future. stream_timeout_polls_ = 0; } else if (message->err() == RdKafka::ERR__TRANSPORT) { // Not returning an error here as the consumer will try to re-connect. LOG(ERROR) << "Broker transport failure: " << message->errstr(); } else if (message->err() == RdKafka::ERR__PARTITION_EOF) { if (++eof_count == partition_count) { LOG(INFO) << "EOF reached for all " << partition_count << " partition(s)"; break; } } else if (message->err() == RdKafka::ERR__TIMED_OUT) { LOG(ERROR) << message->errstr(); stream_timeout_polls_++; break; } else { LOG(ERROR) << "ERROR Code " << message->err() << ", errstr is " << message->errstr(); } } // Prepare the outputs PBIteratorWithDataFormatTrans cur_iter( input_pb_type_, output_pb_type_); CurPBIteratorHandler::CurOutput output; if (version_ == 1) { output.string_list.swap(message_value); } else { // std::ostringstream imploded; // std::copy(message_value.begin(), message_value.end(), // std::ostream_iterator(imploded, "")); // std::string msg; std::unique_ptr reader; for (auto& mesg : message_value) { auto stream_reader = std::make_unique >(options_, mesg); if (input_pb_type_ == data_format::INSTANCE || input_pb_type_ == data_format::EXAMPLE) { reader = absl::make_unique( std::move(stream_reader), FeaturePruningType::PRUNING_RAW_FEATURE); } else { reader = absl::make_unique( std::move(stream_reader), FeaturePruningType::PRUNING_RAW_FEATURE, &fake_mapper_); } uint64 offset_ = 0; while (true) { Status s = cur_iter.GetNext(reader.get(), &output, &offset_); if (!s.ok()) break; offset_ = reader->GetOffset(); } } } size_t all_size = 0; if (output_pb_type_ == data_format::EXAMPLE) { all_size = output.exa_pb_list.size(); } else if (output_pb_type_ == data_format::EXAMPLEBATCH) { all_size = output.eb_pb_list.size(); } else if (output_pb_type_ == data_format::INSTANCE) { all_size = output.ins_pb_list.size(); } else { all_size = output.string_list.size(); } if (all_size < message_value.size()) { LOG(ERROR) << "get not enough pb:" << all_size << "," << message_value.size(); } TensorShape shape({static_cast(all_size)}); Tensor* message_tensor; Tensor* key_tensor; Tensor* continue_fetch_tensor; TF_RETURN_IF_ERROR(allocate_func(shape, &message_tensor, &key_tensor, &continue_fetch_tensor)); for (int i = 0; i < all_size; ++i) { if (output_pb_type_ == data_format::EXAMPLE) { message_tensor->flat()(i) = std::move(output.exa_pb_list[i]); } else if (output_pb_type_ == data_format::INSTANCE) { message_tensor->flat()(i) = std::move(output.ins_pb_list[i]); } else if (output_pb_type_ == data_format::EXAMPLEBATCH) { message_tensor->flat()(i) = std::move(output.eb_pb_list[i]); } else { message_tensor->flat()(i) = std::move(output.string_list[i]); } } if (stream_timeout_polls_ < max_stream_timeout_polls_) { continue_fetch_tensor->scalar()() = 1; } else { continue_fetch_tensor->scalar()() = 0; } LOG_EVERY_N_SEC(INFO, 60) << "consumer pb:" << all_size << "," << message_value.size(); return Status::OK(); } string DebugString() const override { return "KafkaBaseResource"; } mutable mutex mu_; Env* env_ TF_GUARDED_BY(mu_); std::unique_ptr consumer_ TF_GUARDED_BY(mu_); KafkaEventCb kafka_event_cb_ = KafkaEventCb(); KafkaRebalanceCb kafka_rebalance_cb_ = KafkaRebalanceCb(); int64 max_stream_timeout_polls_ = -1; int64 stream_timeout_polls_ = -1; int batch_num_messages_ = 1024; // std::unique_ptr reader_; // std::unique_ptr stream_reader_; data_format::DataFormat output_pb_type_; data_format::DataFormat input_pb_type_; DataFormatOptions options_; FeatureNameMapper fake_mapper_; int version_ = 1; }; class KafkaGroupReadableInitOp : public ResourceOpKernel { public: explicit KafkaGroupReadableInitOp(OpKernelConstruction* context) : ResourceOpKernel(context) { env_ = context->env(); OP_REQUIRES_OK(context, context->GetAttr("lagrangex_header", &options_.lagrangex_header)); OP_REQUIRES_OK(context, context->GetAttr("kafka_dump_prefix", &options_.kafka_dump_prefix)); OP_REQUIRES_OK(context, context->GetAttr("has_sort_id", &options_.has_sort_id)); OP_REQUIRES_OK(context, context->GetAttr("kafka_dump", &options_.kafka_dump)); std::string input_pb_type, output_pb_type; OP_REQUIRES_OK(context, context->GetAttr("input_pb_type", &input_pb_type_)); OP_REQUIRES_OK(context, context->GetAttr("output_pb_type", &output_pb_type_)); } private: void Compute(OpKernelContext* context) override { ResourceOpKernel::Compute(context); const Tensor* topics_tensor; OP_REQUIRES_OK(context, context->input("topics", &topics_tensor)); std::vector topics; for (int64 i = 0; i < topics_tensor->NumElements(); i++) { topics.push_back(topics_tensor->flat()(i)); } const Tensor* metadata_tensor; OP_REQUIRES_OK(context, context->input("metadata", &metadata_tensor)); std::vector metadata; for (int64 i = 0; i < metadata_tensor->NumElements(); i++) { metadata.push_back(metadata_tensor->flat()(i)); } OP_REQUIRES_OK(context, resource_->Init(topics, metadata, options_, input_pb_type_, output_pb_type_)); } Status CreateResource(KafkaGroupReadableResource** resource) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) override { *resource = new KafkaGroupReadableResource(env_); return Status::OK(); } private: mutable mutex mu_; Env* env_ TF_GUARDED_BY(mu_); DataFormatOptions options_; std::string output_pb_type_; std::string input_pb_type_; }; class KafkaGroupReadableNextOp : public OpKernel { public: explicit KafkaGroupReadableNextOp(OpKernelConstruction* context, int version = 1) : OpKernel(context), version_(version) { env_ = context->env(); } void Compute(OpKernelContext* context) override { KafkaGroupReadableResource* resource; OP_REQUIRES_OK(context, GetResourceFromContext(context, "input", &resource)); core::ScopedUnref unref(resource); const Tensor* index_tensor; OP_REQUIRES_OK(context, context->input("index", &index_tensor)); const int64 index = index_tensor->scalar()(); const Tensor* message_poll_timeout_tensor; OP_REQUIRES_OK(context, context->input("message_poll_timeout", &message_poll_timeout_tensor)); const int64 message_poll_timeout = message_poll_timeout_tensor->scalar()(); const Tensor* stream_timeout_tensor; OP_REQUIRES_OK(context, context->input("stream_timeout", &stream_timeout_tensor)); const int64 stream_timeout = stream_timeout_tensor->scalar()(); OP_REQUIRES_OK( context, resource->Next( index, message_poll_timeout, stream_timeout, [&](const TensorShape& shape, Tensor** message, Tensor** key, Tensor** continue_fetch) -> Status { TF_RETURN_IF_ERROR(context->allocate_output(0, shape, message)); if (version_ == 2) { TF_RETURN_IF_ERROR(context->allocate_output(1, TensorShape({}), continue_fetch)); } else { TF_RETURN_IF_ERROR(context->allocate_output(1, shape, key)); TF_RETURN_IF_ERROR(context->allocate_output(2, TensorShape({}), continue_fetch)); } return Status::OK(); })); } private: int version_ = 1; mutable mutex mu_; Env* env_ TF_GUARDED_BY(mu_); }; class KafkaGroupReadableNextOpV2 : public KafkaGroupReadableNextOp { public: explicit KafkaGroupReadableNextOpV2(OpKernelConstruction* context) : KafkaGroupReadableNextOp(context, 2) {} }; namespace { REGISTER_KERNEL_BUILDER(Name("KafkaGroupReadableInit").Device(DEVICE_CPU), KafkaGroupReadableInitOp); REGISTER_KERNEL_BUILDER(Name("KafkaGroupReadableNext").Device(DEVICE_CPU), KafkaGroupReadableNextOp); REGISTER_KERNEL_BUILDER(Name("KafkaGroupReadableNextV2").Device(DEVICE_CPU), KafkaGroupReadableNextOpV2); } // namespace } // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/data/kernels/label_normalization_kernel.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include #include "absl/container/flat_hash_set.h" #include "absl/strings/str_format.h" #include "monolith/native_training/data/kernels/internal/label_utils.h" #include "monolith/native_training/data/training_instance/cc/instance_utils.h" #include "monolith/native_training/data/training_instance/cc/pb_variant.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/framework/tensor.h" namespace tensorflow { namespace monolith_tf { using IFeature = ::idl::matrix::proto::Feature; using Instance = ::parser::proto::Instance; using Example = ::monolith::io::proto::Example; using LineId = ::idl::matrix::proto::LineId; // label_norm: class LabelNormalizationOp : public OpKernel { public: explicit LabelNormalizationOp(OpKernelConstruction *context) : OpKernel(context) { OP_REQUIRES_OK(context, context->GetAttr("norm_methods", &norm_methods_)); OP_REQUIRES_OK(context, context->GetAttr("norm_values", &norm_values_)); OP_REQUIRES_OK(context, context->GetAttr("variant_type", &variant_type_)); if (norm_methods_.size() != norm_values_.size()) { LOG(FATAL) << "Invalid 'norm_methods_', and 'norm_values', the size " "should match each other.!"; } if (variant_type_ != "instance" && variant_type_ != "example") { LOG(FATAL) << "Invalid 'variant_type', please choose on from " "['instance', 'example']!"; } } void Compute(OpKernelContext *context) override { const Tensor &input_tensor = context->input(0); Tensor *output_tensor = nullptr; OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(), &output_tensor)); bool is_instance = variant_type_ == "instance"; if (is_instance) { Instance instance; instance.CopyFrom(*input_tensor.scalar()().get()); output_tensor->scalar()() = std::move(instance); } else { Example example; example.CopyFrom(*input_tensor.scalar()().get()); output_tensor->scalar()() = std::move(example); } auto labels = GetLabel(output_tensor, is_instance); for (int i = 0; i < labels->size(); ++i) { for (int j = 0; j < norm_methods_.size(); ++j) { float label = labels->Get(i); const auto &norm_method = norm_methods_[j]; const auto &norm_value = norm_values_[j]; if (norm_method == "log") { label = std::max(label + norm_value, 0.0f); label = log(label); } else if (norm_method == "scale") { label /= norm_value; } else if (norm_method == "scale2int") { label = int32_t(label / norm_value); } else if (norm_method == "pow") { label = std::pow(label + 1, norm_value); } else if (norm_method == "repow") { if (label > 0) { label = std::pow(label, norm_value); } else { label = 0; } } else if (norm_method == "scalelog") { label /= norm_value; label = log(std::max(label + 1, 0.0f)); } else { assert(false && "illegal label norm params"); } labels->Set(i, label); } } } private: static ::google::protobuf::RepeatedField *GetLabel( Tensor *output_tensor, bool is_instance) { if (is_instance) { return output_tensor->scalar()() .get() ->mutable_label(); } else { return output_tensor->scalar()().get()->mutable_label(); } } std::vector norm_methods_; std::vector norm_values_; std::string variant_type_; }; namespace { REGISTER_KERNEL_BUILDER(Name("LabelNormalization").Device(DEVICE_CPU), LabelNormalizationOp) } // namespace } // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/data/kernels/label_upper_bound_kernel.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include #include "absl/container/flat_hash_set.h" #include "absl/strings/str_format.h" #include "monolith/native_training/data/kernels/internal/label_utils.h" #include "monolith/native_training/data/training_instance/cc/instance_utils.h" #include "monolith/native_training/data/training_instance/cc/pb_variant.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/framework/tensor.h" namespace tensorflow { namespace monolith_tf { using IFeature = ::idl::matrix::proto::Feature; using Instance = ::parser::proto::Instance; using Example = ::monolith::io::proto::Example; using LineId = ::idl::matrix::proto::LineId; // label_upper_bound: class LabelUpperBoundOp : public OpKernel { public: explicit LabelUpperBoundOp(OpKernelConstruction *context) : OpKernel(context) { OP_REQUIRES_OK( context, context->GetAttr("label_upper_bounds", &label_upper_bounds_)); OP_REQUIRES_OK(context, context->GetAttr("variant_type", &variant_type_)); if (variant_type_ != "instance" && variant_type_ != "example") { LOG(FATAL) << "Invalid 'variant_type', please choose on from " "['instance', 'example']!"; } } void Compute(OpKernelContext *context) override { const Tensor &input_tensor = context->input(0); Tensor *output_tensor = nullptr; OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(), &output_tensor)); bool is_instance = variant_type_ == "instance"; if (is_instance) { Instance instance; instance.CopyFrom(*input_tensor.scalar()().get()); output_tensor->scalar()() = std::move(instance); } else { Example example; example.CopyFrom(*input_tensor.scalar()().get()); output_tensor->scalar()() = std::move(example); } auto labels = GetLabel(output_tensor, is_instance); if (labels->size() < label_upper_bounds_.size()) { LOG_EVERY_N_SEC(ERROR, 60) << absl::StrFormat( "Label size(=%ld) should be >= label_upper_bounds size(=%ld), please " "investigate!", labels->size(), label_upper_bounds_.size()); return; } else { for (size_t i = 0; i < label_upper_bounds_.size(); ++i) { if (labels->Get(i) > label_upper_bounds_[i]) { labels->Set(i, label_upper_bounds_[i]); } } } } private: static ::google::protobuf::RepeatedField *GetLabel( Tensor *output_tensor, bool is_instance) { if (is_instance) { return output_tensor->scalar()() .get() ->mutable_label(); } else { return output_tensor->scalar()().get()->mutable_label(); } } std::vector label_upper_bounds_; std::string variant_type_; }; namespace { REGISTER_KERNEL_BUILDER(Name("LabelUpperBound").Device(DEVICE_CPU), LabelUpperBoundOp) } // namespace } // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/data/kernels/map_id_kernels.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/framework/tensor.h" namespace tensorflow { namespace monolith_tf { namespace { template class MapIdOp : public OpKernel { public: explicit MapIdOp(OpKernelConstruction *context) : OpKernel(context) { std::vector from, to; OP_REQUIRES_OK(context, context->GetAttr("from_value", &from)); OP_REQUIRES_OK(context, context->GetAttr("to_value", &to)); OP_REQUIRES_OK(context, context->GetAttr("default_value", &default_value_)); CHECK_EQ(from.size(), to.size()); for (size_t i = 0; i < from.size(); ++i) { map_.insert({from[i], to[i]}); } } void Compute(OpKernelContext *context) override { const Tensor &input_tensor = context->input(0); Tensor *output_tensor = nullptr; OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(), &output_tensor)); auto input_flat = input_tensor.flat(); auto output_flat = output_tensor->flat(); for (size_t i = 0; i < input_flat.size(); ++i) { const T &value = input_flat(i); auto iter = map_.find(value); if (iter == map_.end()) { output_flat(i) = default_value_; } else { output_flat(i) = iter->second; } } } private: std::unordered_map map_; T default_value_; }; REGISTER_KERNEL_BUILDER( Name("MapId").Device(DEVICE_CPU).TypeConstraint("T"), MapIdOp); REGISTER_KERNEL_BUILDER( Name("MapId").Device(DEVICE_CPU).TypeConstraint("T"), MapIdOp); } // namespace } // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/data/kernels/merge_flow_dataset_kernel.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "absl/strings/str_cat.h" #include "monolith/native_training/data/training_instance/cc/pb_variant.h" #include "monolith/native_training/data/kernels/df_resource_kernel.h" #include "tensorflow/core/framework/cancellation.h" #include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/resource_mgr.h" namespace tensorflow { namespace data { namespace monolith_tf { using Instance = ::parser::proto::Instance; using Example = ::monolith::io::proto::Example; using Item = ::tensorflow::monolith_tf::Item; using QueueResource = ::tensorflow::monolith_tf::QueueResource; using VariantType = ::tensorflow::monolith_tf::VariantType; class MergeFlowDatasetOp : public DatasetOpKernel { public: static constexpr const char *const kDatasetType = "merge_dataset"; static constexpr const char *const kDataFlow = "data_flow"; static constexpr const char *const kMaxQueueSize = "max_queue_size"; static constexpr const char *const kVariantType = "variant_type"; explicit MergeFlowDatasetOp(OpKernelConstruction *ctx); protected: void MakeDataset(OpKernelContext *ctx, DatasetBase **output) override; private: class Dataset; std::vector data_flows_; int max_queue_size_; VariantType variant_type_; }; class MergeFlowDatasetOp::Dataset : public DatasetBase { public: Dataset(OpKernelContext *ctx, const std::vector &inputs, const std::vector &data_flows, int max_queue_size, const VariantType &variant_type) : DatasetBase(DatasetContext(ctx)), inputs_(inputs), data_flows_(data_flows), max_queue_size_(max_queue_size), variant_type_(variant_type) { for (const auto input : inputs_) { input->Ref(); } } ~Dataset() override { for (const auto input : inputs_) { input->Unref(); } } std::unique_ptr MakeIteratorInternal( const string &prefix) const override { return absl::make_unique( Iterator::Params{this, strings::StrCat(prefix, "::", kDatasetType)}); } const DataTypeVector &output_dtypes() const override { return inputs_[0]->output_dtypes(); } const std::vector &output_shapes() const override { return inputs_[0]->output_shapes(); } string DebugString() const override { return "This is the customized Dataset: DataFlowDataset"; } Status InputDatasets( std::vector *inputs) const override { for (const auto input : inputs_) { inputs->push_back(input); } return Status::OK(); } Status CheckExternalState() const override { for (const auto input : inputs_) { Status s = input->CheckExternalState(); if (!s.ok()) { return s; } } return Status::OK(); } void SetContainer(const std::string &container) { container_ = container; } std::string GetContainer() const { return container_; } protected: Status AsGraphDefInternal(SerializationContext *ctx, DatasetGraphDefBuilder *b, Node **output) const override { std::vector input_graph_nodes; input_graph_nodes.reserve(inputs_.size()); for (const auto &input : inputs_) { Node *input_node; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input, &input_node)); input_graph_nodes.emplace_back(input_node); } AttrValue data_flows_node; b->BuildAttrValue(data_flows_, &data_flows_node); AttrValue max_queue_size_node; b->BuildAttrValue(max_queue_size_, &max_queue_size_node); AttrValue variant_type_node; if (variant_type_ == VariantType::PBInstance) { b->BuildAttrValue("instance", &variant_type_node); } else { b->BuildAttrValue("example", &variant_type_node); } TF_RETURN_IF_ERROR( b->AddDataset(this, // dataset {}, {std::make_pair(0, input_graph_nodes)}, // inputs {{kDataFlow, data_flows_node}, {kMaxQueueSize, max_queue_size_node}, {kVariantType, variant_type_node}}, // attrs output)); // Node** return Status::OK(); } private: class Iterator : public DatasetIterator { public: explicit Iterator(const Params ¶ms) : DatasetIterator(params), mu_(std::make_shared()), output_mu_(std::make_shared()) {} ~Iterator() override { CancelThreads(); if (deregister_fn_) deregister_fn_(); for (const std::string &name : dataset()->data_flows_) { auto iter = df_to_queue_.find(name); if (iter != df_to_queue_.end()) { if (iter->second != nullptr) { delete iter->second; } df_to_queue_.erase(iter); } } } void CancelThreads() TF_LOCKS_EXCLUDED(mu_) { cancellation_manager_->StartCancel(); mutex_lock l(*mu_); cancelled_ = true; } Status Initialize(IteratorContext *ctx) override { mutex_lock l(*mu_); cancellation_manager_ = absl::make_unique(); IteratorContext::Params params(ctx); params.cancellation_manager = cancellation_manager_.get(); TF_RETURN_IF_ERROR( ::tensorflow::monolith_tf::RegisterCancellationCallback( ctx->cancellation_manager(), [this]() { CancelThreads(); }, &deregister_fn_)); Status s = Status::OK(); input_impls_.reserve(dataset()->inputs_.size()); for (const auto input : dataset()->inputs_) { std::unique_ptr input_impl; s.Update(input->MakeIterator(IteratorContext(params), this, prefix(), &input_impl)); input_impls_.push_back(input_impl.release()); } for (size_t i = 0; i < dataset()->data_flows_.size(); ++i) { std::string data_flows_name = dataset()->data_flows_[i]; QueueResource *queue = new QueueResource(dataset()->max_queue_size_); df_to_queue_.emplace(data_flows_name, queue); prefetch_thread_finished_.push_back(false); } return s; } Status GetNextInternal(IteratorContext *ctx, std::vector *out_tensors, bool *end_of_sequence) override { out_tensors->reserve(1); { mutex_lock l(*mu_); TF_RETURN_IF_ERROR(EnsureThreadStarted(ctx)); } { mutex_lock output_l(*output_mu_); do { for (size_t i = 0; i < dataset()->data_flows_.size(); ++i) { std::string name = dataset()->data_flows_[cur_]; const QueueResource *queue = df_to_queue_[name]; cur_ = (cur_ + 1) % dataset()->data_flows_.size(); if (queue->Empty()) { continue; } Item item = queue->Pop(); if (item.end_of_sequence) { out_tensors->clear(); *end_of_sequence = true; } else { for (const auto &tensor : item.out_tensors) { out_tensors->push_back(tensor); Instance *inst = out_tensors->at(0).scalar()().get(); inst->mutable_line_id()->set_data_source_name( absl::StrCat("data_source", inst->data_source_key())); } *end_of_sequence = item.end_of_sequence; } return Status::OK(); } bool finished = true; for (bool f : prefetch_thread_finished_) { finished = finished && f; } if (cancelled_ || finished) { out_tensors->clear(); *end_of_sequence = true; break; } } while (true); } return Status::OK(); } protected: std::shared_ptr CreateNode( IteratorContext *ctx, model::Node::Args args) const override { return model::MakeUnknownRatioNode(std::move(args)); } Status SaveInternal(SerializationContext *ctx, IteratorStateWriter *writer) override { return Status::OK(); } Status RestoreInternal(IteratorContext *ctx, IteratorStateReader *reader) override { return Status::OK(); } private: const std::shared_ptr mu_; const std::shared_ptr output_mu_; std::function deregister_fn_; std::unique_ptr cancellation_manager_; bool cancelled_ TF_GUARDED_BY(*mu_) = false; bool prefetch_thread_started_ TF_GUARDED_BY(*mu_) = false; std::vector prefetch_thread_finished_ TF_GUARDED_BY(*mu_); size_t cur_ = 0; std::vector input_impls_; std::vector prefetch_threads_; std::unordered_map df_to_queue_; Status EnsureThreadStarted(IteratorContext *ctx) TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) { if (!prefetch_thread_started_) { prefetch_thread_started_ = true; for (size_t i = 0; i < dataset()->data_flows_.size(); ++i) { std::string name = dataset()->data_flows_[i]; std::shared_ptr new_ctx = std::make_shared(*ctx); std::unique_ptr prefetch_thread_ = ctx->StartThread( name, [new_ctx, i, name, this]() { PrefetchThread(new_ctx, i, name); }); prefetch_threads_.push_back(prefetch_thread_.release()); } } return Status::OK(); } void PrefetchThread(const std::shared_ptr &ctx, size_t i, std::string name) { while (true) { { mutex_lock l(*mu_); if (cancelled_) { prefetch_thread_finished_[i] = true; break; } } if (!prefetch_thread_finished_[i]) { Item item; input_impls_[i]->GetNext(ctx.get(), &item.out_tensors, &item.end_of_sequence); if (item.end_of_sequence) { prefetch_thread_finished_[i] = true; break; } bool pushed = false; do { if (cancelled_ || prefetch_thread_finished_[i]) { break; } pushed = df_to_queue_[name]->TryPush(item); } while (!pushed); } } } }; const std::vector inputs_; std::vector data_flows_; int max_queue_size_; VariantType variant_type_; std::string container_; }; MergeFlowDatasetOp::MergeFlowDatasetOp(OpKernelConstruction *ctx) : DatasetOpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr(kDataFlow, &data_flows_)); OP_REQUIRES_OK(ctx, ctx->GetAttr(kMaxQueueSize, &max_queue_size_)); std::string variant_type; OP_REQUIRES_OK(ctx, ctx->GetAttr(kVariantType, &variant_type)); if (variant_type == "instance") { variant_type_ = VariantType::PBInstance; } else if (variant_type == "example") { variant_type_ = VariantType::PBExample; } else { LOG(ERROR) << "invalid variant_type: " << variant_type; ctx->SetStatus(Status(tensorflow::error::Code::INVALID_ARGUMENT, "invalid variant_type")); } } void MergeFlowDatasetOp::MakeDataset(OpKernelContext *ctx, DatasetBase **output) { OpInputList iplist; OP_REQUIRES_OK(ctx, ctx->input_list("inputs", &iplist)); std::vector inputs; for (const auto &ds : iplist) { DatasetBase *input; OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ds, &input)); inputs.push_back(input); } *output = new Dataset(ctx, inputs, data_flows_, max_queue_size_, variant_type_); std::string container; // OP_REQUIRES_OK(ctx, GetNodeAttr(def(), "container", &container)); static_cast(*output)->SetContainer(""); } namespace { REGISTER_KERNEL_BUILDER(Name("MergeFlowDataset").Device(DEVICE_CPU), MergeFlowDatasetOp); } // namespace } // namespace monolith_tf } // namespace data } // namespace tensorflow ================================================ FILE: monolith/native_training/data/kernels/multi_label_gen_kernel.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include "absl/container/flat_hash_set.h" #include "absl/strings/str_format.h" #include "absl/strings/str_split.h" #include "monolith/native_training/data/kernels/internal/label_utils.h" #include "monolith/native_training/data/training_instance/cc/pb_variant.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/framework/tensor.h" namespace tensorflow { namespace monolith_tf { using Instance = ::parser::proto::Instance; using Example = ::monolith::io::proto::Example; using LineId = ::idl::matrix::proto::LineId; using Action = google::protobuf::RepeatedField; using Label = google::protobuf::RepeatedField; class MultiLabelGenOp : public OpKernel { public: explicit MultiLabelGenOp(OpKernelConstruction *context) : OpKernel(context) { OP_REQUIRES_OK(context, context->GetAttr("task_num", &task_num_)); OP_REQUIRES_OK(context, context->GetAttr("head_field", &head_field_)); OP_REQUIRES_OK(context, context->GetAttr("pos_actions", &pos_actions_)); OP_REQUIRES_OK(context, context->GetAttr("neg_actions", &neg_actions_)); OP_REQUIRES_OK(context, context->GetAttr("use_origin_label", &use_origin_label_)); OP_REQUIRES_OK(context, context->GetAttr("pos_label", &pos_label_)); OP_REQUIRES_OK(context, context->GetAttr("neg_label", &neg_label_)); std::string action_priority; OP_REQUIRES_OK(context, context->GetAttr("action_priority", &action_priority)); std::vector action_priority_items = absl::StrSplit(action_priority, ","); for (size_t i = 0; i < action_priority_items.size(); ++i) { int32 action; absl::SimpleAtoi(action_priority_items[i], &action); action_priority_.emplace(action, static_cast(i)); } std::string head_to_index; OP_REQUIRES_OK(context, context->GetAttr("head_to_index", &head_to_index)); for (absl::string_view split : absl::StrSplit(head_to_index, ",")) { std::pair head_and_index = absl::StrSplit(split, ":"); int index; absl::SimpleAtoi(head_and_index.second, &index); CHECK_LT(index, task_num_); head_to_index_.emplace(head_and_index.first, index); } OP_REQUIRES_OK(context, context->GetAttr("variant_type", &variant_type_)); if (variant_type_ != "instance" && variant_type_ != "example") { LOG(FATAL) << "Invalid 'variant_type', please choose on from " "['instance', 'example']!"; } const ::google::protobuf::Descriptor *descriptor = ::idl::matrix::proto::LineId::GetDescriptor(); field = descriptor->FindFieldByName(head_field_); CHECK_EQ(field->is_repeated(), false); } void Compute(OpKernelContext *context) override { const Tensor &input_tensor = context->input(0); Tensor *output_tensor = nullptr; OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(), &output_tensor)); bool is_instance = variant_type_ == "instance"; if (is_instance) { Instance instance; instance.CopyFrom(*input_tensor.scalar()().get()); output_tensor->scalar()() = std::move(instance); } else { Example example; example.CopyFrom(*input_tensor.scalar()().get()); output_tensor->scalar()() = std::move(example); } LineId *line_id = GetLineId(output_tensor, is_instance); auto label = GetLabel(output_tensor, is_instance); float label_value = internal::INVALID_LABEL; if (use_origin_label_) { if (!label->empty()) { label_value = label->Get(0); } else { LOG_EVERY_N_SEC(ERROR, 60) << "Invalid data: label is empty, please investigate and retry!"; } } else { int64_t action; if (FindMostPriorAction(line_id->actions(), &action)) { if (std::find(pos_actions_.begin(), pos_actions_.end(), action) != pos_actions_.end()) { label_value = pos_label_; } else if (std::find(neg_actions_.begin(), neg_actions_.end(), action) != neg_actions_.end()) { label_value = neg_label_; } } } label->Clear(); label->Resize(task_num_, internal::INVALID_LABEL); std::string head_flag = GetHeadFlag(*line_id); if (head_to_index_.count(head_flag)) { int idx = head_to_index_[head_flag]; label->Set(idx, label_value); } } private: static LineId *GetLineId(Tensor *output_tensor, bool is_instance) { if (is_instance) { return output_tensor->scalar()() .get() ->mutable_line_id(); } else { return output_tensor->scalar()() .get() ->mutable_line_id(); } } static ::google::protobuf::RepeatedField *GetLabel( Tensor *output_tensor, bool is_instance) { if (is_instance) { return output_tensor->scalar()() .get() ->mutable_label(); } else { return output_tensor->scalar()().get()->mutable_label(); } } bool FindMostPriorAction(const Action &actions, int64_t *action) { if (actions.size() != 0) { if (action_priority_.empty() || actions.size() == 1) { *action = actions[0]; } else { int64_t priority = std::numeric_limits::max(); for (auto &act : actions) { auto iter = action_priority_.find(act); if (iter != action_priority_.end() && iter->second < priority) { *action = iter->first; priority = iter->second; } } if (priority == std::numeric_limits::max()) *action = actions[0]; } return true; } return false; } std::string GetHeadFlag(const LineId &line_id) { std::stringstream ss; if (head_field_ == "chnid") { ss << line_id.chnid(); } else if (head_field_ == "cid") { ss << line_id.cid(); } else { switch (field->cpp_type()) { case ::google::protobuf::FieldDescriptor::CPPTYPE_INT32: ss << reflection->GetInt32(line_id, field); case ::google::protobuf::FieldDescriptor::CPPTYPE_UINT32: ss << reflection->GetUInt32(line_id, field); case ::google::protobuf::FieldDescriptor::CPPTYPE_INT64: ss << reflection->GetInt64(line_id, field); case ::google::protobuf::FieldDescriptor::CPPTYPE_UINT64: ss << reflection->GetUInt64(line_id, field); case ::google::protobuf::FieldDescriptor::CPPTYPE_STRING: ss << reflection->GetString(line_id, field); default: ss << ""; } } return ss.str(); } int task_num_; std::string head_field_; std::map head_to_index_; std::unordered_map action_priority_; std::vector pos_actions_; std::vector neg_actions_; bool use_origin_label_; float pos_label_; float neg_label_; std::string variant_type_; const ::google::protobuf::FieldDescriptor *field; const ::google::protobuf::Reflection *reflection = ::idl::matrix::proto::LineId::GetReflection(); }; namespace { REGISTER_KERNEL_BUILDER(Name("MultiLabelGen").Device(DEVICE_CPU), MultiLabelGenOp) } // namespace } // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/data/kernels/negative_gen_dataset_kernel.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/data/kernels/negative_gen_dataset_kernel.h" #include #include "absl/container/flat_hash_map.h" #include "absl/strings/str_format.h" #include "monolith/native_training/data/kernels/item_pool_kernels.h" #include "monolith/native_training/data/training_instance/cc/pb_variant.h" #include "monolith/native_training/data/training_instance/cc/reader_util.h" #include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/framework/stats_aggregator.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/io/buffered_inputstream.h" #include "tensorflow/core/lib/io/inputbuffer.h" #include "tensorflow/core/lib/io/zlib_compression_options.h" #include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/lib/strings/str_util.h" namespace tensorflow { namespace data { namespace monolith_tf { using IFeature = ::idl::matrix::proto::Feature; using Instance = ::parser::proto::Instance; using LineId = ::idl::matrix::proto::LineId; using Action = google::protobuf::RepeatedField; using Label = google::protobuf::RepeatedField; using Example = ::monolith::io::proto::Example; using ::tensorflow::monolith_tf::FeatureNameMapper; using ::tensorflow::monolith_tf::FeatureNameMapperTfBridge; using EFeature = ::monolith::io::proto::NamedFeature; using ItemPoolResource = ::tensorflow::monolith_tf::ItemPoolResource; using ItemFeatures = ::tensorflow::monolith_tf::internal::ItemFeatures; static const int32 INVALID_NEGATIVE_ACTION = -99999; constexpr char kInputImplEmpty[] = "input_impl_empty"; static constexpr const char *const kDatasetType = "negtive_gen_dataset"; static constexpr const char *const kNegNum = "neg_num"; static constexpr const char *const kPerChannel = "per_channel"; static constexpr const char *const kChannelFeature = "channel_feature"; static constexpr const char *const kItemFeature = "item_features"; static constexpr const char *const kLabelIndex = "label_index"; static constexpr const char *const kPositiveLabel = "positive_label"; static constexpr const char *const kNegativeLabel = "negative_label"; static constexpr const char *const kNegativeAction = "negative_action"; static constexpr const char *const kActionPriority = "action_priority"; static constexpr const char *const kPositiveActions = "positive_actions"; static constexpr const char *const kCacheOnlyPos = "cache_only_pos"; static constexpr const char *const kCacheNegativeActions = "cache_negative_actions"; static constexpr const char *const kIndexFeature = "index_feature"; static constexpr const char *const kThrowOrigin = "throw_origin"; static constexpr const char *const kThrowOriginNeg = "throw_origin_neg"; static constexpr const char *const kRealNegInstanceWeight = "real_neg_instance_weight"; static constexpr const char *const kSampledNegInstanceWeight = "sampled_neg_instance_weight"; static constexpr const char *const kUnbiasSampledNeg = "unbias_sampled_neg"; static constexpr const char *const kOriginNegInPoolProba = "origin_neg_in_pool_proba"; static constexpr const char *const kNegSampleDeclayFactor = "neg_sample_declay_factor"; static constexpr const char *const kHardEasyRatio = "easy_hard_ratio"; static constexpr const char *const kVariantType = "variant_type"; class InnerIterator { public: InnerIterator(IteratorBase *input_impl, ItemPoolResource *resource, int32 neg_num, bool per_channel, const std::string &channel_feature, const std::vector &item_features, int32 label_index, int32 positive_label, int32 negative_label, int32 negative_action, const std::string &action_priority, const std::vector &positive_actions, const std::string &index_feature, bool throw_origin, bool throw_origin_neg, bool cache_only_pos, const std::vector &cache_negative_actions, float real_neg_instance_weight, float sampled_neg_instance_weight, bool unbias_sampled_neg, float origin_neg_in_pool_proba, float neg_sample_declay_factor, float easy_hard_ratio, const std::string &variant_type) : resource_(resource), index_(0), need_new_ins_(true), input_real_negative_instance_num_(0), input_instance_num_(0), output_instance_num_(0), generate_instance_num_(0), hard_sample_num_(0), easy_sample_num_(0), neg_num_(neg_num), per_channel_(per_channel), channel_feature_(channel_feature), item_features_(item_features.begin(), item_features.end()), label_index_(label_index), positive_label_(positive_label), negative_label_(negative_label), negative_action_(negative_action), positive_actions_(positive_actions.begin(), positive_actions.end()), index_feature_(index_feature), throw_origin_(throw_origin), throw_origin_neg_(throw_origin_neg), cache_only_pos_(cache_only_pos), cache_negative_actions_(cache_negative_actions.begin(), cache_negative_actions.end()), real_neg_instance_weight_(real_neg_instance_weight), sampled_neg_instance_weight_(sampled_neg_instance_weight), unbias_sampled_neg_(unbias_sampled_neg), origin_neg_in_pool_proba_(origin_neg_in_pool_proba), neg_sample_declay_factor_(neg_sample_declay_factor), easy_hard_ratio_(easy_hard_ratio) { input_impl_ = input_impl; tensors_ = new std::vector(); tensors_->reserve(1); std::vector action_priority_items = absl::StrSplit(action_priority, ","); for (size_t i = 0; i < action_priority_items.size(); ++i) { int32 action; if (action_priority_items[i].empty()) { continue; } CHECK(absl::SimpleAtoi(action_priority_items[i], &action)); action_priority_.insert({action, static_cast(i)}); } if (variant_type == "instance") { variant_type_ = VariantType::PBInstance; if (index_feature_.empty()) { has_index_feature_ = false; index_slot_ = 0; } else { has_index_feature_ = true; CHECK(absl::SimpleAtoi(index_feature_, &index_slot_)); } if (channel_feature_.empty()) { channel_slot_ = 3; } else { CHECK(absl::SimpleAtoi(channel_feature_, &channel_slot_)); } for (const auto &fname : item_features_) { int32 slot; CHECK(absl::SimpleAtoi(fname, &slot)); item_slots_.insert(slot); } } else { variant_type_ = VariantType::PBExample; index_slot_ = 0; has_index_feature_ = !index_feature_.empty(); channel_slot_ = 3; } } ~InnerIterator() { delete tensors_; } Status GetNext(IteratorContext *ctx, std::vector *out_tensors, bool *end_of_sequence) { if (end_of_sequence_) { *end_of_sequence = end_of_sequence_; out_tensors->clear(); } while (!end_of_sequence_) { Status s = MaybeGetNextRealInstance(ctx); if (!s.ok()) { return s; } if (end_of_sequence_) { *end_of_sequence = end_of_sequence_; out_tensors->clear(); break; } bool is_positive = IsPositive(); if (index_ == 0 && Cacheable(is_positive)) { SaveToCache(is_positive); } if (index_ == 0 && !is_positive) { input_real_negative_instance_num_++; } if (is_positive && index_ < neg_num_) { Tensor tensor; if (BuildNegativeTensor(ctx, &tensor)) { *end_of_sequence = end_of_sequence_; out_tensors->push_back(std::move(tensor)); index_++; generate_instance_num_++; output_instance_num_++; break; } } need_new_ins_ = true; if (Emitable(is_positive)) { *end_of_sequence = end_of_sequence_; if (is_positive) { SetInstanceWeight(&tensors_->back(), 1.0); } else { float instance_weight = real_neg_instance_weight_ > 0.00001 ? real_neg_instance_weight_ : 1.0; SetInstanceWeight(&tensors_->back(), instance_weight); } out_tensors->push_back(tensors_->back()); output_instance_num_++; break; } } LOG_EVERY_N_SEC(INFO, 180) << "input_instance_num: " << input_instance_num_; LOG_EVERY_N_SEC(INFO, 180) << "input_real_negative_instance_num: " << input_real_negative_instance_num_; LOG_EVERY_N_SEC(INFO, 180) << "output_instance_num: " << output_instance_num_; LOG_EVERY_N_SEC(INFO, 180) << "generate_instance_num: " << generate_instance_num_; LOG_EVERY_N_SEC(INFO, 180) << "hard_sample_num: " << hard_sample_num_; LOG_EVERY_N_SEC(INFO, 180) << "easy_sample_num: " << easy_sample_num_; return Status::OK(); } private: Status MaybeGetNextRealInstance(IteratorContext *ctx) { if (need_new_ins_ && !end_of_sequence_) { tensors_->clear(); TF_RETURN_IF_ERROR( input_impl_->GetNext(ctx, tensors_, &end_of_sequence_)); if (end_of_sequence_) { input_impl_ = nullptr; return Status::OK(); } need_new_ins_ = false; gcids_ = GetGidAndChannelId(); ++input_instance_num_; index_ = 0; // Got one new instance, reset the neg count } return Status::OK(); } template inline T *GetCurrent() { Variant *variant = &tensors_->back().scalar()(); return variant->get(); } inline const LineId *GetLineId() { if (variant_type_ == VariantType::PBInstance) { Instance *instance = GetCurrent(); return &instance->line_id(); } else if (variant_type_ == VariantType::PBExample) { Example *example = GetCurrent(); return &example->line_id(); } else { return nullptr; } } inline const Label *GetLabel() { if (variant_type_ == VariantType::PBInstance) { Instance *instance = GetCurrent(); return &instance->label(); } else if (variant_type_ == VariantType::PBExample) { Example *example = GetCurrent(); return &example->label(); } else { return nullptr; } } bool IsPositive() { bool is_pos = false; const LineId *line_id = GetLineId(); const Label *label = GetLabel(); if (!positive_actions_.empty() && line_id != nullptr) { if (!line_id->actions().empty()) { int64_t action; FindMostPriorAction(line_id->actions(), &action); auto iter = positive_actions_.find(action); is_pos = iter != positive_actions_.end(); } } else if (label != nullptr) { if (label_index_ < label->size()) { is_pos = label->at(label_index_) == positive_label_; } else { LOG_EVERY_N_SEC(ERROR, 60) << absl::StrFormat( "label_index_ should be less than label_size, while got %d vs %d", label_index_, label->size()); } } return is_pos; } inline bool Cacheable(bool is_positive) { if (!cache_only_pos_ || is_positive) { return true; } else if (cache_negative_actions_.empty()) { return false; } else { const LineId *line_id = GetLineId(); for (auto &action : line_id->actions()) { if (cache_negative_actions_.count(action)) { return true; } } return false; } } inline bool Emitable(bool is_positive) { return (!throw_origin_ && (!throw_origin_neg_ || is_positive)); } std::pair GetGidAndChannelId() { uint64_t gid = 0, cid = 3; if (variant_type_ == VariantType::PBInstance) { const Instance *instance = GetCurrent(); gid = instance->line_id().item_id(); if (per_channel_ || has_index_feature_) { for (const auto fid : instance->fid()) { int32 slot = slot_id_v1(fid); if (per_channel_ && channel_slot_ == slot) { cid = fid; } if (has_index_feature_ && slot == index_slot_) { gid = fid; } } } } else if (variant_type_ == VariantType::PBExample) { const Example *example = GetCurrent(); gid = example->line_id().item_id(); if (per_channel_ || has_index_feature_) { for (const auto &named_feature : example->named_feature()) { const std::string &feature_name = named_feature.name(); auto &feature_value = named_feature.feature(); if (per_channel_ && channel_feature_ == feature_name) { if (feature_value.type_case() == ::monolith::io::proto::Feature::kFidV1List && feature_value.fid_v1_list().value_size() > 0) { cid = feature_value.fid_v1_list().value(0); LOG_EVERY_N_SEC(INFO, 180) << "Use Fidv1."; } else if (feature_value.type_case() == ::monolith::io::proto::Feature::kFidV2List && feature_value.fid_v2_list().value_size() > 0) { cid = feature_value.fid_v2_list().value(0); LOG_EVERY_N_SEC(INFO, 180) << "Use Fidv2."; } else { LOG_EVERY_N_SEC(INFO, 180) << "Use Default cid."; } } if (has_index_feature_ && index_feature_ == feature_name) { if (feature_value.type_case() == ::monolith::io::proto::Feature::kFidV1List && feature_value.fid_v1_list().value_size() > 0) { gid = feature_value.fid_v1_list().value(0); } else if (feature_value.type_case() == ::monolith::io::proto::Feature::kFidV2List && feature_value.fid_v2_list().value_size() > 0) { gid = feature_value.fid_v2_list().value(0); } } } } } return std::pair(gid, cid); } void SetInstanceWeight(Tensor *tensor, float instance_weight) { if (variant_type_ == VariantType::PBInstance) { auto *instance = tensor->scalar()().get(); instance->set_instance_weight(instance_weight); } else { auto *example = tensor->scalar()().get(); example->set_instance_weight(instance_weight); } } void SaveToCache(bool is_positive) { std::shared_ptr item_features = std::make_shared(); uint64_t item_id = gcids_.first; uint64_t channel_id = gcids_.second; if (channel_id == 0) { return; } if (!is_positive && origin_neg_in_pool_proba_ >= 0 && origin_neg_in_pool_proba_ < 1) { float proba = (std::rand() % 100) / 100.0; if (proba > origin_neg_in_pool_proba_) { return; } } if (variant_type_ == VariantType::PBExample) { Example *example = GetCurrent(); if (is_positive) { named_feature_list_.Clear(); } for (auto &named_feature : example->named_feature()) { const std::string &feature_name = named_feature.name(); if (item_features_.count(feature_name) != 0) { item_features->example_features[feature_name] = named_feature; } else if (is_positive) { named_feature_list_.Add(named_feature); } } } else { Instance *instance = GetCurrent(); if (is_positive) { fid_list_.Clear(); } for (auto fid : instance->fid()) { int32 slot = slot_id_v1(fid); if (item_slots_.count(slot) != 0) { // only cache group slots item_features->fids.emplace_back(fid); } else if (is_positive) { fid_list_.Add(fid); } } } item_features->item_id = item_id; resource_->Add(channel_id, item_id, item_features); } template void SetLabelAndLineId(T *neg, uint64_t item_id) { if (label_index_ < neg->label_size()) { neg->set_label(label_index_, negative_label_); } else { LOG_EVERY_N_SEC(ERROR, 60) << absl::StrFormat( "label_index_ should be less than label_size, while got %d vs %d", label_index_, neg->label_size()); } neg->mutable_line_id()->set_item_id(item_id); if (negative_action_ != INVALID_NEGATIVE_ACTION) { neg->mutable_line_id()->clear_actions(); neg->mutable_line_id()->add_actions(negative_action_); } } bool BuildNegativeTensor(IteratorContext *ctx, Tensor *res) { // hard_easy neg when per_channel enabled uint64_t channel_id = gcids_.second; if (per_channel_ && NeedEasyNeg(easy_hard_ratio_)) { resource_->SampleChannelID(&channel_id); easy_sample_num_++; } else { hard_sample_num_++; } if (channel_id == 0) { return false; } double freq_factor, time_factor; std::shared_ptr cached_item = resource_->Sample(channel_id, &freq_factor, &time_factor); if (!cached_item) { return false; } uint64_t item_id = cached_item->item_id; Tensor tensor(ctx->allocator({}), DT_VARIANT, TensorShape({})); float instance_weight; if (sampled_neg_instance_weight_ > 0.00001) { instance_weight = sampled_neg_instance_weight_; } else if (unbias_sampled_neg_) { instance_weight = 1.0 + neg_num_ * std::pow(time_factor, neg_sample_declay_factor_) * freq_factor; } else { instance_weight = 1.0; } if (variant_type_ == VariantType::PBExample) { Example *example = GetCurrent(); Example new_example; new_example.mutable_line_id()->CopyFrom(example->line_id()); for (const auto &label : example->label()) { new_example.add_label(label); } const auto &cached_example_features = cached_item->example_features; auto *mutable_named_feature = new_example.mutable_named_feature(); for (const auto &nf : named_feature_list_) { mutable_named_feature->Add()->CopyFrom(nf); } for (const auto &nf : cached_example_features) { mutable_named_feature->Add()->CopyFrom(nf.second); } SetLabelAndLineId(&new_example, item_id); new_example.set_instance_weight(instance_weight); tensor.scalar()() = std::move(new_example); } else { Instance *instance = GetCurrent(); Instance new_instance; new_instance.CopyFrom(*instance); const auto &cached_fid_list = cached_item->fids; google::protobuf::RepeatedField<::google::protobuf::uint64> fid_list = fid_list_; // copy for (auto fid : cached_fid_list) { fid_list.Add(fid); } new_instance.mutable_fid()->Swap(&fid_list); SetLabelAndLineId(&new_instance, item_id); new_instance.set_instance_weight(instance_weight); tensor.scalar()() = std::move(new_instance); } *res = std::move(tensor); return true; } bool FindMostPriorAction(const Action &actions, int64_t *action) { if (actions.size() != 0) { if (action_priority_.empty() || actions.size() == 1) { *action = actions[0]; } else { int64_t priority = std::numeric_limits::max(); for (auto &act : actions) { auto iter = action_priority_.find(act); if (iter != action_priority_.end() && iter->second < priority) { *action = iter->first; priority = iter->second; } } if (priority == std::numeric_limits::max()) *action = actions[0]; } return true; } return false; } bool NeedEasyNeg(float easy_hard_ratio) { return static_cast(std::rand()) / RAND_MAX < easy_hard_ratio; } ItemPoolResource *resource_ = nullptr; bool end_of_sequence_ = false; std::vector *tensors_ = nullptr; IteratorBase *input_impl_ = nullptr; int index_ = 0; bool need_new_ins_ = true; // stats variables int64 input_real_negative_instance_num_ = 0; int64 input_instance_num_ = 0; int64 output_instance_num_ = 0; int64 generate_instance_num_ = 0; // hard & easy stats int64 hard_sample_num_ = 0; int64 easy_sample_num_ = 0; int32 neg_num_; bool per_channel_; std::string channel_feature_; int32 channel_slot_; std::unordered_set item_features_; std::unordered_set item_slots_; int32 label_index_; int32 positive_label_; int32 negative_label_; int32 negative_action_; std::unordered_set positive_actions_; std::unordered_map action_priority_; std::string index_feature_; int32 index_slot_; bool has_index_feature_; bool throw_origin_; bool throw_origin_neg_; bool cache_only_pos_; std::unordered_set cache_negative_actions_; float real_neg_instance_weight_; float sampled_neg_instance_weight_; bool unbias_sampled_neg_; float origin_neg_in_pool_proba_; float neg_sample_declay_factor_; float easy_hard_ratio_; VariantType variant_type_; std::pair gcids_; google::protobuf::RepeatedField<::google::protobuf::uint64> fid_list_; google::protobuf::RepeatedField<::monolith::io::proto::NamedFeature> named_feature_list_; }; class InstanceNegativeGenDatasetOp::Dataset : public DatasetBase { public: Dataset(OpKernelContext *ctx, const DatasetBase *input, int32 neg_num, bool per_channel, const std::string &channel_feature, const std::vector &item_features, int32 label_index, int32 positive_label, int32 negative_label, int32 negative_action, std::string action_priority, const std::vector &positive_actions, const std::string &index_feature, bool throw_origin, bool throw_origin_neg, bool cache_only_pos, const std::vector &cache_negative_actions, float real_neg_instance_weight, float sampled_neg_instance_weight, bool unbias_sampled_neg, float origin_neg_in_pool_proba, float neg_sample_declay_factor, float easy_hard_ratio, const std::string &variant_type, FeatureNameMapper *mapper) : DatasetBase(DatasetContext(ctx)), input_(input), neg_num_(neg_num), per_channel_(per_channel), channel_feature_(channel_feature), item_features_(item_features), label_index_(label_index), positive_label_(positive_label), negative_label_(negative_label), negative_action_(negative_action), action_priority_(action_priority), positive_actions_(positive_actions), index_feature_(index_feature), throw_origin_(throw_origin), throw_origin_neg_(throw_origin_neg), cache_only_pos_(cache_only_pos), cache_negative_actions_(cache_negative_actions), real_neg_instance_weight_(real_neg_instance_weight), sampled_neg_instance_weight_(sampled_neg_instance_weight), unbias_sampled_neg_(unbias_sampled_neg), origin_neg_in_pool_proba_(origin_neg_in_pool_proba), neg_sample_declay_factor_(neg_sample_declay_factor), easy_hard_ratio_(easy_hard_ratio), variant_type_(variant_type), mapper_(mapper) { input_->Ref(); const Tensor *pool_tensor_; OP_REQUIRES_OK(ctx, ctx->input("pool", &pool_tensor_)); handle_ = pool_tensor_->scalar()(); OP_REQUIRES_OK(ctx, LookupResource(ctx, handle_, &resource_)); if (variant_type_ == "example") { std::vector valid_feature_names = item_features_; if (!channel_feature_.empty()) { valid_feature_names.push_back(channel_feature_); } if (!index_feature_.empty()) { valid_feature_names.push_back(index_feature_); } mapper_->RegisterValidNames(valid_feature_names); } } ~Dataset() override { input_->Unref(); core::ScopedUnref unref(resource_); } std::unique_ptr MakeIteratorInternal( const string &prefix) const override { return absl::make_unique( Iterator::Params{this, strings::StrCat(prefix, "::", kDatasetType)}); } const DataTypeVector &output_dtypes() const override { return input_->output_dtypes(); } const std::vector &output_shapes() const override { return input_->output_shapes(); } string DebugString() const override { return "This is the customized Dataset: NegativeGenV2"; } Status InputDatasets( std::vector *inputs) const override { inputs->push_back(input_); return Status::OK(); } Status CheckExternalState() const override { return input_->CheckExternalState(); } protected: Status AsGraphDefInternal(SerializationContext *ctx, DatasetGraphDefBuilder *b, Node **output) const override { Node *input_graph_node; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); Tensor handle(DT_RESOURCE, TensorShape({})); handle.scalar()() = handle_; Node *pool_node; TF_RETURN_IF_ERROR(b->AddTensor(handle, &pool_node)); AttrValue neg_num_node; b->BuildAttrValue(neg_num_, &neg_num_node); AttrValue per_channel_node; b->BuildAttrValue(per_channel_, &per_channel_node); AttrValue channel_feature_node; b->BuildAttrValue(channel_feature_, &channel_feature_node); AttrValue item_features_node; b->BuildAttrValue(item_features_, &item_features_node); AttrValue label_index_node; b->BuildAttrValue(label_index_, &label_index_node); AttrValue positive_label_node; b->BuildAttrValue(positive_label_, &positive_label_node); AttrValue negative_label_node; b->BuildAttrValue(negative_label_, &negative_label_node); AttrValue negative_action_node; b->BuildAttrValue(negative_action_, &negative_action_node); AttrValue action_priority_node; b->BuildAttrValue(action_priority_, &action_priority_node); AttrValue positive_actions_node; b->BuildAttrValue(positive_actions_, &positive_actions_node); AttrValue index_feature_node; b->BuildAttrValue(index_feature_, &index_feature_node); AttrValue throw_origin_node; b->BuildAttrValue(throw_origin_, &throw_origin_node); AttrValue throw_origin_neg_node; b->BuildAttrValue(throw_origin_neg_, &throw_origin_neg_node); AttrValue cache_only_pos_node; b->BuildAttrValue(cache_only_pos_, &cache_only_pos_node); AttrValue cache_negative_actions_node; b->BuildAttrValue(cache_negative_actions_, &cache_negative_actions_node); AttrValue real_neg_instance_weight_node; b->BuildAttrValue(real_neg_instance_weight_, &real_neg_instance_weight_node); AttrValue sampled_neg_instance_weight_node; b->BuildAttrValue(sampled_neg_instance_weight_, &sampled_neg_instance_weight_node); AttrValue unbias_sampled_neg_node; b->BuildAttrValue(unbias_sampled_neg_, &unbias_sampled_neg_node); AttrValue origin_neg_in_pool_proba_node; b->BuildAttrValue(origin_neg_in_pool_proba_, &origin_neg_in_pool_proba_node); AttrValue neg_sample_declay_factor_node; b->BuildAttrValue(neg_sample_declay_factor_, &neg_sample_declay_factor_node); AttrValue easy_hard_ratio_node; b->BuildAttrValue(easy_hard_ratio_, &easy_hard_ratio_node); AttrValue variant_type_node; b->BuildAttrValue(variant_type_, &variant_type_node); TF_RETURN_IF_ERROR(b->AddDataset( this, // dataset {input_graph_node, pool_node}, // inputs {{kNegNum, neg_num_node}, {kPerChannel, per_channel_node}, {kChannelFeature, channel_feature_node}, {kItemFeature, item_features_node}, {kLabelIndex, label_index_node}, {kPositiveLabel, positive_label_node}, {kNegativeLabel, negative_label_node}, {kNegativeAction, negative_action_node}, {kActionPriority, action_priority_node}, {kPositiveActions, positive_actions_node}, {kIndexFeature, index_feature_node}, {kThrowOrigin, throw_origin_node}, {kThrowOriginNeg, throw_origin_neg_node}, {kCacheOnlyPos, cache_only_pos_node}, {kCacheNegativeActions, cache_negative_actions_node}, {kRealNegInstanceWeight, real_neg_instance_weight_node}, {kSampledNegInstanceWeight, sampled_neg_instance_weight_node}, {kUnbiasSampledNeg, unbias_sampled_neg_node}, {kOriginNegInPoolProba, origin_neg_in_pool_proba_node}, {kNegSampleDeclayFactor, neg_sample_declay_factor_node}, {kHardEasyRatio, easy_hard_ratio_node}, {kVariantType, variant_type_node}}, output)); // Node** return Status::OK(); } private: class Iterator : public DatasetIterator { public: explicit Iterator(const Params ¶ms) : DatasetIterator(params) {} ~Iterator() override { mutex_lock l(mu_); if (input_impl_ != nullptr) { input_impl_.reset(); } } Status Initialize(IteratorContext *ctx) override { Status s = dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_); iter_ = std::make_unique( input_impl_.get(), dataset()->resource_, dataset()->neg_num_, dataset()->per_channel_, dataset()->channel_feature_, dataset()->item_features_, dataset()->label_index_, dataset()->positive_label_, dataset()->negative_label_, dataset()->negative_action_, dataset()->action_priority_, dataset()->positive_actions_, dataset()->index_feature_, dataset()->throw_origin_, dataset()->throw_origin_neg_, dataset()->cache_only_pos_, dataset()->cache_negative_actions_, dataset()->real_neg_instance_weight_, dataset()->sampled_neg_instance_weight_, dataset()->unbias_sampled_neg_, dataset()->origin_neg_in_pool_proba_, dataset()->neg_sample_declay_factor_, dataset()->easy_hard_ratio_, dataset()->variant_type_); return s; } Status GetNextInternal(IteratorContext *ctx, std::vector *out_tensors, bool *end_of_sequence) override { mutex_lock l(mu_); out_tensors->reserve(1); TF_RETURN_IF_ERROR(iter_->GetNext(ctx, out_tensors, end_of_sequence)); if (*end_of_sequence) { input_impl_.reset(); } return Status::OK(); } protected: std::shared_ptr CreateNode( IteratorContext *ctx, model::Node::Args args) const override { return model::MakeUnknownRatioNode(std::move(args)); } Status SaveInternal(SerializationContext *ctx, IteratorStateWriter *writer) override { mutex_lock l(mu_); if (!input_impl_) { TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kInputImplEmpty), "")); } else { TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_)); } return Status::OK(); } Status RestoreInternal(IteratorContext *ctx, IteratorStateReader *reader) override { mutex_lock l(mu_); if (!reader->Contains(full_name(kInputImplEmpty))) { TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); } else { input_impl_.reset(); } return Status::OK(); } private: mutex mu_; std::unique_ptr input_impl_ TF_GUARDED_BY(mu_); std::unique_ptr iter_; }; const DatasetBase *const input_; int32 neg_num_; bool per_channel_; std::string channel_feature_; std::vector item_features_; int32 label_index_; int32 positive_label_; int32 negative_label_; int32 negative_action_; std::string action_priority_; std::vector positive_actions_; std::string index_feature_; bool throw_origin_; bool throw_origin_neg_; bool cache_only_pos_; std::vector cache_negative_actions_; float real_neg_instance_weight_; float sampled_neg_instance_weight_; bool unbias_sampled_neg_; float origin_neg_in_pool_proba_; float neg_sample_declay_factor_; float easy_hard_ratio_; std::string variant_type_; ResourceHandle handle_; ItemPoolResource *resource_; FeatureNameMapper *mapper_ = nullptr; }; InstanceNegativeGenDatasetOp::InstanceNegativeGenDatasetOp( OpKernelConstruction *ctx) : UnaryDatasetOpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr(kNegNum, &neg_num_)); OP_REQUIRES_OK(ctx, ctx->GetAttr(kPerChannel, &per_channel_)); OP_REQUIRES_OK(ctx, ctx->GetAttr(kChannelFeature, &channel_feature_)); OP_REQUIRES_OK(ctx, ctx->GetAttr(kItemFeature, &item_features_)); OP_REQUIRES_OK(ctx, ctx->GetAttr(kLabelIndex, &label_index_)); OP_REQUIRES_OK(ctx, ctx->GetAttr(kPositiveLabel, &positive_label_)); OP_REQUIRES_OK(ctx, ctx->GetAttr(kNegativeLabel, &negative_label_)); OP_REQUIRES_OK(ctx, ctx->GetAttr(kNegativeAction, &negative_action_)); OP_REQUIRES_OK(ctx, ctx->GetAttr(kPositiveActions, &positive_actions_)); OP_REQUIRES_OK(ctx, ctx->GetAttr(kActionPriority, &action_priority_)); OP_REQUIRES_OK(ctx, ctx->GetAttr(kIndexFeature, &index_feature_)); OP_REQUIRES_OK(ctx, ctx->GetAttr(kThrowOrigin, &throw_origin_)); OP_REQUIRES_OK(ctx, ctx->GetAttr(kThrowOriginNeg, &throw_origin_neg_)); OP_REQUIRES_OK(ctx, ctx->GetAttr(kCacheOnlyPos, &cache_only_pos_)); OP_REQUIRES_OK(ctx, ctx->GetAttr(kCacheNegativeActions, &cache_negative_actions_)); OP_REQUIRES_OK( ctx, ctx->GetAttr(kRealNegInstanceWeight, &real_neg_instance_weight_)); OP_REQUIRES_OK(ctx, ctx->GetAttr(kSampledNegInstanceWeight, &sampled_neg_instance_weight_)); OP_REQUIRES_OK(ctx, ctx->GetAttr(kUnbiasSampledNeg, &unbias_sampled_neg_)); OP_REQUIRES_OK( ctx, ctx->GetAttr(kOriginNegInPoolProba, &origin_neg_in_pool_proba_)); OP_REQUIRES_OK( ctx, ctx->GetAttr(kNegSampleDeclayFactor, &neg_sample_declay_factor_)); OP_REQUIRES_OK(ctx, ctx->GetAttr(kHardEasyRatio, &easy_hard_ratio_)); OP_REQUIRES_OK(ctx, ctx->GetAttr(kVariantType, &variant_type_)); auto creator = [this](FeatureNameMapperTfBridge **out_mapper) { TF_RETURN_IF_ERROR(FeatureNameMapperTfBridge::New(out_mapper)); return Status::OK(); }; ResourceMgr *resource_mgr = ctx->resource_manager(); OP_REQUIRES_OK(ctx, resource_mgr->LookupOrCreate( resource_mgr->default_container(), FeatureNameMapperTfBridge::kName, &mapper_, creator)); } void InstanceNegativeGenDatasetOp::MakeDataset(OpKernelContext *ctx, DatasetBase *input, DatasetBase **output) { *output = new Dataset( ctx, input, neg_num_, per_channel_, channel_feature_, item_features_, label_index_, positive_label_, negative_label_, negative_action_, action_priority_, positive_actions_, index_feature_, throw_origin_, throw_origin_neg_, cache_only_pos_, cache_negative_actions_, real_neg_instance_weight_, sampled_neg_instance_weight_, unbias_sampled_neg_, origin_neg_in_pool_proba_, neg_sample_declay_factor_, easy_hard_ratio_, variant_type_, mapper_->GetFeatureNameMapper()); } namespace { REGISTER_KERNEL_BUILDER(Name("InstanceNegativeGenDataset").Device(DEVICE_CPU), InstanceNegativeGenDatasetOp); } // namespace } // namespace monolith_tf } // namespace data } // namespace tensorflow ================================================ FILE: monolith/native_training/data/kernels/negative_gen_dataset_kernel.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_NATIVE_TRAINING_DATA_KERNELS_NEGATIVE_GEN_DATASET_KERNEL_H_ #define MONOLITH_NATIVE_TRAINING_DATA_KERNELS_NEGATIVE_GEN_DATASET_KERNEL_H_ #include "monolith/native_training/data/kernels/feature_name_mapper_tf_bridge.h" #include "tensorflow/core/framework/dataset.h" namespace tensorflow { namespace data { namespace monolith_tf { enum class VariantType { PBInstance, PBExample }; class InstanceNegativeGenDatasetOp : public UnaryDatasetOpKernel { public: explicit InstanceNegativeGenDatasetOp(OpKernelConstruction* ctx); ~InstanceNegativeGenDatasetOp() override { mapper_->Unref(); } protected: void MakeDataset(OpKernelContext* ctx, DatasetBase* input, DatasetBase** output) override; private: class Dataset; int32 neg_num_; bool per_channel_; std::string channel_feature_; std::vector item_features_; int32 label_index_; int32 positive_label_; int32 negative_label_; int32 negative_action_; std::string action_priority_; std::vector positive_actions_; std::string index_feature_; bool throw_origin_; bool throw_origin_neg_; bool cache_only_pos_; std::vector cache_negative_actions_; float real_neg_instance_weight_ = 1.0; float sampled_neg_instance_weight_ = -1; bool unbias_sampled_neg_; float origin_neg_in_pool_proba_; float neg_sample_declay_factor_; float easy_hard_ratio_; std::string variant_type_; tensorflow::monolith_tf::FeatureNameMapperTfBridge* mapper_ = nullptr; }; } // namespace monolith_tf } // namespace data } // namespace tensorflow #endif // MONOLITH_NATIVE_TRAINING_DATA_KERNELS_NEGATIVE_GEN_DATASET_KERNEL_H_ ================================================ FILE: monolith/native_training/data/kernels/parquet_dataset_kernel.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/data/kernels/internal/parquet_example_reader.h" #include "parquet/api/reader.h" #include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/profiler/lib/traceme.h" #include "third_party/nlohmann/json.hpp" namespace tensorflow { namespace data { namespace monolith_tf { using monolith::io::proto::Example; using monolith::io::proto::NamedFeature; using monolith::io::proto::NamedFeatureList; class ParquetDatasetOp : public DatasetOpKernel { public: static const char* const kDatasetType; static const char* const kFileName; static const char* const kOutputPbType; static const char* const kBatchSize; static const char* const kDropRemainder; static const char* const kSelectColumns; static const char* const kSelectColumnsType; explicit ParquetDatasetOp(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) { // select_columns OP_REQUIRES_OK(ctx, ctx->GetAttr(kBatchSize, &batch_size_)); OP_REQUIRES_OK(ctx, ctx->GetAttr(kDropRemainder, &drop_remainder_)); OP_REQUIRES_OK(ctx, ctx->GetAttr(kSelectColumns, &select_columns_)); OP_REQUIRES_OK(ctx, ctx->GetAttr(kSelectColumnsType, &select_columns_type_)); } void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { tstring file_name; OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kFileName, &file_name)); tstring output_pb_type; OP_REQUIRES_OK( ctx, ParseScalarArgument(ctx, kOutputPbType, &output_pb_type)); *output = new Dataset(ctx, file_name, output_pb_type, batch_size_, drop_remainder_, select_columns_, select_columns_type_); // config log nlohmann::json j; j[kFileName] = file_name; j[kOutputPbType] = output_pb_type; j[kBatchSize] = batch_size_; j[kSelectColumns] = select_columns_.size(); j[kSelectColumnsType] = select_columns_type_.size(); LOG(INFO) << j.dump(); } private: class Dataset : public DatasetBase { public: explicit Dataset(OpKernelContext* ctx, tstring file_name, tstring output_pb_type, int32_t batch_size, bool drop_remainder, std::vector select_columns, std::vector select_columns_type) : DatasetBase(DatasetContext(ctx)), file_name_(std::move(file_name)), output_pb_type_(std::move(output_pb_type)), batch_size_(batch_size), drop_remainder_(drop_remainder), select_columns_(std::move(select_columns)), select_columns_type_(std::move(select_columns_type)) {} std::unique_ptr MakeIteratorInternal( const string& prefix) const override { return absl::make_unique( Iterator::Params{this, strings::StrCat(prefix, "::", kDatasetType)}); } const DataTypeVector& output_dtypes() const override { static DataTypeVector* dtypes = nullptr; if (!dtypes) { if (output_pb_type_ == "example" || output_pb_type_ == "examplebatch") { dtypes = new DataTypeVector({DT_VARIANT}); } else { dtypes = new DataTypeVector({DT_STRING}); } } return *dtypes; } const std::vector& output_shapes() const override { static auto* shapes = new std::vector{TensorShape({})}; return *shapes; } string DebugString() const override { return "ParquetDatasetOp::Dataset"; } Status CheckExternalState() const override { return Status::OK(); } protected: Status AsGraphDefInternal(SerializationContext* ctx, DatasetGraphDefBuilder* b, Node** output) const override { Node* file_name = nullptr; TF_RETURN_IF_ERROR(b->AddScalar(file_name_, &file_name)); Node* output_pb_type = nullptr; TF_RETURN_IF_ERROR(b->AddScalar(output_pb_type_, &output_pb_type)); AttrValue batch_size; b->BuildAttrValue(batch_size_, &batch_size); AttrValue drop_remainder; b->BuildAttrValue(drop_remainder_, &drop_remainder); AttrValue select_columns; b->BuildAttrValue(select_columns_, &select_columns); AttrValue select_columns_type; b->BuildAttrValue(select_columns_type_, &select_columns_type); TF_RETURN_IF_ERROR( b->AddDataset(this, {file_name, output_pb_type}, {{kBatchSize, batch_size}, {kDropRemainder, drop_remainder}, {kSelectColumns, select_columns}, {kSelectColumnsType, select_columns_type}}, output)); return Status::OK(); } private: class Iterator : public DatasetIterator { public: explicit Iterator(const Params& params) : DatasetIterator(params) {} Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, bool* end_of_sequence) override { mutex_lock l(mu_); out_tensors->clear(); out_tensors->reserve(1); if (!parquet_reader_) { parquet_reader_.reset( new tensorflow::data::ParquetExampleReader(ctx->env())); std::vector select_col_str(dataset()->select_columns_.begin(), dataset()->select_columns_.end()); std::vector select_col_type_str( dataset()->select_columns_type_.begin(), dataset()->select_columns_type_.end()); TF_RETURN_IF_ERROR(parquet_reader_->Init( dataset()->file_name_, select_col_str, select_col_type_str)); } if (dataset()->output_pb_type_ == "example") { Example example; TF_RETURN_IF_ERROR(GetNextExample(example, end_of_sequence)); Tensor record_tensor(ctx->allocator({}), DT_VARIANT, {}); record_tensor.scalar()() = std::move(example); out_tensors->emplace_back(std::move(record_tensor)); if (*end_of_sequence) { LOG(INFO) << "end_of_sequence of " << dataset()->file_name_; } else { counter_++; LOG_EVERY_N_SEC(INFO, 60) << "consume " << counter_ << " examples from " << dataset()->file_name_; } } else if (dataset()->output_pb_type_ == "examplebatch") { ExampleBatch example_batch; TF_RETURN_IF_ERROR( GetNextExampleBatch(example_batch, end_of_sequence)); if (!(*end_of_sequence)) { if (dataset()->drop_remainder_ && example_batch.batch_size() < dataset()->batch_size_) { LOG(INFO) << "last example batch size=" << example_batch.batch_size() << " dropped"; *end_of_sequence = true; } else { Tensor record_tensor(ctx->allocator({}), DT_VARIANT, {}); record_tensor.scalar()() = std::move(example_batch); out_tensors->emplace_back(std::move(record_tensor)); counter_++; if (counter_ % 100 == 0) { LOG(INFO) << "consume " << counter_ << "example_batch from " << dataset()->file_name_; } } } } else if (dataset()->output_pb_type_ == "plaintext") { // only for debug use, generate examplebatch pb string ExampleBatch example_batch; TF_RETURN_IF_ERROR( GetNextExampleBatch(example_batch, end_of_sequence)); Tensor record_tensor(ctx->allocator({}), DT_STRING, {}); std::string out; example_batch.SerializeToString(&out); record_tensor.scalar()() = out; out_tensors->emplace_back(std::move(record_tensor)); if (*end_of_sequence) { LOG(INFO) << "end_of_sequence of " << dataset()->file_name_; } } else { return errors::InvalidArgument( "output_pb_type is ", dataset()->output_pb_type_, ",should be example or examplebatch or plaintext"); } return Status::OK(); } Status GetNextExample(Example& example, bool* end_of_sequence) { if (parquet_reader_->IsEOF()) { *end_of_sequence = true; } else { *end_of_sequence = false; example.Clear(); parquet_reader_->GetNextExample(example); } return Status::OK(); } Status GetNextExampleBatch(ExampleBatch& example_batch, bool* end_of_sequence) { profiler::TraceMe activity( []() { return "ParquetDatasetOp::GetNextExampleBatch"; }); if (parquet_reader_->IsEOF()) { *end_of_sequence = true; } else { *end_of_sequence = false; example_batch.Clear(); parquet_reader_->GetNextExampleBatch(example_batch, dataset()->batch_size_); } return Status::OK(); } NamedFeatureList* AddNamedFeatureList(ExampleBatch& example_batch, const std::string& name, int32_t id) { NamedFeatureList* named_feature_list = example_batch.add_named_feature_list(); named_feature_list->set_id(id); named_feature_list->set_name(name); return named_feature_list; } protected: Status SaveInternal(SerializationContext* ctx, IteratorStateWriter* writer) override { mutex_lock l(mu_); // do nothing LOG(INFO) << "Save function is not supported yet."; return Status::OK(); } Status RestoreInternal(IteratorContext* ctx, IteratorStateReader* reader) override { mutex_lock l(mu_); // do nothing LOG(INFO) << "Restore function is not supported yet."; return Status::OK(); } private: mutex mu_; int64_t counter_ = 0; std::unique_ptr parquet_reader_; }; // original inputs/attrs tstring file_name_; tstring output_pb_type_; int32_t batch_size_; bool drop_remainder_; std::vector select_columns_; std::vector select_columns_type_; }; Dataset* output_ = nullptr; int32_t batch_size_; bool drop_remainder_; std::vector select_columns_; std::vector select_columns_type_; }; const char* const ParquetDatasetOp::kDatasetType = "ParquetDataset"; const char* const ParquetDatasetOp::kFileName = "file_name"; const char* const ParquetDatasetOp::kOutputPbType = "output_pb_type"; const char* const ParquetDatasetOp::kBatchSize = "batch_size"; const char* const ParquetDatasetOp::kSelectColumns = "select_columns"; const char* const ParquetDatasetOp::kSelectColumnsType = "select_columns_type"; const char* const ParquetDatasetOp::kDropRemainder = "drop_remainder"; namespace { REGISTER_KERNEL_BUILDER(Name("ParquetDataset").Device(DEVICE_CPU), ParquetDatasetOp); } } // namespace monolith_tf } // namespace data } // namespace tensorflow ================================================ FILE: monolith/native_training/data/kernels/parse_example_lib.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/data/kernels/parse_example_lib.h" #include #include #include #include "absl/strings/match.h" #include "tensorflow/core/platform/threadpool.h" #include "tensorflow/core/profiler/lib/traceme.h" namespace tensorflow { namespace monolith_tf { using Instance = ::parser::proto::Instance; using LineId = ::idl::matrix::proto::LineId; using EFeature = ::monolith::io::proto::Feature; using Example = ::monolith::io::proto::Example; using ExampleBatch = ::monolith::io::proto::ExampleBatch; using FeatureListType = ::monolith::io::proto::FeatureListType; using FieldDescriptor = ::google::protobuf::FieldDescriptor; using FeatureConfigs = ::monolith::io::proto::FeatureConfigs; BaseParser::BaseParser(const std::vector &names, const std::vector &shapes, const std::vector &dtypes, const std::vector extra_names, DataType input_dtype) : input_dtype_(input_dtype), extra_names_(extra_names) { for (size_t i = 0; i < names.size(); ++i) { name2info_.emplace(names[i], std::make_tuple(i, shapes[i], dtypes[i])); idx2info_.emplace(i, std::make_tuple(names[i], shapes[i], dtypes[i])); if (shapes[i] == -1 && dtypes[i] == DataType::DT_INT64) { ragged_names_.insert(names[i]); idx2info_.emplace(i + names.size(), std::make_tuple(names[i], shapes[i], dtypes[i])); } } } void BaseParser::AllocateFeatures(OpKernelContext *ctx, std::vector *out_tensors, OpOutputList *out_list, int batch_size) { profiler::TraceMe activity([]() { return "AllocateFeatures"; }); std::string name; int shape; DataType dtype; for (size_t i = 0; i < name2info_.size(); ++i) { std::tie(name, shape, dtype) = idx2info_[i]; if (shape == -1) { OP_REQUIRES( ctx, dtype == DataType::DT_INT64, errors::InvalidArgument("If shape is -1, then dtype must be int64")); OP_REQUIRES_OK( ctx, out_list->allocate(i, {batch_size + 1}, &out_tensors->at(i))); } else { OP_REQUIRES_OK( ctx, out_list->allocate(i, {batch_size, shape}, &out_tensors->at(i))); } std::memset(out_tensors->at(i)->data(), 0, out_tensors->at(i)->TotalBytes()); } } void BaseParser::AllocateRaggedValues(OpKernelContext *ctx, std::vector *out_tensors, OpOutputList *out_list, int batch_size) { profiler::TraceMe activity([]() { return "AllocateRaggedValues"; }); int idx, shape; DataType dtype; for (const std::string &name : ragged_names_) { std::tie(idx, shape, dtype) = name2info_[name]; Tensor *tensor = out_tensors->at(idx); shape = static_cast(tensor->flat()(batch_size)); idx += name2info_.size(); OP_REQUIRES_OK(ctx, out_list->allocate(idx, {shape}, &out_tensors->at(idx))); if (shape > 0) { std::memset(out_tensors->at(idx)->data(), 0, out_tensors->at(idx)->TotalBytes()); } } } // TODO: This function can be optimized further if needed: // // 1. Instead flat tensor inside, flat it outside (Reduce 2/3 running time) // 2. Using switch instead of if void BaseParser::FillFeature(OpKernelContext *ctx, const EFeature &feature, Tensor *tensor, const std::string &name, const int shape, const int offset) { if (feature.has_fid_v1_list()) { auto flat = tensor->flat(); flat(offset + 1) = flat(offset) + feature.fid_v1_list().value_size(); } else if (feature.has_fid_v2_list()) { auto flat = tensor->flat(); flat(offset + 1) = flat(offset) + feature.fid_v2_list().value_size(); } else { if (shape == -1) { auto flat = tensor->flat(); flat(offset + 1) = flat(offset); } } if (feature.has_float_list()) { if (shape == 1) { if (feature.float_list().value_size() > 0) { tensor->flat()(offset) = feature.float_list().value(0); } else { tensor->flat()(offset) = std::numeric_limits::min(); LOG_EVERY_N(INFO, 10000) << "float feature " << name << " has missing data!"; } } else { auto matrix = tensor->matrix(); for (int j = 0; j < std::min(shape, feature.float_list().value_size()); ++j) { matrix(offset, j) = feature.float_list().value(j); } } } else if (feature.has_double_list()) { if (shape == 1) { if (feature.double_list().value_size() > 0) { tensor->flat()(offset) = feature.double_list().value(0); } else { tensor->flat()(offset) = std::numeric_limits::min(); LOG_EVERY_N(INFO, 10000) << "double feature " << name << " has missing data!"; } } else { auto matrix = tensor->matrix(); for (int j = 0; j < std::min(shape, feature.double_list().value_size()); ++j) { matrix(offset, j) = feature.double_list().value(j); } } } else if (feature.has_int64_list()) { if (shape == 1) { CHECK_GT(feature.int64_list().value_size(), 0); tensor->flat()(offset) = feature.int64_list().value(0); } else { auto matrix = tensor->matrix(); for (int j = 0; j < std::min(shape, feature.int64_list().value_size()); ++j) { matrix(offset, j) = feature.int64_list().value(j); } } } else if (feature.has_bytes_list()) { OP_REQUIRES(ctx, shape == 1, errors::InvalidArgument("shape must be 1 for bytes list!")); CHECK_GT(feature.bytes_list().value_size(), 0); tensor->flat()(offset) = feature.bytes_list().value(0); } else { if (feature.has_fid_v2_lists() || feature.has_float_lists() || feature.has_double_lists() || feature.has_int64_lists() || feature.has_bytes_lists()) { LOG(ERROR) << "list of list is not support yet!"; } } } void BaseParser::FillFromLineId(OpKernelContext *ctx, const LineId &line_id, std::vector *out_tensors, const int offset) { int idx, shape; DataType dtype; for (const std::string &name : extra_names_) { std::tie(idx, shape, dtype) = name2info_[name]; Tensor *tensor = out_tensors->at(idx); if (name == "req_time") { tensor->flat()(offset) = line_id.req_time(); } if (name == "user_id") { tensor->flat()(offset) = line_id.user_id(); } else if (name == "uid") { tensor->flat()(offset) = line_id.uid(); } else if (name == "actions") { if (shape > line_id.actions_size()) { LOG_EVERY_N(ERROR, 100) << absl::StrFormat("Expected actions' shape=%d while got %d", shape, line_id.actions_size()); } if (shape == 1) { if (line_id.actions_size()) { tensor->flat()(offset) = line_id.actions(0); } } else { auto matrix = tensor->matrix(); for (int i = 0; i < std::min(shape, line_id.actions_size()); ++i) { matrix(offset, i) = line_id.actions(i); } } } else if (name == "sample_rate") { tensor->flat()(offset) = line_id.sample_rate(); } else if (name == "chnid") { tensor->flat()(offset) = line_id.chnid(); } else { const auto *field = descriptor->FindFieldByName(name); OP_REQUIRES_OK(ctx, FillFromLineIdByreflection(line_id, field, tensor, shape, offset)); } } } Status BaseParser::FillFromLineIdByreflection(const LineId &line_id, const FieldDescriptor *field, Tensor *tensor, int shape, int offset) { if (field->is_repeated()) { const int field_size = reflection->FieldSize(line_id, field); switch (field->cpp_type()) { case FieldDescriptor::CPPTYPE_INT32: { if (shape == 1) { tensor->flat()(offset) = reflection->GetRepeatedInt32(line_id, field, 0); } else { auto matrix = tensor->matrix(); for (int i = 0; i < std::min(shape, field_size); ++i) { matrix(offset, i) = reflection->GetRepeatedInt32(line_id, field, i); } } break; } case FieldDescriptor::CPPTYPE_UINT32: { if (shape == 1) { tensor->flat()(offset) = reflection->GetRepeatedUInt32(line_id, field, 0); } else { auto matrix = tensor->matrix(); for (int i = 0; i < std::min(shape, field_size); ++i) { matrix(offset, i) = reflection->GetRepeatedUInt32(line_id, field, i); } } break; } case FieldDescriptor::CPPTYPE_INT64: { if (shape == 1) { tensor->flat()(offset) = reflection->GetRepeatedInt64(line_id, field, 0); } else { auto matrix = tensor->matrix(); for (int i = 0; i < std::min(shape, field_size); ++i) { matrix(offset, i) = reflection->GetRepeatedInt64(line_id, field, i); } } break; } case FieldDescriptor::CPPTYPE_UINT64: { if (shape == 1) { tensor->flat()(offset) = reflection->GetRepeatedUInt64(line_id, field, 0); } else { auto matrix = tensor->matrix(); for (int i = 0; i < std::min(shape, field_size); ++i) { matrix(offset, i) = reflection->GetRepeatedUInt64(line_id, field, i); } } break; } case FieldDescriptor::CPPTYPE_FLOAT: { if (shape == 1) { tensor->flat()(offset) = reflection->GetRepeatedFloat(line_id, field, 0); } else { auto matrix = tensor->matrix(); for (int i = 0; i < std::min(shape, field_size); ++i) { matrix(offset, i) = reflection->GetRepeatedFloat(line_id, field, i); } } break; } case FieldDescriptor::CPPTYPE_DOUBLE: { if (shape == 1) { tensor->flat()(offset) = reflection->GetRepeatedDouble(line_id, field, 0); } else { auto matrix = tensor->matrix(); for (int i = 0; i < std::min(shape, field_size); ++i) { matrix(offset, i) = reflection->GetRepeatedDouble(line_id, field, i); } } break; } case FieldDescriptor::CPPTYPE_STRING: { if (shape == 1) { tensor->flat()(offset) = reflection->GetRepeatedString(line_id, field, 0); } else { auto matrix = tensor->matrix(); for (int i = 0; i < std::min(shape, field_size); ++i) { matrix(offset, i) = reflection->GetRepeatedString(line_id, field, i); } } break; } default: return errors::InvalidArgument(field->name(), " Data type not match, only " "string/int32/int64/float32 " "supported."); } } else { switch (field->cpp_type()) { case FieldDescriptor::CPPTYPE_INT32: { auto flat = tensor->flat(); flat(offset) = reflection->GetInt32(line_id, field); break; } case FieldDescriptor::CPPTYPE_UINT32: { auto flat = tensor->flat(); flat(offset) = reflection->GetUInt32(line_id, field); break; } case FieldDescriptor::CPPTYPE_INT64: { auto flat = tensor->flat(); flat(offset) = reflection->GetInt64(line_id, field); break; } case FieldDescriptor::CPPTYPE_UINT64: { auto flat = tensor->flat(); flat(offset) = reflection->GetUInt64(line_id, field); break; } case FieldDescriptor::CPPTYPE_FLOAT: { auto flat = tensor->flat(); flat(offset) = reflection->GetFloat(line_id, field); break; } case FieldDescriptor::CPPTYPE_DOUBLE: { auto flat = tensor->flat(); flat(offset) = reflection->GetDouble(line_id, field); break; } case FieldDescriptor::CPPTYPE_STRING: { auto flat = tensor->flat(); flat(offset) = reflection->GetString(line_id, field); break; } default: return errors::InvalidArgument(field->name(), " Data type not match, only " "int32/int64/float32/string " "supported."); } } return Status::OK(); } ExampleParser::ExampleParser(const std::vector &names, const std::vector &shapes, const std::vector &dtypes, const std::vector extra_names, DataType input_dtype, FeatureNameMapper *mapper) : BaseParser(names, shapes, dtypes, extra_names, input_dtype), mapper_(mapper) { std::unordered_set extra_name_set(extra_names.begin(), extra_names.end()); std::vector sparse_feature_names; for (size_t i = 0; i < names.size(); ++i) { if (!extra_name_set.count(names[i]) && shapes[i] == -1) { sparse_feature_names.push_back(names[i]); } } CHECK(mapper_->RegisterValidNames(sparse_feature_names)); } void ExampleParser::Parse(OpKernelContext *ctx, const std::vector &examples, OpOutputList *out_list) { int batch_size = examples.size(); std::vector out_tensors; out_tensors.resize(idx2info_.size()); int idx, shape; DataType dtype; // 1) allocate output tensors for ragged splits and other non-ragged AllocateFeatures(ctx, &out_tensors, out_list, batch_size); // 2) fill all tensors expect ragged values int offset = 0; { profiler::TraceMe activity( []() { return "FillAllTensorsExceptRaggedValues"; }); for (const Example *example : examples) { std::unordered_set appeared; appeared.reserve(example->named_feature_size()); bool has_fill_label = false, has_fill_instance_weight = false; for (const auto &named_feature : example->named_feature()) { // FeatureNameMapper const std::string &name = named_feature.name(); auto it = name2info_.find(name); if (it == name2info_.end()) continue; std::tie(idx, shape, dtype) = it->second; FillFeature(ctx, named_feature.feature(), out_tensors[idx], name, shape, offset); appeared.insert(name); if (name == "label") { has_fill_label = true; } else if (name == "instance_weight") { has_fill_instance_weight = true; } } for (const auto &ragged : ragged_names_) { if (appeared.find(ragged) == appeared.end()) { std::tie(idx, shape, dtype) = name2info_[ragged]; auto flat = out_tensors[idx]->flat(); flat(offset + 1) += flat(offset); } } // for label if (!has_fill_label) { auto it = name2info_.find("label"); if (it != name2info_.end()) { std::tie(idx, shape, dtype) = it->second; Tensor *tensor = out_tensors[idx]; if (example->label_size() > 0) { if (shape == 1) { tensor->flat()(offset) = example->label(0); } else { auto matrix = tensor->matrix(); for (int j = 0; j < shape; ++j) { if (j < example->label_size()) { matrix(offset, j) = example->label(j); } else { matrix(offset, j) = internal::INVALID_LABEL; } } } } else { if (shape == 1) { tensor->flat()(offset) = internal::INVALID_LABEL; } else { auto matrix = tensor->matrix(); for (int j = 0; j < shape; ++j) { matrix(offset, j) = internal::INVALID_LABEL; } } } } } // for instance_weight if (!has_fill_instance_weight) { auto it = name2info_.find("instance_weight"); if (it != name2info_.end()) { std::tie(idx, shape, dtype) = it->second; Tensor *tensor = out_tensors[idx]; float instance_weight = example->instance_weight(); tensor->flat()(offset) = instance_weight > 0 ? instance_weight : 1.0; } } // for extra fields in line_id if (!extra_names_.empty()) { const LineId &line_id = example->line_id(); FillFromLineId(ctx, line_id, &out_tensors, offset); } offset++; } } // 3) allocate output tensors for ragged values AllocateRaggedValues(ctx, &out_tensors, out_list, batch_size); // 4) fill ragged values if (ragged_names_.size()) { profiler::TraceMe activity([]() { return "FillRaggedValues"; }); offset = 0; for (const Example *example : examples) { for (const auto &named_feature : example->named_feature()) { const auto &name = named_feature.name(); auto it = ragged_names_.find(name); if (it != ragged_names_.end()) { std::tie(idx, shape, dtype) = name2info_[name]; auto splits = out_tensors[idx]->flat(); auto values = out_tensors[idx + name2info_.size()]->flat(); int start = static_cast(splits(offset)); const auto &feature = named_feature.feature(); if (feature.has_fid_v1_list()) { for (int i = 0; i < feature.fid_v1_list().value_size(); ++i) { values(start + i) = convert_fid_v1_to_v2(feature.fid_v1_list().value(i)); } } else if (feature.has_fid_v2_list()) { for (int i = 0; i < feature.fid_v2_list().value_size(); ++i) { values(start + i) = feature.fid_v2_list().value(i); } } } } offset++; } } } ExampleBatchParser::ExampleBatchParser( const std::vector &names, const std::vector &shapes, const std::vector &dtypes, const std::vector extra_names, DataType input_dtype) : BaseParser(names, shapes, dtypes, extra_names, input_dtype) {} void ExampleBatchParser::Parse(OpKernelContext *ctx, const ExampleBatch &example_batch, OpOutputList *out_list) { int batch_size = example_batch.batch_size(); std::vector out_tensors; out_tensors.resize(idx2info_.size()); std::string name; int idx, shape; DataType dtype; // 1) allocate output tensors for ragged splits and other non-ragged AllocateFeatures(ctx, &out_tensors, out_list, batch_size); // 2) fill all tensors expect ragged values for (const auto &named_feature_list : example_batch.named_feature_list()) { name = named_feature_list.name(); if (name == "__LINE_ID__") { // for extra fields in line_id if (extra_names_.size() > 0) { int offset = 0; for (const auto &feature : named_feature_list.feature()) { LineId line_id; CHECK_GT(feature.bytes_list().value_size(), 0); const auto serialized = feature.bytes_list().value(0); OP_REQUIRES( ctx, line_id.ParseFromArray(serialized.data(), serialized.size()), errors::FailedPrecondition("Failed to parse the LineId.")); FillFromLineId(ctx, line_id, &out_tensors, offset); offset++; } } } else if (name == "__LABEL__") { // for label auto it = name2info_.find("label"); if (it != name2info_.end()) { std::tie(idx, shape, dtype) = it->second; Tensor *tensor = out_tensors[idx]; int offset = 0; for (const auto &feature : named_feature_list.feature()) { if (shape == 1) { CHECK_GT(feature.float_list().value_size(), 0); tensor->flat()(offset) = feature.float_list().value(0); } else { auto matrix = tensor->matrix(); for (int j = 0; j < std::min(shape, feature.float_list().value_size()); ++j) { matrix(offset, j) = feature.float_list().value(j); } } offset++; } } } else if (name == "instance_weight") { auto it = name2info_.find("instance_weight"); if (it != name2info_.end()) { std::tie(idx, shape, dtype) = it->second; Tensor *tensor = out_tensors[idx]; int offset = 0; for (const auto &feature : named_feature_list.feature()) { CHECK_GT(feature.float_list().value_size(), 0); float instance_weight = feature.float_list().value(0); tensor->flat()(offset) = instance_weight > 0 ? instance_weight : 1.0; offset++; } } } else { auto it = name2info_.find(name); if (it == name2info_.end()) continue; std::tie(idx, shape, dtype) = name2info_[name]; Tensor *tensor = out_tensors[idx]; if (named_feature_list.type() == FeatureListType::SHARED) { CHECK_GT(named_feature_list.feature_size(), 0); const auto &feature = named_feature_list.feature(0); for (int offset = 0; offset < batch_size; ++offset) { FillFeature(ctx, feature, tensor, name, shape, offset); } } else { int offset = 0; for (const auto &feature : named_feature_list.feature()) { FillFeature(ctx, feature, tensor, name, shape, offset); offset++; } } } } // 3) allocate output tensors for ragged values AllocateRaggedValues(ctx, &out_tensors, out_list, batch_size); // 4) fill ragged values if (ragged_names_.size()) { for (const auto &named_feature_list : example_batch.named_feature_list()) { name = named_feature_list.name(); auto it = ragged_names_.find(name); if (it != ragged_names_.end()) { std::tie(idx, shape, dtype) = name2info_[name]; auto splits = out_tensors[idx]->flat(); auto values = out_tensors[idx + name2info_.size()]->flat(); if (named_feature_list.type() == FeatureListType::SHARED) { const auto &feature = named_feature_list.feature(0); for (int offset = 0; offset < batch_size; ++offset) { int start = static_cast(splits(offset)); if (feature.has_fid_v1_list()) { for (int i = 0; i < feature.fid_v1_list().value_size(); ++i) { values(start + i) = convert_fid_v1_to_v2(feature.fid_v1_list().value(i)); } } else if (feature.has_fid_v2_list()) { for (int i = 0; i < feature.fid_v2_list().value_size(); ++i) { values(start + i) = feature.fid_v2_list().value(i); } } } } else { int offset = 0; for (const auto &feature : named_feature_list.feature()) { int start = static_cast(splits(offset)); if (feature.has_fid_v1_list()) { for (int i = 0; i < feature.fid_v1_list().value_size(); ++i) { values(start + i) = convert_fid_v1_to_v2(feature.fid_v1_list().value(i)); } } else if (feature.has_fid_v2_list()) { for (int i = 0; i < feature.fid_v2_list().value_size(); ++i) { values(start + i) = feature.fid_v2_list().value(i); } } offset++; } } } } } } ExampleBatchListParser::ExampleBatchListParser( const std::vector &names, const std::vector &shapes, const std::vector &dtypes, const std::vector &extra_names, DataType input_dtype) : BaseParser(names, shapes, dtypes, extra_names, input_dtype) {} void ExampleBatchListParser::Parse( OpKernelContext *ctx, const ExampleBatch &example_batch, const std::vector &label_config_, float positive_label, float negative_label, OpOutputList *out_list) { int batch_size = example_batch.batch_size(); std::vector out_tensors; out_tensors.resize(idx2info_.size()); std::string name; int idx, shape; DataType dtype; // 1) allocate output tensors for ragged splits and other non-ragged AllocateFeatures(ctx, &out_tensors, out_list, batch_size); // 2) fill all tensors expect ragged values for (const auto &named_feature_list : example_batch.named_feature_list()) { name = named_feature_list.name(); if (name == "__LINE_ID__") { auto it = name2info_.find("label"); if (it != name2info_.end()) { std::tie(idx, shape, dtype) = it->second; } // for extra fields in line_id if (extra_names_.size() > 0) { int offset = 0; for (const auto &feature : named_feature_list.feature()) { LineId line_id; CHECK_GT(feature.bytes_list().value_size(), 0); const auto serialized = feature.bytes_list().value(0); OP_REQUIRES( ctx, line_id.ParseFromArray(serialized.data(), serialized.size()), errors::FailedPrecondition("Failed to parse the LineId.")); FillFromLineId(ctx, line_id, &out_tensors, offset); if (it != name2info_.end()) { FillLabelFromLineId(ctx, line_id, label_config_, positive_label, negative_label, out_tensors[idx], offset); } offset++; } } } else if (name == "instance_weight") { auto it = name2info_.find("instance_weight"); if (it != name2info_.end()) { std::tie(idx, shape, dtype) = it->second; Tensor *tensor = out_tensors[idx]; int offset = 0; for (const auto &feature : named_feature_list.feature()) { CHECK_GT(feature.float_list().value_size(), 0); float instance_weight = feature.float_list().value(0); tensor->flat()(offset++) = instance_weight > 0 ? instance_weight : 1.0; } } } else { auto it = name2info_.find(name); if (it == name2info_.end()) continue; std::tie(idx, shape, dtype) = it->second; Tensor *tensor = out_tensors[idx]; if (named_feature_list.type() == FeatureListType::SHARED) { const auto &feature = named_feature_list.feature(0); for (int offset = 0; offset < batch_size; ++offset) { FillFeature(ctx, feature, tensor, name, shape, offset); } } else { int offset = 0; for (const auto &feature : named_feature_list.feature()) { FillFeature(ctx, feature, tensor, name, shape, offset); offset++; } } } } // 3) allocate output tensors for ragged values AllocateRaggedValues(ctx, &out_tensors, out_list, batch_size); // 4) fill ragged values for (const auto &named_feature_list : example_batch.named_feature_list()) { name = named_feature_list.name(); auto it = ragged_names_.find(name); if (it != ragged_names_.end()) { int slot = named_feature_list.id(); std::tie(idx, shape, dtype) = name2info_[name]; auto splits = out_tensors[idx]->flat(); auto values = out_tensors[idx + name2info_.size()]->flat(); if (named_feature_list.type() == FeatureListType::SHARED) { const auto &feature = named_feature_list.feature(0); for (int offset = 0; offset < batch_size; ++offset) { int start = static_cast(splits(offset)); if (feature.has_fid_v1_list()) { for (int i = 0; i < feature.fid_v1_list().value_size(); ++i) { values(start + i) = GetFidV2(slot, feature.fid_v1_list().value(i)); } } else if (feature.has_fid_v2_list()) { for (int i = 0; i < feature.fid_v2_list().value_size(); ++i) { values(start + i) = GetFidV2(slot, feature.fid_v2_list().value(i)); } } } } else { int offset = 0; for (const auto &feature : named_feature_list.feature()) { int start = static_cast(splits(offset)); if (feature.has_fid_v1_list()) { for (int i = 0; i < feature.fid_v1_list().value_size(); ++i) { values(start + i) = GetFidV2(slot, feature.fid_v1_list().value(i)); } } else if (feature.has_fid_v2_list()) { for (int i = 0; i < feature.fid_v2_list().value_size(); ++i) { values(start + i) = GetFidV2(slot, feature.fid_v2_list().value(i)); } } offset++; } } } } } void ExampleBatchListParser::FillLabelFromLineId( OpKernelContext *ctx, const ::idl::matrix::proto::LineId &line_id, const std::vector &label_config_, float positive_label, float negative_label, Tensor *out_tensor, const int offset) { std::set actions(line_id.actions().begin(), line_id.actions().end()); int label_idx = 0; auto matrix = out_tensor->matrix(); for (const auto &task_conf : label_config_) { if (internal::HasIntersection(task_conf.pos_actions, actions)) { matrix(offset, label_idx) = positive_label; } else { if (task_conf.neg_actions.empty()) { matrix(offset, label_idx) = negative_label; } else { if (internal::HasIntersection(task_conf.neg_actions, actions)) { matrix(offset, label_idx) = negative_label; } else { matrix(offset, label_idx) = internal::INVALID_LABEL; } } } label_idx++; } } } // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/data/kernels/parse_example_lib.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_NATIVE_TRAINING_DATA_KERNELS_PARSE_EXAMPLE_LIB_H_ #define MONOLITH_NATIVE_TRAINING_DATA_KERNELS_PARSE_EXAMPLE_LIB_H_ #include #include #include "google/protobuf/descriptor.h" #include "idl/matrix/proto/example.pb.h" #include "idl/matrix/proto/proto_parser.pb.h" #include "monolith/native_training/data/kernels/internal/label_utils.h" #include "monolith/native_training/data/training_instance/cc/reader_util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/platform/env.h" namespace tensorflow { namespace monolith_tf { class BaseParser { public: explicit BaseParser(const std::vector &names, const std::vector &shapes, const std::vector &dtypes, const std::vector extra_names, DataType input_dtype); protected: void AllocateFeatures(OpKernelContext *ctx, std::vector *out_tensors, OpOutputList *out_list, int batch_size); void AllocateRaggedValues(OpKernelContext *ctx, std::vector *out_tensors, OpOutputList *out_list, int batch_size); void FillFeature(OpKernelContext *ctx, const ::monolith::io::proto::Feature &feature, Tensor *tensor, const std::string &name, int shape, int offset); void FillFromLineId(OpKernelContext *ctx, const ::idl::matrix::proto::LineId &line_id, std::vector *out_tensors, const int offset); Status FillFromLineIdByreflection( const ::idl::matrix::proto::LineId &line_id, const ::google::protobuf::FieldDescriptor *field, Tensor *tensor, int shape, int offset); std::unordered_map> name2info_; std::unordered_map> idx2info_; std::unordered_set ragged_names_; std::vector extra_names_; DataType input_dtype_; const ::google::protobuf::Descriptor *descriptor = ::idl::matrix::proto::LineId::GetDescriptor(); const ::google::protobuf::Reflection *reflection = ::idl::matrix::proto::LineId::GetReflection(); }; class ExampleParser : public BaseParser { public: explicit ExampleParser(const std::vector &names, const std::vector &shapes, const std::vector &dtypes, const std::vector extra_names, DataType input_dtype, FeatureNameMapper *mapper); void Parse( OpKernelContext *ctx, const std::vector &examples, OpOutputList *out_list); private: FeatureNameMapper *mapper_ = nullptr; }; class ExampleBatchParser : public BaseParser { public: explicit ExampleBatchParser(const std::vector &names, const std::vector &shapes, const std::vector &dtypes, const std::vector extra_names, DataType input_dtype); void Parse(OpKernelContext *ctx, const ::monolith::io::proto::ExampleBatch &example_batch, OpOutputList *out_list); }; class ExampleBatchListParser : public BaseParser { public: explicit ExampleBatchListParser(const std::vector &names, const std::vector &shapes, const std::vector &dtypes, const std::vector &extra_names, DataType input_dtype); void Parse(OpKernelContext *ctx, const ::monolith::io::proto::ExampleBatch &example_batchs, const std::vector &label_config_, float positive_label, float negative_label, OpOutputList *out_list); private: uint64 mask_ = (1 << 48) - 1; void FillLabelFromLineId( OpKernelContext *ctx, const ::idl::matrix::proto::LineId &line_id, const std::vector &label_config_, float positive_label, float negative_label, Tensor *out_tensor, const int offset); }; } // namespace monolith_tf } // namespace tensorflow #endif MONOLITH_NATIVE_TRAINING_DATA_KERNELS_PARSE_EXAMPLE_LIB_H_ ================================================ FILE: monolith/native_training/data/kernels/parse_input_data_kernel.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include #include "absl/container/flat_hash_map.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_split.h" #include "google/protobuf/descriptor.h" #include "idl/matrix/proto/example.pb.h" #include "idl/matrix/proto/proto_parser.pb.h" #include "monolith/native_training/data/kernels/feature_name_mapper_tf_bridge.h" #include "monolith/native_training/data/kernels/internal/label_utils.h" #include "monolith/native_training/data/kernels/parse_example_lib.h" #include "monolith/native_training/data/training_instance/cc/data_reader.h" #include "monolith/native_training/data/training_instance/cc/parse_instance_lib.h" #include "monolith/native_training/data/training_instance/cc/pb_variant.h" #include "monolith/native_training/data/training_instance/cc/reader_util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/threadpool.h" #include "tensorflow/core/profiler/lib/traceme.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "monolith/native_training/runtime/common/metrics.h" namespace tensorflow { namespace monolith_tf { namespace { using Instance = ::parser::proto::Instance; using LineId = ::idl::matrix::proto::LineId; using EFeature = ::monolith::io::proto::Feature; using Example = ::monolith::io::proto::Example; using ExampleBatch = ::monolith::io::proto::ExampleBatch; using FieldDescriptor = ::google::protobuf::FieldDescriptor; using ExampleParser = ::tensorflow::monolith_tf::ExampleParser; using ExampleBatchParser = ::tensorflow::monolith_tf::ExampleBatchParser; using ExampleBatchListParser = ::tensorflow::monolith_tf::ExampleBatchListParser; using NamedFeatureList = ::monolith::io::proto::NamedFeatureList; using FeatureConfigs = ::monolith::io::proto::FeatureConfigs; class DataCounter { public: explicit DataCounter(std::string op, bool emit_mini_batch, int64_t emit_every_n_batch = 2000) : op_(std::move(op)), emit_mini_batch_(emit_mini_batch), mini_batch_num_(0), emit_every_n_batch_(emit_every_n_batch), last_batch_size_(0) { CHECK_GT(emit_every_n_batch_, 0); } ~DataCounter() { LOG(INFO) << absl::StrFormat( "Finally metrics_emit(counter) [data_consume_num] op=%s, " "batch_size=%d, total_mini_batch_num=%llu", op_, last_batch_size_, mini_batch_num_); int64_t remainder = mini_batch_num_ % emit_every_n_batch_; if (remainder) { monolith::GetMetrics()->emit_counter("data_consume_num", last_batch_size_ * remainder, absl::StrFormat("op=%s", op_)); if (emit_mini_batch_) { monolith::GetMetrics()->emit_counter("mini_batch_num", remainder, absl::StrFormat("op=%s", op_)); } } } void EmitDataConsumeNumCounter(int batch_size) { mini_batch_num_ += 1; last_batch_size_ = batch_size; LOG_EVERY_N_SEC(INFO, 300) << absl::StrFormat( "metrics_emit(counter) [data_consume_num] op=%s, " "batch_size=%d, total_mini_batch_num=%llu", op_, batch_size, mini_batch_num_); if (mini_batch_num_ % emit_every_n_batch_ == 0) { monolith::GetMetrics()->emit_counter("data_consume_num", batch_size * emit_every_n_batch_, absl::StrFormat("op=%s", op_)); if (emit_mini_batch_) { monolith::GetMetrics()->emit_counter("mini_batch_num", emit_every_n_batch_, absl::StrFormat("op=%s", op_)); } } } private: std::string op_; bool emit_mini_batch_; int64_t mini_batch_num_; int64_t emit_every_n_batch_; int64_t last_batch_size_; }; Status GetParserConfig(OpKernelConstruction *ctx, InstanceParserConfig *c, std::vector *index) { TF_RETURN_IF_ERROR(ctx, ctx->GetAttr("fidv1_features", &(c->fidv1_features))); TF_RETURN_IF_ERROR(ctx, ctx->GetAttr("fidv2_features", &(c->fidv2_features))); std::vector names; std::vector shapes; std::vector dtypes; std::vector extra_names; std::unordered_set misc({"label", "instance_weight"}); TF_RETURN_IF_ERROR(ctx, ctx->GetAttr("names", &names)); TF_RETURN_IF_ERROR(ctx, ctx->GetAttr("shapes", &shapes)); TF_RETURN_IF_ERROR(ctx, ctx->GetAttr("dtypes", &dtypes)); TF_RETURN_IF_ERROR(ctx, ctx->GetAttr("extra_names", &extra_names)); int ragged_size = c->fidv1_features.size() + c->fidv2_features.size(); if (names.size() != shapes.size() || shapes.size() + ragged_size != dtypes.size()) { return errors::InvalidArgument( "Num of names, shapes and dtypes do not match"); } for (size_t i = 0; i < names.size(); ++i) { if (i < ragged_size) { continue; // skip fidv1/fidv2 } std::string name = names[i]; int dim = shapes[i]; DataType dtype = dtypes[i]; auto eit = std::find(extra_names.begin(), extra_names.end(), name); if (eit != extra_names.end() || misc.find(name) != misc.end()) { // extra switch (dtype) { case DataType::DT_INT64: c->misc_int64_features.push_back(name); c->misc_int64_dims.push_back(dim); break; case DataType::DT_FLOAT: c->misc_float_features.push_back(name); c->misc_float_dims.push_back(dim); break; case DataType::DT_STRING: c->misc_string_features.push_back(name); c->misc_string_dims.push_back(dim); break; default: return errors::InvalidArgument("Unsupported data type!"); } } else { // dense switch (dtype) { case DataType::DT_INT64: c->int64_features.push_back(name); c->int64_feature_dims.push_back(dim); break; case DataType::DT_FLOAT: c->float_features.push_back(name); c->float_feature_dims.push_back(dim); break; case DataType::DT_STRING: c->string_features.push_back(name); c->string_feature_dims.push_back(dim); break; default: return errors::InvalidArgument("Unsupported data type!"); } } } std::vector new_names; new_names.reserve(names.size()); new_names.insert(new_names.end(), names.begin(), names.begin() + ragged_size); new_names.insert(new_names.end(), c->float_features.begin(), c->float_features.end()); new_names.insert(new_names.end(), c->int64_features.begin(), c->int64_features.end()); new_names.insert(new_names.end(), c->string_features.begin(), c->string_features.end()); new_names.insert(new_names.end(), c->misc_float_features.begin(), c->misc_float_features.end()); new_names.insert(new_names.end(), c->misc_int64_features.begin(), c->misc_int64_features.end()); new_names.insert(new_names.end(), c->misc_string_features.begin(), c->misc_string_features.end()); index->reserve(dtypes.size()); std::unordered_map name_to_idx; for (size_t i = 0; i < names.size(); ++i) { name_to_idx.emplace(names[i], i); } for (size_t i = 0; i < new_names.size(); ++i) { int idx = name_to_idx[new_names[i]]; if (i < ragged_size) { (*index)[i] = idx; (*index)[i + ragged_size] = idx + names.size(); } else { (*index)[i + ragged_size] = idx; } } return Status::OK(); } class ParseStringInstancesOp : public OpKernel { public: explicit ParseStringInstancesOp(OpKernelConstruction *ctx) : OpKernel(ctx) { InstanceParserConfig config; OP_REQUIRES_OK(ctx, GetParserConfig(ctx, &config, &index_)); config.collapse_batch_dim = false; parser_ = std::make_unique(config); OP_REQUIRES_OK(ctx, parser_->Init()); counter_ = std::make_unique("ParseStringInstancesOp", true); } void Compute(OpKernelContext *ctx) override { // Grab the input tensor const Tensor *pb_input; OP_REQUIRES_OK(ctx, ctx->input("pb_input", &pb_input)); const auto &serialized_flat = pb_input->flat(); const int batch_size = serialized_flat.size(); std::vector instances(batch_size); // has alocated memory { profiler::TraceMe activity([]() { return "Deserialize"; }); auto deserialize_fn = [&](int64 begin, int64 end) { for (int64 i = begin; i < end; ++i) { const auto &serialized = serialized_flat(i); OP_REQUIRES( ctx, instances[i].ParseFromArray(serialized.data(), serialized.size()), errors::FailedPrecondition("Failed to parse the Instance.")); } }; auto workers = ctx->device()->tensorflow_cpu_worker_threads()->workers; workers->ParallelFor(batch_size, tensorflow::thread::ThreadPool::SchedulingParams( tensorflow::thread::ThreadPool:: SchedulingStrategy::kFixedBlockSize, absl::nullopt, 1), deserialize_fn); } InstanceParser::Output output; { profiler::TraceMe activity([]() { return "Parse"; }); OP_REQUIRES_OK(ctx, parser_->Parse(ctx, instances, &output)); } OpOutputList out_list; { profiler::TraceMe activity([]() { return "PrepareOutput"; }); OP_REQUIRES_OK(ctx, ctx->output_list("tensors", &out_list)); OP_REQUIRES( ctx, output.tensors.size() == out_list.size(), errors::FailedPrecondition("output tensor size doesn't match")); for (size_t i = 0; i < output.tensors.size(); ++i) { out_list.set(index_[i], output.tensors[i]); } } counter_->EmitDataConsumeNumCounter(batch_size); } protected: const std::vector &GetIndex() const { return index_; } InstanceParser *GetParse() const { return parser_.get(); } DataCounter *GetCounter() const { return counter_.get(); } private: std::vector index_; std::unique_ptr parser_; std::unique_ptr counter_; }; class ParseStringInstancesV2Op : public ParseStringInstancesOp { public: explicit ParseStringInstancesV2Op(OpKernelConstruction *ctx) : ParseStringInstancesOp(ctx) {} void Compute(OpKernelContext *ctx) override { // Grab the input tensor const Tensor *pb_input; OP_REQUIRES_OK(ctx, ctx->input("pb_input", &pb_input)); const auto &serialized_flat = pb_input->flat(); int batch_size = serialized_flat.size(); std::vector instances(batch_size); // has alocated memory for (int i = 0; i < batch_size; ++i) { const auto &serialized = serialized_flat(i); OP_REQUIRES( ctx, instances[i].ParseFromArray(serialized.data(), serialized.size()), errors::FailedPrecondition("Failed to parse the Instance.")); } InstanceParser::Output output; OP_REQUIRES_OK(ctx, ParseStringInstancesOp::GetParse()->Parse( ctx, instances, &output)); OpOutputList out_list; OP_REQUIRES_OK(ctx, ctx->output_list("tensors", &out_list)); OP_REQUIRES(ctx, output.tensors.size() == out_list.size(), errors::FailedPrecondition("output tensor size doesn't match")); for (size_t i = 0; i < output.tensors.size(); ++i) { out_list.set(ParseStringInstancesOp::GetIndex()[i], output.tensors[i]); } Tensor *instance_tensor; OP_REQUIRES_OK(ctx, ctx->allocate_output("sparse_features", TensorShape({ batch_size, }), &instance_tensor)); for (size_t i = 0; i < batch_size; ++i) { instance_tensor->flat()(i) = std::move(instances[i]); } ParseStringInstancesOp::GetCounter()->EmitDataConsumeNumCounter(batch_size); } }; class ParseVariantInstancesOp : public OpKernel { public: explicit ParseVariantInstancesOp(OpKernelConstruction *ctx) : OpKernel(ctx) { InstanceParserConfig config; OP_REQUIRES_OK(ctx, GetParserConfig(ctx, &config, &index_)); config.collapse_batch_dim = false; parser_ = std::make_unique(config); OP_REQUIRES_OK(ctx, parser_->Init()); counter_ = std::make_unique("ParseVariantInstancesOp", true); } void Compute(OpKernelContext *ctx) override { // Grab the input tensor const Tensor *pb_input; OP_REQUIRES_OK(ctx, ctx->input("pb_input", &pb_input)); TTypes::ConstVec pb_variant_tensor = pb_input->vec(); const int batch_size = pb_variant_tensor.dimension(0); std::vector instances; // not allocated memory instances.reserve(batch_size); for (int i = 0; i < batch_size; ++i) { instances.push_back(*pb_variant_tensor(i).get()); } InstanceParser::Output output; OP_REQUIRES_OK(ctx, parser_->Parse(ctx, instances, &output)); OpOutputList out_list; OP_REQUIRES_OK(ctx, ctx->output_list("tensors", &out_list)); for (size_t i = 0; i < output.tensors.size(); ++i) { out_list.set(index_[i], output.tensors[i]); } counter_->EmitDataConsumeNumCounter(batch_size); } private: std::vector index_; std::unique_ptr parser_; std::unique_ptr counter_; }; class ParseVariantInstancesV2Op : public ParseVariantInstancesOp { public: explicit ParseVariantInstancesV2Op(OpKernelConstruction *ctx) : ParseVariantInstancesOp(ctx) {} void Compute(OpKernelContext *ctx) override { ParseVariantInstancesOp::Compute(ctx); OP_REQUIRES_OK(ctx, ctx->set_output("sparse_features", ctx->input(0))); } }; class ParseStringExamplesOp : public OpKernel { public: explicit ParseStringExamplesOp(OpKernelConstruction *ctx) : OpKernel(ctx) { std::vector names; std::vector shapes; std::vector dtypes; std::vector extra_names; OP_REQUIRES_OK(ctx, ctx->GetAttr("names", &names)); OP_REQUIRES_OK(ctx, ctx->GetAttr("shapes", &shapes)); OP_REQUIRES_OK(ctx, ctx->GetAttr("dtypes", &dtypes)); OP_REQUIRES_OK(ctx, ctx->GetAttr("extra_names", &extra_names)); auto creator = [this](FeatureNameMapperTfBridge **out_mapper) { TF_RETURN_IF_ERROR(FeatureNameMapperTfBridge::New(out_mapper)); return Status::OK(); }; ResourceMgr *resource_mgr = ctx->resource_manager(); OP_REQUIRES_OK(ctx, resource_mgr->LookupOrCreate( resource_mgr->default_container(), FeatureNameMapperTfBridge::kName, &mapper_, creator)); parser_ = std::make_unique(names, shapes, dtypes, extra_names, DataType::DT_STRING, mapper_->GetFeatureNameMapper()); counter_ = std::make_unique("ParseStringExamplesOp", true); } ~ParseStringExamplesOp() override { mapper_->Unref(); } void Compute(OpKernelContext *ctx) override { // Grab the input tensor const Tensor *pb_input; OP_REQUIRES_OK(ctx, ctx->input("pb_input", &pb_input)); const auto &serialized_flat = pb_input->flat(); int batch_size = serialized_flat.size(); std::vector examples(batch_size); for (size_t i = 0; i < batch_size; ++i) { const auto &serialized = serialized_flat(i); OP_REQUIRES( ctx, examples[i].ParseFromArray(serialized.data(), serialized.size()), errors::FailedPrecondition("Failed to parse the Example.")); ExtendExample(&examples[i]); } OpOutputList out_list; OP_REQUIRES_OK(ctx, ctx->output_list("tensors", &out_list)); std::vector example_ptrs; example_ptrs.reserve(examples.size()); for (const auto &example : examples) { example_ptrs.push_back(&example); } parser_->Parse(ctx, example_ptrs, &out_list); counter_->EmitDataConsumeNumCounter(batch_size); } private: FeatureNameMapperTfBridge *mapper_ = nullptr; protected: ExampleParser *GetParse() const { return parser_.get(); } DataCounter *GetCounter() const { return counter_.get(); } FeatureNameMapper *GetFeatureNameMapper() const { return mapper_->GetFeatureNameMapper(); } private: std::unique_ptr parser_; std::unique_ptr counter_; }; class ParseStringExamplesV2Op : public ParseStringExamplesOp { public: explicit ParseStringExamplesV2Op(OpKernelConstruction *ctx) : ParseStringExamplesOp(ctx) {} void Compute(OpKernelContext *ctx) override { // Grab the input tensor const Tensor *pb_input; OP_REQUIRES_OK(ctx, ctx->input("pb_input", &pb_input)); const auto &serialized_flat = pb_input->flat(); int batch_size = serialized_flat.size(); Tensor *example_tensor; OP_REQUIRES_OK(ctx, ctx->allocate_output("sparse_features", TensorShape({ batch_size, }), &example_tensor)); google::protobuf::Arena arena; std::vector example_ptrs; example_ptrs.reserve(batch_size); for (size_t i = 0; i < batch_size; ++i) { const auto &serialized = serialized_flat(i); auto *example_ptr = google::protobuf::Arena::CreateMessage(&arena); example_tensor->flat()(i) = std::move(*example_ptr); auto example = example_tensor->flat()(i).get(); OP_REQUIRES(ctx, example->ParseFromArray(serialized.data(), serialized.size()), errors::FailedPrecondition("Failed to parse the Example.")); ExtendExample(example); example_ptrs.push_back(example); } OpOutputList out_list; OP_REQUIRES_OK(ctx, ctx->output_list("tensors", &out_list)); ParseStringExamplesOp::GetParse()->Parse(ctx, example_ptrs, &out_list); ParseStringExamplesOp::GetCounter()->EmitDataConsumeNumCounter(batch_size); } }; class ParseVariantExamplesOp : public OpKernel { public: explicit ParseVariantExamplesOp(OpKernelConstruction *ctx) : OpKernel(ctx) { std::vector names; std::vector shapes; std::vector dtypes; std::vector extra_names; OP_REQUIRES_OK(ctx, ctx->GetAttr("names", &names)); OP_REQUIRES_OK(ctx, ctx->GetAttr("shapes", &shapes)); OP_REQUIRES_OK(ctx, ctx->GetAttr("dtypes", &dtypes)); OP_REQUIRES_OK(ctx, ctx->GetAttr("extra_names", &extra_names)); auto creator = [this](FeatureNameMapperTfBridge **out_mapper) { TF_RETURN_IF_ERROR(FeatureNameMapperTfBridge::New(out_mapper)); return Status::OK(); }; ResourceMgr *resource_mgr = ctx->resource_manager(); OP_REQUIRES_OK(ctx, resource_mgr->LookupOrCreate( resource_mgr->default_container(), FeatureNameMapperTfBridge::kName, &mapper_, creator)); parser_ = std::make_unique(names, shapes, dtypes, extra_names, DataType::DT_VARIANT, mapper_->GetFeatureNameMapper()); counter_ = std::make_unique("ParseVariantExamplesOp", true); } ~ParseVariantExamplesOp() override { mapper_->Unref(); } void Compute(OpKernelContext *ctx) override { const Tensor *pb_input; OP_REQUIRES_OK(ctx, ctx->input("pb_input", &pb_input)); const auto &pb_variant_tensor = pb_input->vec(); int batch_size = pb_variant_tensor.dimension(0); std::vector examples; examples.reserve(batch_size); for (int i = 0; i < batch_size; ++i) { const auto *example = pb_variant_tensor(i).get(); CHECK_NOTNULL(example); examples.push_back(example); } OpOutputList out_list; OP_REQUIRES_OK(ctx, ctx->output_list("tensors", &out_list)); { profiler::TraceMe activity([]() { return "Parse"; }); parser_->Parse(ctx, examples, &out_list); } { profiler::TraceMe activity([]() { return "EmitDataConsumeNumCounter"; }); counter_->EmitDataConsumeNumCounter(batch_size); } } private: FeatureNameMapperTfBridge *mapper_ = nullptr; std::unique_ptr parser_; std::unique_ptr counter_; }; class ParseVariantExamplesV2Op : public ParseVariantExamplesOp { public: explicit ParseVariantExamplesV2Op(OpKernelConstruction *ctx) : ParseVariantExamplesOp(ctx) {} void Compute(OpKernelContext *ctx) override { ParseVariantExamplesOp::Compute(ctx); OP_REQUIRES_OK(ctx, ctx->set_output("sparse_features", ctx->input(0))); } }; class ParseStringExampleBatchOp : public OpKernel { public: explicit ParseStringExampleBatchOp(OpKernelConstruction *ctx) : OpKernel(ctx) { std::vector names; std::vector shapes; std::vector dtypes; std::vector extra_names; OP_REQUIRES_OK(ctx, ctx->GetAttr("names", &names)); OP_REQUIRES_OK(ctx, ctx->GetAttr("shapes", &shapes)); OP_REQUIRES_OK(ctx, ctx->GetAttr("dtypes", &dtypes)); OP_REQUIRES_OK(ctx, ctx->GetAttr("extra_names", &extra_names)); parser_ = std::make_unique( names, shapes, dtypes, extra_names, DataType::DT_STRING); counter_ = std::make_unique("ParseStringExampleBatchOp", false); } void Compute(OpKernelContext *ctx) override { // Grab the input tensor const Tensor *pb_input; OP_REQUIRES_OK(ctx, ctx->input("pb_input", &pb_input)); const auto &serialized = pb_input->flat()(0); google::protobuf::Arena arena; auto *example_batch = google::protobuf::Arena::CreateMessage(&arena); OP_REQUIRES( ctx, example_batch->ParseFromArray(serialized.data(), serialized.size()), errors::FailedPrecondition("Failed to parse the Instance.")); OpOutputList out_list; OP_REQUIRES_OK(ctx, ctx->output_list("tensors", &out_list)); parser_->Parse(ctx, *example_batch, &out_list); counter_->EmitDataConsumeNumCounter(example_batch->batch_size()); } protected: ExampleBatchParser *GetParse() const { return parser_.get(); } DataCounter *GetCounter() const { return counter_.get(); } private: std::unique_ptr parser_; std::unique_ptr counter_; }; class ParseStringExampleBatchV2Op : public ParseStringExampleBatchOp { public: explicit ParseStringExampleBatchV2Op(OpKernelConstruction *ctx) : ParseStringExampleBatchOp(ctx) {} void Compute(OpKernelContext *ctx) override { // Grab the input tensor const Tensor *pb_input; OP_REQUIRES_OK(ctx, ctx->input("pb_input", &pb_input)); const auto &serialized = pb_input->flat()(0); google::protobuf::Arena arena; auto *example_batch_ptr = google::protobuf::Arena::CreateMessage(&arena); Tensor *example_batch_tensor; OP_REQUIRES_OK(ctx, ctx->allocate_output("sparse_features", TensorShape({ 1, }), &example_batch_tensor)); example_batch_tensor->scalar()() = std::move(*example_batch_ptr); auto example_batch = example_batch_tensor->scalar()().get(); OP_REQUIRES( ctx, example_batch->ParseFromArray(serialized.data(), serialized.size()), errors::FailedPrecondition("Failed to parse the Instance.")); OpOutputList out_list; OP_REQUIRES_OK(ctx, ctx->output_list("tensors", &out_list)); ParseStringExampleBatchOp::GetParse()->Parse(ctx, *example_batch, &out_list); ParseStringExampleBatchOp::GetCounter()->EmitDataConsumeNumCounter( example_batch->batch_size()); } }; class ParseVariantExampleBatchOp : public OpKernel { public: explicit ParseVariantExampleBatchOp(OpKernelConstruction *ctx) : OpKernel(ctx) { std::vector names; std::vector shapes; std::vector dtypes; std::vector extra_names; OP_REQUIRES_OK(ctx, ctx->GetAttr("names", &names)); OP_REQUIRES_OK(ctx, ctx->GetAttr("shapes", &shapes)); OP_REQUIRES_OK(ctx, ctx->GetAttr("dtypes", &dtypes)); OP_REQUIRES_OK(ctx, ctx->GetAttr("extra_names", &extra_names)); parser_ = std::make_unique( names, shapes, dtypes, extra_names, DataType::DT_VARIANT); counter_ = std::make_unique("ParseVariantExampleBatchOp", false); } void Compute(OpKernelContext *ctx) override { // Grab the input tensor const Tensor *pb_input; OP_REQUIRES_OK(ctx, ctx->input("pb_input", &pb_input)); const auto &variant = pb_input->flat()(0); const ExampleBatch *example_batch = variant.get(); OpOutputList out_list; OP_REQUIRES_OK(ctx, ctx->output_list("tensors", &out_list)); parser_->Parse(ctx, *example_batch, &out_list); counter_->EmitDataConsumeNumCounter(example_batch->batch_size()); } private: std::unique_ptr parser_; std::unique_ptr counter_; }; class ParseVariantExampleBatchV2Op : public ParseVariantExampleBatchOp { public: explicit ParseVariantExampleBatchV2Op(OpKernelConstruction *ctx) : ParseVariantExampleBatchOp(ctx) {} void Compute(OpKernelContext *ctx) override { ParseVariantExampleBatchOp::Compute(ctx); OP_REQUIRES_OK(ctx, ctx->set_output("sparse_features", ctx->input(0))); } }; class ParseVariantExampleBatchListOp : public OpKernel { public: explicit ParseVariantExampleBatchListOp(OpKernelConstruction *ctx) : OpKernel(ctx) { std::string label_config; std::vector names; std::vector shapes; std::vector dtypes; std::vector extra_names; OP_REQUIRES_OK(ctx, ctx->GetAttr("label_config", &label_config)); internal::ParseTaskConfig(label_config, &label_config_); OP_REQUIRES_OK(ctx, ctx->GetAttr("names", &names)); OP_REQUIRES_OK(ctx, ctx->GetAttr("shapes", &shapes)); OP_REQUIRES_OK(ctx, ctx->GetAttr("dtypes", &dtypes)); OP_REQUIRES_OK(ctx, ctx->GetAttr("extra_names", &extra_names)); OP_REQUIRES_OK(ctx, ctx->GetAttr("positive_label", &positive_label_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("negative_label", &negative_label_)); parser_ = std::make_unique( names, shapes, dtypes, extra_names, DataType::DT_VARIANT); counter_ = std::make_unique("ParseVariantExampleBatchListOp", false); } void Compute(OpKernelContext *ctx) override { // Grab the input tensor OpInputList inputs; OP_REQUIRES_OK(ctx, ctx->input_list("inputs", &inputs)); ExampleBatch example_batch; int batch_size = 0; for (auto iter = inputs.begin(); iter != inputs.end(); ++iter) { const ExampleBatch *sub_eb = iter->scalar()().get(); batch_size += sub_eb->batch_size(); example_batch.MergeFrom(*sub_eb); } OpOutputList out_list; OP_REQUIRES_OK(ctx, ctx->output_list("tensors", &out_list)); parser_->Parse(ctx, example_batch, label_config_, positive_label_, negative_label_, &out_list); counter_->EmitDataConsumeNumCounter(batch_size); } private: std::unique_ptr parser_; std::unique_ptr counter_; std::vector label_config_; float positive_label_ = 1.0f, negative_label_ = 0.0f; }; REGISTER_KERNEL_BUILDER( Name("ParseInstances").Device(DEVICE_CPU).TypeConstraint("T"), ParseStringInstancesOp); REGISTER_KERNEL_BUILDER( Name("ParseInstances").Device(DEVICE_CPU).TypeConstraint("T"), ParseVariantInstancesOp); REGISTER_KERNEL_BUILDER( Name("ParseInstancesV2").Device(DEVICE_CPU).TypeConstraint("T"), ParseStringInstancesV2Op); REGISTER_KERNEL_BUILDER( Name("ParseInstancesV2").Device(DEVICE_CPU).TypeConstraint("T"), ParseVariantInstancesV2Op); REGISTER_KERNEL_BUILDER( Name("ParseExamples").Device(DEVICE_CPU).TypeConstraint("T"), ParseStringExamplesOp); REGISTER_KERNEL_BUILDER( Name("ParseExamples").Device(DEVICE_CPU).TypeConstraint("T"), ParseVariantExamplesOp); REGISTER_KERNEL_BUILDER( Name("ParseExamplesV2").Device(DEVICE_CPU).TypeConstraint("T"), ParseStringExamplesV2Op); REGISTER_KERNEL_BUILDER( Name("ParseExamplesV2").Device(DEVICE_CPU).TypeConstraint("T"), ParseVariantExamplesV2Op); REGISTER_KERNEL_BUILDER( Name("ParseExampleBatch").Device(DEVICE_CPU).TypeConstraint("T"), ParseStringExampleBatchOp); REGISTER_KERNEL_BUILDER( Name("ParseExampleBatch").Device(DEVICE_CPU).TypeConstraint("T"), ParseVariantExampleBatchOp); REGISTER_KERNEL_BUILDER( Name("ParseExampleBatchV2").Device(DEVICE_CPU).TypeConstraint("T"), ParseStringExampleBatchV2Op); REGISTER_KERNEL_BUILDER( Name("ParseExampleBatchV2").Device(DEVICE_CPU).TypeConstraint("T"), ParseVariantExampleBatchV2Op); REGISTER_KERNEL_BUILDER(Name("ParseExampleBatchList").Device(DEVICE_CPU), ParseVariantExampleBatchListOp); } // namespace } // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/data/kernels/parse_sparse_feature.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/data/kernels/parse_sparse_feature.h" #include #include #include #include "monolith/native_training/data/kernels/parse_sparse_feature.h" #include "monolith/native_training/data/training_instance/cc/pb_variant.h" #include "monolith/native_training/runtime/common/metrics.h" namespace tensorflow { namespace monolith_tf { ShardingSparseFidsOp::ShardingSparseFidsOp(OpKernelConstruction *ctx, int version /* = 1*/) : OpKernel(ctx), version_(version) { std::string feature_cfgs_str; OP_REQUIRES_OK(ctx, ctx->GetAttr("ps_num", &ps_num_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("feature_cfgs", &feature_cfgs_str)); OP_REQUIRES_OK(ctx, ctx->GetAttr("unique", &unique_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("parallel_flag", ¶llel_flag_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("single_thread_feature_watermark", &single_thread_feature_watermark_)); std::string input_type; LOG(INFO) << "[ShardingSparseFidsOp] version: " << version; OP_REQUIRES_OK(ctx, ctx->GetAttr("input_type", &input_type)); if (input_type == "example") { input_type_ = 0; } else if (input_type == "examplebatch") { input_type_ = 1; } else if (input_type == "instance") { input_type_ = 2; } else { OP_REQUIRES(ctx, false, errors::FailedPrecondition( "input_type only support example/examplebatch.")); } ::monolith::io::proto::FeatureConfigs feature_cfgs; OP_REQUIRES( ctx, feature_cfgs.ParseFromString(feature_cfgs_str), errors::FailedPrecondition("Failed to parse the FeatureConfigs.")); enable_parallel_ = (parallel_flag_ > 0); // parallel_flag_ == 0 default not parallel auto creator = [this](FeatureNameMapperTfBridge **out_mapper) { TF_RETURN_IF_ERROR(FeatureNameMapperTfBridge::New(out_mapper)); return Status::OK(); }; ResourceMgr *resource_mgr = ctx->resource_manager(); OP_REQUIRES_OK(ctx, resource_mgr->LookupOrCreate( resource_mgr->default_container(), FeatureNameMapperTfBridge::kName, &mapper_, creator)); std::vector feature_names; feature_names.reserve(feature_cfgs.feature_configs_size()); for (const auto &pair : feature_cfgs.feature_configs()) { feature_names.push_back(pair.first); } mapper_raw_ptr_ = mapper_->GetFeatureNameMapper(); CHECK(mapper_raw_ptr_->RegisterValidNames(feature_names)); feature_index_conf_.reserve(feature_cfgs.feature_configs_size() * 2); static std::vector slot_id_feature_prefix({"fc_slot_", "slot_"}); for (auto &iter : feature_cfgs.feature_configs()) { auto &feature_cfg = feature_conf_[iter.first]; feature_cfg.table_name = iter.second.table(); feature_cfg.feature_name = iter.first; feature_cfg.version = version_; int dims_sum = 0; for (size_t slice_idx = 0; slice_idx < iter.second.slice_dims_size(); slice_idx++) { dims_sum += iter.second.slice_dims(slice_idx); } feature_cfg.dims_sum = dims_sum; auto &table_cfg = table_conf_[iter.second.table()]; table_cfg.table_name = iter.second.table(); for (auto &feature_prfix : slot_id_feature_prefix) { if (absl::StartsWith(iter.first, feature_prfix)) { std::string sub_str(iter.first.substr(feature_prfix.size())); try { int slot_id = std::stoi(sub_str); slot_id_to_feature_name_[slot_id] = iter.first; } catch (std::exception const &ex) { LOG(ERROR) << "slot_id_to_feature_name_ err:" << ex.what() << ":" << iter.first << "," << sub_str; continue; } } } } for (auto &iter : table_conf_) { table_cfg_list_.push_back(&iter.second); } std::sort( table_cfg_list_.begin(), table_cfg_list_.end(), [](TableInfo *a, TableInfo *b) { return a->table_name < b->table_name; }); for (uint i = 0; i < table_cfg_list_.size(); ++i) { auto &conf = *(table_cfg_list_[i]); conf.table_index = i; } for (auto &iter : feature_conf_) { feature_cfg_list_.push_back(&iter.second); } std::sort(feature_cfg_list_.begin(), feature_cfg_list_.end(), [](FeatureInfo *a, FeatureInfo *b) { return a->feature_name < b->feature_name; }); for (uint i = 0; i < feature_cfg_list_.size(); ++i) { auto &conf = *(feature_cfg_list_[i]); conf.feature_index = i; auto &table_cfg = table_conf_[conf.table_name]; conf.table_index = table_cfg.table_index; conf.feature_in_table_index = table_cfg.feature_count++; table_cfg.feature_index_list.push_back(i); } for (uint i = 0; i < feature_cfg_list_.size(); ++i) { auto &conf = *(feature_cfg_list_[i]); auto &table_cfg = table_conf_[conf.table_name]; conf.table_feature_count = table_cfg.feature_count; } if (version_ == 2) { int output_index = 0; for (uint i = 0; i < table_cfg_list_.size(); ++i) { auto &table_cfg = *(table_cfg_list_[i]); for (auto feature_index : table_cfg.feature_index_list) { auto &feature_cfg = *(feature_cfg_list_[feature_index]); feature_cfg.output_pre_index = output_index; } output_index += std::max(table_cfg.feature_index_list.size(), 1UL) * ps_num_; } } else { for (auto feature_cfg_ptr : feature_cfg_list_) { feature_cfg_ptr->output_pre_index = feature_cfg_ptr->table_index * ps_num_; } } if (mapper_raw_ptr_->IsAvailable()) { int32_t max_sorted_id = -1; LOG_FIRST_N(INFO, 1) << mapper_raw_ptr_->DebugString(); absl::flat_hash_map feature_index_conf_tmp; feature_index_conf_tmp.reserve(feature_conf_.size() * 2); for (auto &iter : feature_conf_) { int32_t id = -1; int32_t sorted_id = -1; bool found = mapper_raw_ptr_->GetIdByName(iter.first, &id, &sorted_id); if (found && !feature_index_conf_tmp.contains(sorted_id)) { feature_index_conf_tmp[sorted_id] = &iter.second; max_sorted_id = std::max(max_sorted_id, sorted_id); } else { feature_index_conf_tmp.clear(); LOG(ERROR) << "mapper_raw_ptr_ not find:" << iter.first; break; } } if (feature_index_conf_tmp.size() > 0) { feature_index_conf_.resize(max_sorted_id + 1, nullptr); for (auto &iter : feature_index_conf_tmp) { feature_index_conf_[iter.first] = iter.second; } } } else { LOG(WARNING) << "mapper_raw_ptr_ not Available()"; } } void ShardingSparseFidsOp::FillFidList( uint64_t value, std::vector> &shard_vec, MultiShardUniqHashTable &shard_uniq_hashtable, tensorflow::TTypes::Flat fid_offset_flat, int feature_output_index, int *offset) { if (unique_) { FillFidList(value, shard_uniq_hashtable, fid_offset_flat, feature_output_index, offset); } else { FillFidList(value, shard_vec, fid_offset_flat, feature_output_index, offset); } } void ShardingSparseFidsOp::FillFidList( uint64_t value, std::vector> &shard_vec, tensorflow::TTypes::Flat fid_offset_flat, int feature_output_index, int *offset) { auto mod = value % ps_num_; int output_offset = feature_cfg_list_[feature_output_index]->GetFidOutputIndex(mod); int feature_offset = shard_vec[mod].size(); if (version_ == 3 || version_ == 4 || version_ == 5) { feature_offset *= feature_cfg_list_[feature_output_index]->dims_sum; } fid_offset_flat(*offset) = ((uint64_t(output_offset) << 32) | feature_offset); shard_vec[mod].push_back(value); ++(*offset); } void ShardingSparseFidsOp::FillFidList( uint64_t value, MultiShardUniqHashTable &shard_uniq_hashtable, tensorflow::TTypes::Flat fid_offset_flat, int feature_output_index, int *offset) { auto mod = value % ps_num_; int output_offset = feature_cfg_list_[feature_output_index]->GetFidOutputIndex(mod); auto feature_offset = shard_uniq_hashtable.uniq_fid(value, mod); // auto fid_find_iter = shard_vec[mod].find(value); // int feature_offset1 = -1; // if (fid_find_iter == shard_vec[mod].end()) { // feature_offset1 = shard_vec[mod].size(); // shard_vec[mod].emplace(value, feature_offset1); // } else { // feature_offset1 = fid_find_iter->second; // } // CHECK_EQ(feature_offset1, feature_offset); if (version_ == 3 || version_ == 4 || version_ == 5) { feature_offset *= feature_cfg_list_[feature_output_index]->dims_sum; } fid_offset_flat(*offset) = (uint64_t(output_offset) << 32) | feature_offset; ++(*offset); } void ShardingSparseFidsOp::CopyFidList( const std::vector &shard_ptr, int offset, TensorSliceAccessor *cur_tensor) { void *data_ptr = cur_tensor->ptr; std::memcpy(reinterpret_cast(data_ptr) + int64_size_ * offset, shard_ptr.data(), shard_ptr.size() * int64_size_); } void ShardingSparseFidsOp::CopyFidList( const absl::flat_hash_map &shard_ptr, int offset, TensorSliceAccessor *cur_tensor) { // auto cur_tensor_flat = cur_tensor->template flat(); for (auto &fid : shard_ptr) { (*cur_tensor)(fid.second + offset) = fid.first; } } Status ShardingSparseFidsOp::CreateOffsetTensor( OpKernelContext *ctx, const std::vector> &all_feature_counter, int all_feature_counter_size, Tensor **nfl_offset_tensor, Tensor **feature_offset_tensor, Tensor **fid_offset_tensor, OpOutputList *fid_list_row_splits_out_list, OpOutputList *fid_list_row_splits_size_out_list, std::vector &tmp_tensor_list, std::vector> &fid_list_row_splits_flat_list, std::vector *nfl_fid_offset, const std::unordered_set *shared_feature) { // feature_size TF_RETURN_IF_ERROR( ctx, ctx->allocate_output("nfl_offset", TensorShape({ feature_conf_.size() + 1, }), nfl_offset_tensor)); auto nfl_offset_flat = (*nfl_offset_tensor)->flat(); // 最大是 feature_size * batch_size, 但是shared_feature只算一个 TF_RETURN_IF_ERROR(ctx, ctx->allocate_output("feature_offset", TensorShape({ all_feature_counter_size + 1, }), feature_offset_tensor)); auto feature_offset_flat = (*feature_offset_tensor)->flat(); if (version_ == 5) { Tensor *nfl_size_tensor; TF_RETURN_IF_ERROR(ctx, ctx->allocate_output("nfl_size", TensorShape({ 1, }), &nfl_size_tensor)); nfl_size_tensor->flat()(0) = feature_conf_.size() + 1; Tensor *feature_size_tensor; TF_RETURN_IF_ERROR(ctx, ctx->allocate_output("feature_size", TensorShape({ 1, }), &feature_size_tensor)); feature_size_tensor->flat()(0) = all_feature_counter_size + 1; } int all_fid_size = 0; int feature_offset_index = 0; for (uint i = 0; i < all_feature_counter.size(); ++i) { auto &feature_counter = all_feature_counter[i]; nfl_offset_flat(i) = feature_offset_index; if (shared_feature->size() > 0 && shared_feature->count(i) > 0) { nfl_offset_flat(i) |= SHARED_FLAG; } (*nfl_fid_offset)[i] = all_fid_size; for (uint j = 0; j < feature_counter.size(); ++j) { feature_offset_flat(feature_offset_index) = all_fid_size; all_fid_size += abs(feature_counter[j]); ++feature_offset_index; } } CHECK_EQ(feature_offset_index, all_feature_counter_size); feature_offset_flat(feature_offset_index) = all_fid_size; nfl_offset_flat(all_feature_counter.size()) = feature_offset_index + 1; // fid_size TF_RETURN_IF_ERROR(ctx, ctx->allocate_output("fid_offset", TensorShape({ all_fid_size, }), fid_offset_tensor)); if (version_ == 5) { Tensor *fid_size_tensor; TF_RETURN_IF_ERROR(ctx, ctx->allocate_output("fid_size", TensorShape({ 1, }), &fid_size_tensor)); fid_size_tensor->flat()(0) = all_fid_size; } if (version_ == 4) { Tensor *fid_list_row_lengths_tensor; int all_feature_count = 0; for (uint i = 0; i < table_cfg_list_.size(); ++i) { all_feature_count += table_cfg_list_[i]->feature_count + 1; } all_feature_count *= ps_num_; TF_RETURN_IF_ERROR( ctx, ctx->allocate_output("fid_list_row_splits", TensorShape({ all_feature_count, }), &fid_list_row_lengths_tensor)); auto cur_tensor_flat = fid_list_row_lengths_tensor->flat(); cur_tensor_flat.setZero(); fid_list_row_splits_flat_list.resize(table_cfg_list_.size() * ps_num_); int pre_feature_count = 0; for (uint ps_num_i = 0; ps_num_i < ps_num_; ++ps_num_i) { for (uint table_index = 0; table_index < table_cfg_list_.size(); ++table_index) { /* LOG(ERROR) << "xxxx all_feature_count:" << all_feature_count << ",pre_feature_count:" << pre_feature_count << ",feature_count:" << table_cfg_list_[table_index]->feature_count; auto cur_tensor = fid_list_row_lengths_tensor->Slice( pre_feature_count, pre_feature_count + table_cfg_list_[table_index]->feature_count); */ int cur_count = table_cfg_list_[table_index]->feature_count + 1; fid_list_row_splits_flat_list[table_index * ps_num_ + ps_num_i] = TensorSliceAccessor( {static_cast(fid_list_row_lengths_tensor->data()) + pre_feature_count, cur_count}); pre_feature_count += cur_count; } } } else { tmp_tensor_list.resize(table_cfg_list_.size() * ps_num_); int fid_list_row_splits_flat_list_index = -1; for (uint i = 0; i < table_cfg_list_.size(); ++i) { for (uint j = 0; j < ps_num_; ++j) { Tensor *cur_tensor; ++fid_list_row_splits_flat_list_index; if (version_ == 2 || version_ == 3) { TF_RETURN_IF_ERROR(ctx, fid_list_row_splits_out_list->allocate( fid_list_row_splits_flat_list_index, tensorflow::TensorShape{ table_cfg_list_[i]->feature_count + 1}, &cur_tensor)); } else if (version_ == 5) { TF_RETURN_IF_ERROR(ctx, fid_list_row_splits_out_list->allocate( fid_list_row_splits_flat_list_index, tensorflow::TensorShape{ table_cfg_list_[i]->feature_count + 1}, &cur_tensor)); Tensor *size_tensor; TF_RETURN_IF_ERROR(ctx, fid_list_row_splits_size_out_list->allocate( fid_list_row_splits_flat_list_index, tensorflow::TensorShape{ 1, }, &size_tensor)); size_tensor->flat()(0) = table_cfg_list_[i]->feature_count + 1; } else { cur_tensor = &(tmp_tensor_list[i * ps_num_ + j]); TF_RETURN_IF_ERROR( ctx, ctx->allocate_temp(DT_INT64, tensorflow::TensorShape{ table_cfg_list_[i]->feature_count + 1}, cur_tensor)); } auto cur_tensor_flat = cur_tensor->flat(); cur_tensor_flat.setZero(); fid_list_row_splits_flat_list.emplace_back(TensorSliceAccessor( {static_cast(cur_tensor->data()), cur_tensor->NumElements()})); } } } return Status::OK(); } void ShardingSparseFidsOp::ParallelRun( OpKernelContext *ctx, int task_count, const std::function &fn) { if (enable_parallel_ && task_count > 1) { auto workers = ctx->device()->tensorflow_cpu_worker_threads()->workers; workers->ParallelFor( task_count, thread::ThreadPool::SchedulingParams( thread::ThreadPool::SchedulingStrategy::kFixedBlockSize, absl::nullopt, 1), fn); } else { for (int i = 0; i < task_count; ++i) { fn(i, i + 1); } } } void ShardingSparseFidsOp::InitInstanceWrapper( ShardingSparseFidsOp::InstanceWrapper *instance_wrapper) { auto &instances = instance_wrapper->instances; if (slot_id_to_feature_name_.size() == 0) { return; } instance_wrapper->fid_v1.reserve(slot_id_to_feature_name_.size()); for (auto &slot_name_iter : slot_id_to_feature_name_) { auto &part = instance_wrapper->fid_v1[slot_name_iter.second]; part.resize(instances.size()); for (int i = 0; i < part.size(); ++i) { part[i].reserve(instances[i]->fid_size()); } } for (int i = 0; i < instances.size(); ++i) { for (auto fid : instances[i]->fid()) { int slot_id = slot_id_v1(fid); auto find_iter = slot_id_to_feature_name_.find(slot_id); if (find_iter == slot_id_to_feature_name_.end()) { continue; } instance_wrapper->fid_v1[find_iter->second][i].push_back(fid); } } } void ShardingSparseFidsOp::Compute(OpKernelContext *ctx) { auto shard_sparse_op_latency_start = std::chrono::system_clock::now(); const Tensor *pb_input; OP_REQUIRES_OK(ctx, ctx->input("pb_input", &pb_input)); OpOutputList out_list; if (version_ != 4) { OP_REQUIRES_OK(ctx, ctx->output_list("fid_list", &out_list)); } OpOutputList fid_list_row_splits_out_list; OpOutputList fid_list_row_splits_size_out_list; if (version_ == 2 || version_ == 3) { OP_REQUIRES_OK( ctx, ctx->output_list("fid_list_row_splits", &fid_list_row_splits_out_list)); } if (version_ == 5) { OP_REQUIRES_OK(ctx, ctx->output_list("fid_list_row_splits", &fid_list_row_splits_out_list)); OP_REQUIRES_OK(ctx, ctx->output_list("fid_list_row_splits_size", &fid_list_row_splits_size_out_list)); } int batch_size = 0; Status st; if (input_type_ == 0) { const auto &pb_variant_tensor = pb_input->vec(); batch_size = pb_variant_tensor.dimension(0); std::vector examples; examples.reserve(batch_size); for (int i = 0; i < batch_size; ++i) { const auto *example = pb_variant_tensor(i).get(); CHECK_NOTNULL(example); examples.push_back(example); } st = FeatureParallelParse(ctx, examples, &out_list, &fid_list_row_splits_out_list, &fid_list_row_splits_size_out_list); } else if (input_type_ == 1) { const auto &example_batch = *(pb_input->scalar()().get()); batch_size = example_batch.batch_size(); st = FeatureParallelParse(ctx, example_batch, &out_list, &fid_list_row_splits_out_list, &fid_list_row_splits_size_out_list); } else if (input_type_ == 2) { const auto &pb_variant_tensor = pb_input->vec(); batch_size = pb_variant_tensor.dimension(0); InstanceWrapper instance_wapper; instance_wapper.instances.reserve(batch_size); for (int i = 0; i < batch_size; ++i) { const auto *instance = pb_variant_tensor(i).get(); CHECK_NOTNULL(instance); instance_wapper.instances.push_back(instance); } InitInstanceWrapper(&instance_wapper); st = FeatureParallelParse(ctx, instance_wapper, &out_list, &fid_list_row_splits_out_list, &fid_list_row_splits_size_out_list); } OP_REQUIRES_OK(ctx, st); Tensor *batch_size_tensor; if (version_ == 5) { OP_REQUIRES_OK(ctx, ctx->allocate_output("batch_size", TensorShape({1, }), &batch_size_tensor)); batch_size_tensor->flat()(0) = batch_size; } else { OP_REQUIRES_OK(ctx, ctx->allocate_output("batch_size", TensorShape({}), &batch_size_tensor)); batch_size_tensor->scalar()() = batch_size; } auto shard_sparse_op_latency_end = std::chrono::system_clock::now(); std::chrono::duration shard_sparse_op_latency_diff = std::chrono::duration_cast( shard_sparse_op_latency_end - shard_sparse_op_latency_start); monolith::GetMetrics()->emit_timer("sharding_sparse_fids_op_latency", shard_sparse_op_latency_diff.count()); } REGISTER_KERNEL_BUILDER(Name("ShardingSparseFids").Device(DEVICE_CPU), ShardingSparseFidsOp); class ShardingSparseFidsOpV2 : public ShardingSparseFidsOp { public: explicit ShardingSparseFidsOpV2(OpKernelConstruction *ctx) : ShardingSparseFidsOp(ctx, 2) {} }; REGISTER_KERNEL_BUILDER(Name("ShardingSparseFidsV2").Device(DEVICE_CPU), ShardingSparseFidsOpV2); class ShardingSparseFidsOpV3 : public ShardingSparseFidsOp { public: explicit ShardingSparseFidsOpV3(OpKernelConstruction *ctx) : ShardingSparseFidsOp(ctx, 3) {} }; REGISTER_KERNEL_BUILDER(Name("ShardingSparseFidsV3").Device(DEVICE_CPU), ShardingSparseFidsOpV3); class ShardingSparseFidsOpV4 : public ShardingSparseFidsOp { public: explicit ShardingSparseFidsOpV4(OpKernelConstruction *ctx) : ShardingSparseFidsOp(ctx, 4) {} }; REGISTER_KERNEL_BUILDER(Name("ShardingSparseFidsV4").Device(DEVICE_CPU), ShardingSparseFidsOpV4); class ShardingSparseFidsOpV5 : public ShardingSparseFidsOp { public: explicit ShardingSparseFidsOpV5(OpKernelConstruction *ctx) : ShardingSparseFidsOp(ctx, 5) {} }; REGISTER_KERNEL_BUILDER(Name("ShardingSparseFidsV5").Device(DEVICE_CPU), ShardingSparseFidsOpV5); // 该函数包含6个输出 // fid_list: shape(table_count*ps_num, 若干fid), // 将不同feature样本的全部fid聚合,并按照ps_num分shard,按照feature->table的映射聚合填充 // fid_list_row_splits shape(table_count*ps_num, table内feature个数+1), // 与fid_list组成ragged_tensor,主要作用是将相同table内不同feature区分开 // fid_offset shape(特征数*(1 if shard feature else // batch_size)*fid数), // 样本的一维平铺,最后不是存储的fid,而是在fid_list中的偏移,用于在fid_list寻址 // 高32位为fid_list 第一维度与feature在table内index 的组合 // 低32位为fid在当前feature fid_list的第几位 // feature_offset shape(特征数*(1 if shard feature else batch_size)), // 对fid_offset一维平铺的拆解,标识每个样本的分界点 // nfl_offset shape(特征数), // 对fid_offset/feature_offset一维平铺的拆解,标识每个特征的分界点 // batch_size REGISTER_OP("ShardingSparseFids") .Input("pb_input: variant") .Output("fid_list: N * int64") .Output("fid_offset: uint64") .Output("feature_offset: int32") .Output("nfl_offset: uint32") .Output("batch_size: int32") .Attr("ps_num: int") .Attr("feature_cfgs: string") .Attr("N: int") .Attr("unique: bool") .Attr("input_type: string") .Attr("parallel_flag: int") .Attr("single_thread_feature_watermark: int = 320000") .SetDoNotOptimize() // Source dataset ops must disable constant folding. .SetShapeFn([](shape_inference::InferenceContext *ctx) { // fid_list int N = 0; TF_RETURN_IF_ERROR(ctx->GetAttr("N", &N)); for (int i = 0; i < N; ++i) { ctx->set_output(i, ctx->Vector(ctx->UnknownDim())); // fid_list } ctx->set_output(N, ctx->Vector(ctx->UnknownDim())); // fid_offset ctx->set_output(N + 1, ctx->Vector(ctx->UnknownDim())); // feature_offset ctx->set_output(N + 2, ctx->Vector(ctx->UnknownDim())); // nfl_offset ctx->set_output(N + 3, ctx->Scalar()); // batch_size return Status::OK(); }); REGISTER_OP("ShardingSparseFidsV2") .Input("pb_input: variant") .Output("fid_list: N * int64") .Output("fid_list_row_splits: N * int64") .Output("fid_offset: uint64") .Output("feature_offset: int32") .Output("nfl_offset: uint32") .Output("batch_size: int32") .Attr("ps_num: int") .Attr("feature_cfgs: string") .Attr("N: int") .Attr("unique: bool") .Attr("input_type: string") .Attr("parallel_flag: int") .Attr("single_thread_feature_watermark: int = 320000") .SetDoNotOptimize() // Source dataset ops must disable constant folding. .SetShapeFn([](shape_inference::InferenceContext *ctx) { // fid_list int N = 0; TF_RETURN_IF_ERROR(ctx->GetAttr("N", &N)); for (int i = 0; i < N; ++i) { ctx->set_output(i, ctx->Vector(ctx->UnknownDim())); // fid_list ctx->set_output(N + i, ctx->Vector(ctx->UnknownDim())); // fid_list_row_splits } N *= 2; ctx->set_output(N, ctx->Vector(ctx->UnknownDim())); // fid_offset ctx->set_output(N + 1, ctx->Vector(ctx->UnknownDim())); // feature_offset ctx->set_output(N + 2, ctx->Vector(ctx->UnknownDim())); // nfl_offset ctx->set_output(N + 3, ctx->Scalar()); // batch_size return Status::OK(); }); // 与v2的区别在于fid_offset包含偏移都乘以了feature的dim REGISTER_OP("ShardingSparseFidsV3") .Input("pb_input: variant") .Output("fid_list: N * int64") .Output("fid_list_row_splits: N * int64") .Output("fid_offset: uint64") .Output("feature_offset: int32") .Output("nfl_offset: uint32") .Output("batch_size: int32") .Attr("ps_num: int") .Attr("feature_cfgs: string") .Attr("N: int") .Attr("unique: bool") .Attr("input_type: string") .Attr("parallel_flag: int") .Attr("single_thread_feature_watermark: int = 320000") .SetDoNotOptimize() // Source dataset ops must disable constant folding. .SetShapeFn([](shape_inference::InferenceContext *ctx) { // fid_list int N = 0; TF_RETURN_IF_ERROR(ctx->GetAttr("N", &N)); for (int i = 0; i < N; ++i) { ctx->set_output(i, ctx->Vector(ctx->UnknownDim())); // fid_list ctx->set_output(N + i, ctx->Vector(ctx->UnknownDim())); // fid_list_row_splits } N *= 2; ctx->set_output(N, ctx->Vector(ctx->UnknownDim())); // fid_offset ctx->set_output(N + 1, ctx->Vector(ctx->UnknownDim())); // feature_offset ctx->set_output(N + 2, ctx->Vector(ctx->UnknownDim())); // nfl_offset ctx->set_output(N + 3, ctx->Scalar()); // batch_size int ps_num = 0; TF_RETURN_IF_ERROR(ctx->GetAttr("ps_num", &ps_num)); std::string feature_cfgs_str; TF_RETURN_IF_ERROR(ctx->GetAttr("feature_cfgs", &feature_cfgs_str)); ::monolith::io::proto::FeatureConfigs feature_cfgs; CHECK(feature_cfgs.ParseFromString(feature_cfgs_str)); std::unordered_set table_name; for (auto &iter : feature_cfgs.feature_configs()) { table_name.insert(iter.second.table()); } return Status::OK(); }); // 为适配gpu emb,将多个tenor合并优化性能,并设配输入输出结构 // fid_list_row_splits [ps_shard, table, feature] // fid_list_table_row_length [ps_shard, table] // [ps_shard] // fid_list_emb_row_lenth [ps_shard, table] 对应emb在ps_shard的切分 REGISTER_OP("ShardingSparseFidsV4") .Input("pb_input: variant") .Output("fid_list: int64") .Output("fid_list_row_splits: int64") .Output("fid_list_table_row_length: int32") .Output("fid_list_shard_row_lenth: int32") .Output("fid_list_emb_row_lenth: int32") .Output("fid_offset: uint64") .Output("feature_offset: int32") .Output("nfl_offset: uint32") .Output("batch_size: int32") .Attr("ps_num: int") .Attr("feature_cfgs: string") .Attr("unique: bool") .Attr("input_type: string") .Attr("parallel_flag: int") .Attr("single_thread_feature_watermark: int = 320000") .SetDoNotOptimize() // Source dataset ops must disable constant folding. .SetShapeFn([](shape_inference::InferenceContext *ctx) { int ps_num = 0; TF_RETURN_IF_ERROR(ctx->GetAttr("ps_num", &ps_num)); std::string feature_cfgs_str; TF_RETURN_IF_ERROR(ctx->GetAttr("feature_cfgs", &feature_cfgs_str)); ::monolith::io::proto::FeatureConfigs feature_cfgs; CHECK(feature_cfgs.ParseFromString(feature_cfgs_str)); std::unordered_set table_name; for (auto &iter : feature_cfgs.feature_configs()) { table_name.insert(iter.second.table()); } // fid_list int N = 0; ctx->set_output(0, ctx->Vector(ctx->UnknownDim())); // fid_list ctx->set_output(1, ctx->Vector(ctx->UnknownDim())); // fid_list_row_splits ctx->set_output( 2, ctx->Vector(ps_num * table_name.size())); // fid_list_table_row_length ctx->set_output(3, ctx->Vector(ps_num)); // fid_list_shard_row_lenth ctx->set_output( 4, ctx->Vector(ps_num * table_name.size())); // fid_list_emb_row_lenth N = 5; ctx->set_output(N, ctx->Vector(ctx->UnknownDim())); // fid_offset ctx->set_output(N + 1, ctx->Vector(ctx->UnknownDim())); // feature_offset ctx->set_output(N + 2, ctx->Vector(ctx->UnknownDim())); // nfl_offset ctx->set_output(N + 3, ctx->Scalar()); // batch_size return Status::OK(); }); REGISTER_OP("ShardingSparseFidsV5") .Input("pb_input: variant") .Output("fid_list: N * int64") .Output("fid_list_row_splits: N * int64") .Output("fid_list_row_splits_size: N * int32") .Output("fid_offset: uint64") .Output("feature_offset: int32") .Output("nfl_offset: uint32") .Output("batch_size: int32") .Output("nfl_size: int32") .Output("feature_size: int32") .Output("fid_size: int32") .Output("emb_size: int32") .Attr("ps_num: int") .Attr("feature_cfgs: string") .Attr("N: int") .Attr("unique: bool") .Attr("input_type: string") .Attr("parallel_flag: int") .Attr("single_thread_feature_watermark: int = 320000") .SetDoNotOptimize() // Source dataset ops must disable constant folding. .SetShapeFn([](shape_inference::InferenceContext *ctx) { // fid_list int N = 0; TF_RETURN_IF_ERROR(ctx->GetAttr("N", &N)); for (int i = 0; i < N; ++i) { ctx->set_output(i, ctx->Vector(ctx->UnknownDim())); // fid_list ctx->set_output(N + i, ctx->Vector(ctx->UnknownDim())); // fid_list_row_splits ctx->set_output( N * 2 + i, ctx->Vector(ctx->UnknownDim())); // fid_list_row_splits_size } N *= 3; ctx->set_output(N, ctx->Vector(ctx->UnknownDim())); // fid_offset ctx->set_output(N + 1, ctx->Vector(ctx->UnknownDim())); // feature_offset ctx->set_output(N + 2, ctx->Vector(ctx->UnknownDim())); // nfl_offset ctx->set_output(N + 3, ctx->Vector(ctx->UnknownDim())); // batch_size ctx->set_output(N + 4, ctx->Vector(ctx->UnknownDim())); // nfl_size ctx->set_output(N + 5, ctx->Vector(ctx->UnknownDim())); // feature_size ctx->set_output(N + 6, ctx->Vector(ctx->UnknownDim())); // fid_size ctx->set_output(N + 7, ctx->Vector(ctx->UnknownDim())); // emb_size // ctx->set_output(N + 3, ctx->Scalar()); // batch_size int ps_num = 0; TF_RETURN_IF_ERROR(ctx->GetAttr("ps_num", &ps_num)); std::string feature_cfgs_str; TF_RETURN_IF_ERROR(ctx->GetAttr("feature_cfgs", &feature_cfgs_str)); ::monolith::io::proto::FeatureConfigs feature_cfgs; CHECK(feature_cfgs.ParseFromString(feature_cfgs_str)); std::unordered_set table_name; for (auto &iter : feature_cfgs.feature_configs()) { table_name.insert(iter.second.table()); } return Status::OK(); }); } // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/data/kernels/parse_sparse_feature.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_NATIVE_TRAINING_DATA_KERNELS_PARSE_SPARSE_FEATURE_LIB_H_ #define MONOLITH_NATIVE_TRAINING_DATA_KERNELS_PARSE_SPARSE_FEATURE_LIB_H_ #include #include #include #include "google/protobuf/descriptor.h" #include "idl/matrix/proto/example.pb.h" #include "idl/matrix/proto/proto_parser.pb.h" #include "monolith/native_training/data/kernels/feature_name_mapper_tf_bridge.h" #include "monolith/native_training/data/kernels/internal/label_utils.h" #include "monolith/native_training/data/kernels/internal/uniq_hashtable.h" #include "monolith/native_training/data/training_instance/cc/reader_util.h" #include "monolith/native_training/runtime/common/metrics.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/platform/env.h" #include "absl/strings/match.h" #include "tensorflow/core/platform/threadpool.h" #include "tensorflow/core/profiler/lib/traceme.h" namespace tensorflow { namespace monolith_tf { class ShardingSparseFidsOp : public OpKernel { public: using FeatureListType = ::monolith::io::proto::FeatureListType; using Instance = ::parser::proto::Instance; using Example = ::monolith::io::proto::Example; using ExampleBatch = ::monolith::io::proto::ExampleBatch; using FeatureConfigs = ::monolith::io::proto::FeatureConfigs; explicit ShardingSparseFidsOp(OpKernelConstruction *ctx, int version = 1); ~ShardingSparseFidsOp() override { mapper_->Unref(); } void Compute(OpKernelContext *ctx) override; private: struct InstanceWrapper { struct FeaturePtr { FeaturePtr(const std::vector *fid_v1_, const idl::matrix::proto::Feature *fid_v2_) : fid_v1(fid_v1_), fid_v2(fid_v2_) {} const std::vector *fid_v1 = nullptr; const idl::matrix::proto::Feature *fid_v2 = nullptr; }; std::vector instances; absl::flat_hash_map>> fid_v1; }; void InitInstanceWrapper(InstanceWrapper *instance_wrapper); template Status FeatureParallelParse(OpKernelContext *ctx, const TInput &input, OpOutputList *fid_list_out_list, OpOutputList *fid_list_row_splits_out_list, OpOutputList *fid_list_row_splits_size_out_list); int GetBatchSize(const InstanceWrapper &instance_wrapper) { return instance_wrapper.instances.size(); } int GetBatchSize(const ::monolith::io::proto::ExampleBatch &example_batch) { return example_batch.batch_size(); } template int GetBatchSize(const std::vector &inputs) { return inputs.size(); } void ParallelRun(OpKernelContext *ctx, int task_count, const std::function &fn); template struct TensorSliceAccessor { TData *ptr = nullptr; int64_t size = 0; TData &operator()(int64_t index) { // CHECK(index >= 0 && index < size); return *(ptr + index); } }; void FillFidList(uint64_t value, std::vector> &shard_vec, MultiShardUniqHashTable &shard_uniq_hashtable, tensorflow::TTypes::Flat fid_offset_flat, int feature_output_index, int *offset); void FillFidList(uint64_t value, std::vector> &shard_vec, tensorflow::TTypes::Flat fid_offset_flat, int feature_output_index, int *offset); void FillFidList(uint64_t value, MultiShardUniqHashTable &shard_uniq_hashtable, tensorflow::TTypes::Flat fid_offset_flat, int feature_output_index, int *offset); void CopyFidList(const std::vector &shard_ptr, int offset, TensorSliceAccessor *to); void CopyFidList(const absl::flat_hash_map &shard_ptr, int offset, TensorSliceAccessor *to); Status CreateOffsetTensor( OpKernelContext *ctx, const std::vector> &all_feature_counter, int all_feature_counter_size, Tensor **nfl_offset_tensor, Tensor **feature_offset_tensor, Tensor **fid_offset_tensor, OpOutputList *fid_list_row_splits_out_list, OpOutputList *fid_list_row_splits_size_out_list, std::vector &tmp_tensor_list, std::vector> &fid_list_row_splits_flat_list, std::vector *nfl_fid_offset, const std::unordered_set *shared_feature); struct FeatureInfo { std::string feature_name; std::string table_name; int feature_index = -1; int table_index = -1; int feature_in_table_index = -1; int table_feature_count; int output_pre_index; int dims_sum; int version = 1; int GetFidOutputIndex(int ps_i) { if (version == 2) { return output_pre_index + ps_i * table_feature_count + feature_in_table_index; } else { return output_pre_index + ps_i; } } int GetPsShard(int fid_offset) { if (version == 2) { return (fid_offset - output_pre_index - feature_in_table_index) / table_feature_count; // no use, only version==1 or version==3 // will call this func } else { return fid_offset - output_pre_index; } } }; absl::flat_hash_map feature_conf_; std::vector feature_index_conf_; std::vector feature_cfg_list_; struct TableInfo { std::string table_name; int table_index = -1; int feature_count = 0; std::vector feature_index_list; }; absl::flat_hash_map table_conf_; std::vector table_cfg_list_; absl::flat_hash_map slot_id_to_feature_name_; // instance fid_v1 slot 特征 映射 int ps_num_ = 0; int single_thread_feature_watermark_ = 80000 * 4; int single_thread_assign_watermark_ = 100000 * 4; int int64_size_ = sizeof(int64_t); static constexpr uint32_t SHARED_FLAG = (1L << 31); bool unique_ = false; int parallel_flag_ = 0; int input_type_ = 0; bool enable_parallel_ = true; int version_ = 1; FeatureNameMapperTfBridge *mapper_ = nullptr; FeatureNameMapper *mapper_raw_ptr_ = nullptr; template void SplitTask(const std::vector &context_list, int limit, std::vector> *out); #define DFeatureParallelMakeUpTask1Context(INPUT_TYPE) \ template \ void FeatureParallelMakeUpTask1Context( \ const INPUT_TYPE &input, int batch_size, \ std::vector *task_context_list, \ absl::flat_hash_map \ *feature_shard_count_map, \ absl::flat_hash_map> \ *table_feature_map, \ std::vector> *all_feature_counter, \ std::unordered_set *shared_feature, int *all_feature_counter_size) DFeatureParallelMakeUpTask1Context(::monolith::io::proto::ExampleBatch); DFeatureParallelMakeUpTask1Context( std::vector); DFeatureParallelMakeUpTask1Context(InstanceWrapper); #undef DFeatureParallelMakeUpTask1Context #define DFeatureParallelDoTask1(INPUT_TYPE) \ template \ void FeatureParallelDoTask1( \ const INPUT_TYPE &input, TFeatureParallelTask1Context *task_context, \ std::vector &nfl_fid_offset, \ tensorflow::TTypes::Flat fid_offset_flat, \ tensorflow::TTypes::Flat feature_offset_flat) DFeatureParallelDoTask1(::monolith::io::proto::ExampleBatch); DFeatureParallelDoTask1(std::vector); DFeatureParallelDoTask1(InstanceWrapper); #undef DFeatureParallelDoTask1 }; template void ShardingSparseFidsOp::SplitTask(const std::vector &context_list, int limit, std::vector> *out) { std::vector full_index; int pre_count = 0; for (unsigned int i = 0; i < context_list.size(); ++i) { const auto &context = context_list[i]; if (context.size == 0) { continue; } else if (context.size >= limit) { full_index.push_back(i); } else { if (pre_count == 0) { out->emplace_back(std::vector()); } if (pre_count + context.size <= limit) { out->back().push_back(i); pre_count += context.size; } else { out->emplace_back(std::vector()); out->back().push_back(i); pre_count = context.size; } } } for (auto index : full_index) { out->emplace_back(std::vector({index})); } } template void ShardingSparseFidsOp::FeatureParallelMakeUpTask1Context( const ::monolith::io::proto::ExampleBatch &example_batch, int batch_size, std::vector *task_context_list, absl::flat_hash_map *feature_shard_count_map, absl::flat_hash_map> *table_feature_map, std::vector> *all_feature_counter, std::unordered_set *shared_feature, int *all_feature_counter_size) { std::vector example_batch_feature_index(feature_conf_.size(), -1); for (int n_i = 0; n_i < example_batch.named_feature_list_size(); ++n_i) { const auto &named_feature_list = example_batch.named_feature_list(n_i); auto &name = named_feature_list.name(); auto find_iter = feature_conf_.find(name); if (find_iter == feature_conf_.end()) { continue; } example_batch_feature_index[find_iter->second.feature_index] = n_i; } for (uint i = 0; i < example_batch_feature_index.size(); ++i) { int n_i = example_batch_feature_index[i]; auto &feature_counter = (*all_feature_counter)[i]; if (n_i < 0) { feature_counter.push_back(0); shared_feature->insert(i); *all_feature_counter_size += 1; continue; } TFeatureParallelTask1Context *task_context = nullptr; auto table_index = feature_cfg_list_[i]->table_index; auto feature_shard_count_map_iter = feature_shard_count_map->find(i); if (feature_shard_count_map_iter == feature_shard_count_map->end()) { task_context_list->emplace_back(TFeatureParallelTask1Context()); task_context = &(task_context_list->back()); task_context->example_batch_feature_index = n_i; task_context->feature_output_index = i; (*feature_shard_count_map)[i] = task_context; (*table_feature_map)[table_index].push_back(task_context); } else { task_context = feature_shard_count_map_iter->second; } const auto &named_feature_list = example_batch.named_feature_list(n_i); int fid_size = 0; if (named_feature_list.type() == FeatureListType::SHARED) { const auto &feature = named_feature_list.feature(0); int tmp_counter = 0; if (feature.has_fid_v1_list()) { tmp_counter = feature.fid_v1_list().value_size(); } else if (feature.has_fid_v2_list()) { tmp_counter = feature.fid_v2_list().value_size(); } fid_size += tmp_counter; feature_counter.push_back(tmp_counter); shared_feature->insert(i); *all_feature_counter_size += 1; } else { feature_counter.reserve(batch_size); *all_feature_counter_size += batch_size; for (const auto &feature : named_feature_list.feature()) { int tmp_counter = 0; if (feature.has_fid_v1_list()) { tmp_counter = feature.fid_v1_list().value_size(); } else if (feature.has_fid_v2_list()) { tmp_counter = feature.fid_v2_list().value_size(); } fid_size += tmp_counter; feature_counter.push_back(tmp_counter); } } if (fid_size > 0) { task_context->size += fid_size; } } } template void ShardingSparseFidsOp::FeatureParallelMakeUpTask1Context( const std::vector &examples, int batch_size, std::vector *task_context_list, absl::flat_hash_map *feature_shard_count_map, absl::flat_hash_map> *table_feature_map, std::vector> *all_feature_counter, std::unordered_set *shared_feature, int *all_feature_counter_size) { task_context_list->resize(feature_conf_.size()); for (auto &feature_counter : *all_feature_counter) { feature_counter.resize(batch_size, 0); } *all_feature_counter_size = batch_size * feature_conf_.size(); for (uint i = 0; i < task_context_list->size(); ++i) { auto *task_context = &((*task_context_list)[i]); task_context->named_feature_ptr_list.reserve(examples.size()); task_context->feature_sample_index.reserve(examples.size()); task_context->feature_output_index = i; (*feature_shard_count_map)[feature_cfg_list_[i]->feature_index] = task_context; (*table_feature_map)[feature_cfg_list_[i]->table_index].push_back( task_context); } for (uint ex_i = 0; ex_i < examples.size(); ++ex_i) { const ::monolith::io::proto::Example *example = examples[ex_i]; CHECK_NOTNULL(example); for (const auto &named_feature : example->named_feature()) { int fid_size = 0; const auto &feature = named_feature.feature(); fid_size = feature.fid_v2_list().value_size(); if (fid_size == 0) { fid_size = feature.fid_v1_list().value_size(); } if (fid_size <= 0) continue; int feature_index = -1; auto sorted_id = named_feature.sorted_id(); if (sorted_id > 0) { // 优先利用id查找,比string查找更快 if (feature_index_conf_.size()) { if (sorted_id >= feature_index_conf_.size()) { continue; } auto feature_ptr = feature_index_conf_.at(sorted_id); if (feature_ptr == nullptr) { continue; } feature_index = feature_ptr->feature_index; // CHECK_EQ(named_feature.name(), feature_ptr->feature_name); } else { LOG_EVERY_N_SEC(ERROR, 10) << "FeatureNameMapper error"; } } else { const auto &name = named_feature.name(); auto find_iter = feature_conf_.find(name); if (find_iter == feature_conf_.end()) { continue; } feature_index = find_iter->second.feature_index; } CHECK(feature_index >= 0 && feature_index < task_context_list->size()); auto &task_context = (*task_context_list)[feature_index]; auto &feature_counter = (*all_feature_counter)[feature_index]; feature_counter[ex_i] = fid_size; task_context.size += fid_size; task_context.named_feature_ptr_list.push_back(&named_feature); task_context.feature_sample_index.push_back(feature_index * batch_size + ex_i); } } } template void ShardingSparseFidsOp::FeatureParallelMakeUpTask1Context( const ShardingSparseFidsOp::InstanceWrapper &instance_wrapper, int batch_size, std::vector *task_context_list, absl::flat_hash_map *feature_shard_count_map, absl::flat_hash_map> *table_feature_map, std::vector> *all_feature_counter, std::unordered_set *shared_feature, int *all_feature_counter_size) { task_context_list->resize(feature_conf_.size()); for (auto &feature_counter : *all_feature_counter) { feature_counter.resize(batch_size, 0); } *all_feature_counter_size = batch_size * feature_conf_.size(); std::vector>> feature_named_feature_ptr_list(feature_conf_.size()); for (auto &elem : feature_named_feature_ptr_list) { elem.reserve(batch_size); } // fid v1 for (auto &iter : instance_wrapper.fid_v1) { auto find_iter = feature_conf_.find(iter.first); if (find_iter == feature_conf_.end()) { continue; } auto &named_feature_ptr_list = feature_named_feature_ptr_list[find_iter->second.feature_index]; for (uint ex_i = 0; ex_i < iter.second.size(); ++ex_i) { named_feature_ptr_list.push_back(std::make_pair( InstanceWrapper::FeaturePtr(&iter.second[ex_i], nullptr), ex_i)); } } // fid v2 for (uint ex_i = 0; ex_i < instance_wrapper.instances.size(); ++ex_i) { const auto *instance = instance_wrapper.instances[ex_i]; CHECK_NOTNULL(instance); for (const auto &named_feature : instance->feature()) { const auto &name = named_feature.name(); auto find_iter = feature_conf_.find(name); if (find_iter == feature_conf_.end()) { continue; } auto &named_feature_ptr_list = feature_named_feature_ptr_list[find_iter->second.feature_index]; named_feature_ptr_list.push_back(std::make_pair( InstanceWrapper::FeaturePtr({nullptr, &named_feature}), ex_i)); } } for (uint i = 0; i < feature_named_feature_ptr_list.size(); ++i) { auto *task_context = &((*task_context_list)[i]); task_context->instance_feature_ptr_list.reserve(batch_size); task_context->feature_sample_index.reserve(batch_size); task_context->feature_output_index = i; (*feature_shard_count_map)[feature_cfg_list_[i]->feature_index] = task_context; (*table_feature_map)[feature_cfg_list_[i]->table_index].push_back( task_context); auto &feature_counter = (*all_feature_counter)[i]; for (uint j = 0; j < feature_named_feature_ptr_list[i].size(); ++j) { auto &info = feature_named_feature_ptr_list[i][j]; auto named_feature = info.first; auto ex_i = info.second; int fid_size = 0; if (named_feature.fid_v1) { fid_size = named_feature.fid_v1->size(); } else { fid_size += named_feature.fid_v2->fid_size(); // this is a sequence feature list. for (const auto &fidlist : named_feature.fid_v2->fid_list()) { fid_size += fidlist.value_size(); } } if (fid_size > 0) { feature_counter[ex_i] = fid_size; task_context->size += fid_size; task_context->instance_feature_ptr_list.push_back(named_feature); task_context->feature_sample_index.push_back(i * batch_size + ex_i); } } } } template void ShardingSparseFidsOp::FeatureParallelDoTask1( const ::monolith::io::proto::ExampleBatch &example_batch, TFeatureParallelTask1Context *task_context, std::vector &nfl_fid_offset, tensorflow::TTypes::Flat fid_offset_flat, tensorflow::TTypes::Flat feature_offset_flat) { auto &shard_vec = task_context->fid_list; auto &uniq_hashtable = task_context->uniq_hashtable; const auto &named_feature_list = example_batch.named_feature_list( task_context->example_batch_feature_index); auto feature_output_index = task_context->feature_output_index; auto offset = nfl_fid_offset[feature_output_index]; if (named_feature_list.type() == FeatureListType::SHARED) { const auto &feature = named_feature_list.feature(0); if (feature.has_fid_v1_list()) { for (int i = 0; i < feature.fid_v1_list().value_size(); ++i) { auto value = convert_fid_v1_to_v2(feature.fid_v1_list().value(i)); FillFidList(value, shard_vec, uniq_hashtable, fid_offset_flat, feature_output_index, &offset); } } else if (feature.has_fid_v2_list()) { for (int i = 0; i < feature.fid_v2_list().value_size(); ++i) { auto value = feature.fid_v2_list().value(i); FillFidList(value, shard_vec, uniq_hashtable, fid_offset_flat, feature_output_index, &offset); } } } else { for (const auto &feature : named_feature_list.feature()) { if (feature.has_fid_v1_list()) { for (int i = 0; i < feature.fid_v1_list().value_size(); ++i) { auto value = convert_fid_v1_to_v2(feature.fid_v1_list().value(i)); FillFidList(value, shard_vec, uniq_hashtable, fid_offset_flat, feature_output_index, &offset); } } else if (feature.has_fid_v2_list()) { for (int i = 0; i < feature.fid_v2_list().value_size(); ++i) { auto value = feature.fid_v2_list().value(i); FillFidList(value, shard_vec, uniq_hashtable, fid_offset_flat, feature_output_index, &offset); } } } } task_context->feature_offset = offset - nfl_fid_offset[feature_output_index]; } template void ShardingSparseFidsOp::FeatureParallelDoTask1( const std::vector &examples, TFeatureParallelTask1Context *task_context, std::vector &nfl_fid_offset, tensorflow::TTypes::Flat fid_offset_flat, tensorflow::TTypes::Flat feature_offset_flat) { auto &shard_vec = task_context->fid_list; auto &uniq_hashtable = task_context->uniq_hashtable; auto feature_output_index = task_context->feature_output_index; int offset = 0; for (uint sub_task_index = 0; sub_task_index < task_context->named_feature_ptr_list.size(); ++sub_task_index) { const auto named_feature_ptr = task_context->named_feature_ptr_list[sub_task_index]; offset = feature_offset_flat(task_context->feature_sample_index[sub_task_index]); const auto &feature = named_feature_ptr->feature(); if (feature.has_fid_v1_list()) { for (int i = 0; i < feature.fid_v1_list().value_size(); ++i) { auto value = convert_fid_v1_to_v2(feature.fid_v1_list().value(i)); FillFidList(value, shard_vec, uniq_hashtable, fid_offset_flat, feature_output_index, &offset); } } else if (feature.has_fid_v2_list()) { for (int i = 0; i < feature.fid_v2_list().value_size(); ++i) { auto value = feature.fid_v2_list().value(i); FillFidList(value, shard_vec, uniq_hashtable, fid_offset_flat, feature_output_index, &offset); } } } task_context->feature_offset = offset - nfl_fid_offset[feature_output_index]; } template void ShardingSparseFidsOp::FeatureParallelDoTask1( const ShardingSparseFidsOp::InstanceWrapper &instance_wrapper, TFeatureParallelTask1Context *task_context, std::vector &nfl_fid_offset, tensorflow::TTypes::Flat fid_offset_flat, tensorflow::TTypes::Flat feature_offset_flat) { auto &shard_vec = task_context->fid_list; auto &uniq_hashtable = task_context->uniq_hashtable; auto feature_output_index = task_context->feature_output_index; int offset = 0; for (uint sub_task_index = 0; sub_task_index < task_context->instance_feature_ptr_list.size(); ++sub_task_index) { const auto &named_feature_ptr = task_context->instance_feature_ptr_list[sub_task_index]; offset = feature_offset_flat(task_context->feature_sample_index[sub_task_index]); if (named_feature_ptr.fid_v1) { for (auto value : *named_feature_ptr.fid_v1) { value = convert_fid_v1_to_v2(value); FillFidList(value, shard_vec, uniq_hashtable, fid_offset_flat, feature_output_index, &offset); } } else { for (const auto &value : named_feature_ptr.fid_v2->fid()) { FillFidList(value, shard_vec, uniq_hashtable, fid_offset_flat, feature_output_index, &offset); } // this is a sequence feature list. for (const auto &fid_list : named_feature_ptr.fid_v2->fid_list()) { for (const auto &value : fid_list.value()) { FillFidList(value, shard_vec, uniq_hashtable, fid_offset_flat, feature_output_index, &offset); } } } } task_context->feature_offset = offset - nfl_fid_offset[feature_output_index]; } template Status ShardingSparseFidsOp::FeatureParallelParse( OpKernelContext *ctx, const TInput &input, OpOutputList *fid_list_out_list, OpOutputList *fid_list_row_splits_out_list, OpOutputList *fid_list_row_splits_size_out_list) { int batch_size = GetBatchSize(input); struct FeatureParallelTask1Context { std::vector> fid_list; MultiShardUniqHashTable uniq_hashtable; int example_batch_feature_index = -1; // use for example_batch std::vector named_feature_ptr_list; // use for example std::vector instance_feature_ptr_list; // for instance int feature_output_index = -1; std::vector feature_sample_index; // use for example int size = 0; std::vector table_offset; // (ps_num_, 0); int feature_offset = 0; }; std::vector task_context_list; absl::flat_hash_map feature_shard_count_map; absl::flat_hash_map> table_feature_map; Tensor *fid_offset_tensor, *feature_offset_tensor; std::vector tmp_tensor_list; std::vector> fid_list_row_splits_flat_list; std::vector nfl_fid_offset(feature_conf_.size()); { profiler::TraceMe activity([]() { return "ShardingSparseFidsOp::Alloc"; }); feature_shard_count_map.reserve(feature_conf_.size()); task_context_list.reserve(feature_conf_.size()); table_feature_map.reserve(table_cfg_list_.size()); std::vector> all_feature_counter(feature_conf_.size()); int all_feature_counter_size = 0; std::unordered_set shared_feature; FeatureParallelMakeUpTask1Context( input, batch_size, &task_context_list, &feature_shard_count_map, &table_feature_map, &all_feature_counter, &shared_feature, &all_feature_counter_size); for (auto &task_context : task_context_list) { task_context.fid_list.resize(ps_num_); task_context.uniq_hashtable.resize(ps_num_); task_context.table_offset.resize(ps_num_, 0); if (task_context.size > 0) { int reserve_size = task_context.size * 6 / 5 / ps_num_; if (unique_) { task_context.uniq_hashtable.reserve(reserve_size); } else { for (auto &fid_list_part : task_context.fid_list) { fid_list_part.reserve(reserve_size); } } } } Tensor *nfl_offset_tensor; TF_RETURN_IF_ERROR( ctx, CreateOffsetTensor(ctx, all_feature_counter, all_feature_counter_size, &nfl_offset_tensor, &feature_offset_tensor, &fid_offset_tensor, fid_list_row_splits_out_list, fid_list_row_splits_size_out_list, tmp_tensor_list, fid_list_row_splits_flat_list, &nfl_fid_offset, &shared_feature)); } auto fid_offset_flat = fid_offset_tensor->flat(); auto feature_offset_flat = feature_offset_tensor->flat(); std::vector> task_split; SplitTask( task_context_list, single_thread_feature_watermark_, &task_split); { profiler::TraceMe activity([]() { return "ShardingSparseFidsOp::AddVec"; }); activity.AppendMetadata([&task_split, &task_context_list] { return profiler::TraceMeEncode({{"task_num", task_context_list.size()}, {"split_num", task_split.size()}}); }); std::vector capacities(task_split.size(), 0); auto task_func = [this, &task_context_list, &input, &task_split, &nfl_fid_offset, &fid_offset_flat, &feature_offset_flat, &capacities](const int64 begin, const int64 end) { UniqHashTable uniq_hashtable; for (int64 task_index = begin; task_index < end; ++task_index) { auto &task_index_list = task_split[task_index]; for (auto index : task_index_list) { auto &task_context = task_context_list[index]; if (unique_) { uniq_hashtable.Reset(); task_context.uniq_hashtable.init(&uniq_hashtable); } FeatureParallelDoTask1(input, &task_context, nfl_fid_offset, fid_offset_flat, feature_offset_flat); } capacities[task_index] = uniq_hashtable.Capacity(); } }; ParallelRun(ctx, task_split.size(), task_func); double avg_capacity = 0; if (capacities.size() > 0) { avg_capacity = static_cast(std::accumulate(capacities.begin(), capacities.end(), 0)) / capacities.size(); } activity.AppendMetadata([&avg_capacity] { return profiler::TraceMeEncode({{"hashtable_size", avg_capacity}}); }); monolith::GetMetrics()->emit_timer("sharding_sparse_fids_op_hashtable_capacity", avg_capacity); } struct TaskContext2 { // fill fid_list TaskContext2(const TensorSliceAccessor &accessor_, std::vector *shard_fid_list_, std::vector *shard_ptr_, int size_, int offset_) : accessor(accessor_), shard_fid_list(shard_fid_list_), shard_ptr(shard_ptr_), size(size_), offset(offset_) {} // rewrite fid_offset explicit TaskContext2(FeatureParallelTask1Context *task_context_) : task1_context(task_context_), size(task_context_->feature_offset) {} TensorSliceAccessor accessor; std::vector *shard_fid_list = nullptr; std::vector *shard_ptr = nullptr; int size = -1; int offset = -1; FeatureParallelTask1Context *task1_context = nullptr; }; std::vector task2_context_list; Tensor *fid_list_table_row_length_tensor = nullptr; Tensor *fid_list_shard_row_lenth_tensor = nullptr; Tensor *fid_list_emb_row_lenth_tensor = nullptr; std::vector> fid_list_tensor_vec( table_cfg_list_.size() * ps_num_); if (version_ == 4) { int size_record_total = 0; std::vector size_record(table_cfg_list_.size() * ps_num_, 0); for (uint table_index = 0; table_index < table_cfg_list_.size(); ++table_index) { // auto &table_name = table_names_[table_index]; auto table_feature_map_find_iter = table_feature_map.find(table_index); std::vector *feature_vec_ptr = nullptr; if (table_feature_map_find_iter != table_feature_map.end()) { feature_vec_ptr = &(table_feature_map_find_iter->second); } for (int ps_num_i = 0; ps_num_i < ps_num_; ++ps_num_i) { int &size = size_record[table_index * ps_num_ + ps_num_i]; if (feature_vec_ptr != nullptr) { for (auto task_context_ptr : *feature_vec_ptr) { if (unique_) { size += task_context_ptr->uniq_hashtable.fid_num(ps_num_i); } else { size += task_context_ptr->fid_list[ps_num_i].size(); } } } size_record_total += size; } } Tensor *fid_list_tensor; TF_RETURN_IF_ERROR(ctx, ctx->allocate_output("fid_list", TensorShape({ size_record_total, }), &fid_list_tensor)); fid_list_tensor->flat().setZero(); int pre_count = 0; for (uint ps_num_i = 0; ps_num_i < ps_num_; ++ps_num_i) { for (uint table_index = 0; table_index < table_cfg_list_.size(); ++table_index) { int size = size_record[table_index * ps_num_ + ps_num_i]; fid_list_tensor_vec[table_index * ps_num_ + ps_num_i] = TensorSliceAccessor( {static_cast(fid_list_tensor->data()) + pre_count, size}); // fid_list_tensor->Slice(pre_count, pre_count + size); pre_count += size; } } TF_RETURN_IF_ERROR( ctx, ctx->allocate_output("fid_list_table_row_length", tensorflow::TensorShape({ ps_num_ * table_cfg_list_.size(), }), &fid_list_table_row_length_tensor)); fid_list_table_row_length_tensor->flat().setZero(); TF_RETURN_IF_ERROR(ctx, ctx->allocate_output("fid_list_shard_row_lenth", tensorflow::TensorShape({ ps_num_, }), &fid_list_shard_row_lenth_tensor)); fid_list_shard_row_lenth_tensor->flat().setZero(); TF_RETURN_IF_ERROR( ctx, ctx->allocate_output("fid_list_emb_row_lenth", tensorflow::TensorShape({ ps_num_ * table_cfg_list_.size(), }), &fid_list_emb_row_lenth_tensor)); fid_list_emb_row_lenth_tensor->flat().setZero(); } int index = -1; Tensor *emb_size_tensor; if (version_ == 5) { TF_RETURN_IF_ERROR(ctx, ctx->allocate_output("emb_size", TensorShape({ table_cfg_list_.size() * ps_num_, }), &emb_size_tensor)); } for (uint table_index = 0; table_index < table_cfg_list_.size(); ++table_index) { // auto &table_name = table_names_[table_index]; auto table_feature_map_find_iter = table_feature_map.find(table_index); std::vector *feature_vec_ptr = nullptr; if (table_feature_map_find_iter != table_feature_map.end()) { feature_vec_ptr = &(table_feature_map_find_iter->second); } for (int ps_num_i = 0; ps_num_i < ps_num_; ++ps_num_i) { auto &cur_tensor_flat = fid_list_row_splits_flat_list[table_index * ps_num_ + ps_num_i]; int size = 0; int pre_offset = 0; if (feature_vec_ptr != nullptr) { for (auto task_context_ptr : *feature_vec_ptr) { task_context_ptr->table_offset[ps_num_i] = pre_offset; int cur_fid_size = 0; if (unique_) { cur_fid_size = task_context_ptr->uniq_hashtable.fid_num(ps_num_i); } else { cur_fid_size = task_context_ptr->fid_list[ps_num_i].size(); } size += cur_fid_size; // std::cerr << "cur_fid_size: " << cur_fid_size << " size: " << size // << std::endl << std::flush; if (version_ == 3 || version_ == 4 || version_ == 5) { int emb_size = cur_fid_size * feature_cfg_list_[task_context_ptr->feature_output_index] ->dims_sum; if (version_ == 4) { auto fid_list_emb_row_lenth_flat = fid_list_emb_row_lenth_tensor->flat(); fid_list_emb_row_lenth_flat(ps_num_i * table_cfg_list_.size() + table_index) += emb_size; } pre_offset += emb_size; } else { pre_offset = size; } auto cur_tensor_flat_index = feature_cfg_list_[task_context_ptr->feature_output_index] ->feature_in_table_index + 1; cur_tensor_flat(cur_tensor_flat_index) = cur_fid_size; } } if (version_ == 5) { emb_size_tensor->flat()(table_index * ps_num_ + ps_num_i) = pre_offset; } TensorSliceAccessor cur_accessor; for (uint z = 2; z <= table_cfg_list_[table_index]->feature_count; ++z) { cur_tensor_flat(z) += cur_tensor_flat(z - 1); } if (version_ != 4) { Tensor *cur_tensor; TF_RETURN_IF_ERROR( ctx, fid_list_out_list->allocate(++index, tensorflow::TensorShape{size}, &cur_tensor)); if (size == 0) { std::memset(cur_tensor->data(), 0, cur_tensor->TotalBytes()); continue; } cur_accessor = TensorSliceAccessor( {static_cast(cur_tensor->data()), cur_tensor->NumElements()}); } else { cur_accessor = fid_list_tensor_vec[++index]; fid_list_table_row_length_tensor->flat()( ps_num_i * table_cfg_list_.size() + table_index) += size; fid_list_shard_row_lenth_tensor->flat()(ps_num_i) += size; } int offset = 0; for (auto task_context_ptr : *feature_vec_ptr) { std::vector *shard_fid_list = nullptr; std::vector *shard_ptr = nullptr; int tmp_size = 0; if (unique_) { shard_fid_list = &(task_context_ptr->uniq_hashtable.fid_list(ps_num_i)); tmp_size = shard_fid_list->size(); } else { shard_ptr = &(task_context_ptr->fid_list[ps_num_i]); tmp_size = shard_ptr->size(); } DCHECK_LE(offset + tmp_size, size); TaskContext2 tmp_task_context(cur_accessor, shard_fid_list, shard_ptr, tmp_size, offset); task2_context_list.emplace_back(tmp_task_context); offset += tmp_size; } } if (version_ != 2) { if (feature_vec_ptr) { for (auto task_context_ptr : *feature_vec_ptr) { TaskContext2 tmp_task_context(task_context_ptr); task2_context_list.emplace_back(tmp_task_context); } } } } std::vector> task2_split; SplitTask(task2_context_list, single_thread_assign_watermark_, &task2_split); { profiler::TraceMe activity([]() { return "ShardingSparseFidsOp::Copy"; }); auto tensor_assign_func = [this, ctx, &task2_context_list, &task2_split, &nfl_fid_offset, &fid_offset_flat]( const int64 begin, const int64 end) { for (int64 task_index = begin; task_index < end; ++task_index) { auto &task_index_list = task2_split[task_index]; for (auto index : task_index_list) { auto &task_context = task2_context_list[index]; if (task_context.task1_context) { auto &task1_context = *task_context.task1_context; auto offset = nfl_fid_offset[task1_context.feature_output_index]; /* auto table_index = feature_cfg_list_[task1_context.feature_output_index] ->table_index * ps_num_; for (int i = 0; i < task1_context.feature_offset; ++i, ++offset) { auto cur = fid_offset_flat(offset); cur += task1_context.table_offset.at(static_cast(cur >> 32) - table_index); fid_offset_flat(offset) = cur; }*/ auto feature_cfg = feature_cfg_list_[task1_context.feature_output_index]; for (int i = 0; i < task1_context.feature_offset; ++i, ++offset) { auto cur = fid_offset_flat(offset); cur += task1_context.table_offset.at( feature_cfg->GetPsShard(static_cast(cur >> 32))); fid_offset_flat(offset) = cur; } } else if (unique_) { CopyFidList(*task_context.shard_fid_list, task_context.offset, &task_context.accessor); } else { CopyFidList(*task_context.shard_ptr, task_context.offset, &task_context.accessor); } } } }; ParallelRun(ctx, task2_split.size(), tensor_assign_func); } return Status::OK(); } } // namespace monolith_tf } // namespace tensorflow #endif MONOLITH_NATIVE_TRAINING_DATA_KERNELS_PARSE_SPARSE_FEATURE_LIB_H_ ================================================ FILE: monolith/native_training/data/kernels/pb_dataset_kernel.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "absl/strings/str_format.h" #include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/io/buffered_inputstream.h" #include "tensorflow/core/lib/io/inputbuffer.h" #include "tensorflow/core/lib/io/zlib_compression_options.h" #include "monolith/native_training/data/kernels/feature_name_mapper_tf_bridge.h" #include "monolith/native_training/data/training_instance/cc/data_reader.h" #include "monolith/native_training/runtime/common/metrics.h" #include "third_party/nlohmann/json.hpp" namespace tensorflow { namespace data { namespace monolith_tf { namespace { using Example = ::monolith::io::proto::Example; using ExampleBatch = ::monolith::io::proto::ExampleBatch; using Instance = ::parser::proto::Instance; using ::tensorflow::monolith_tf::BaseStreamReader; using ::tensorflow::monolith_tf::DataFormatOptions; using ::tensorflow::monolith_tf::ExampleBatchIterator; using ::tensorflow::monolith_tf::ExampleToInstance; using ::tensorflow::monolith_tf::FeatureNameMapper; using ::tensorflow::monolith_tf::FeatureNameMapperTfBridge; using ::tensorflow::monolith_tf::FeaturePruningType; using ::tensorflow::monolith_tf::FileStreamReader; using ::tensorflow::monolith_tf::InputCompressType; using ::tensorflow::monolith_tf::InstanceToExample; using ::tensorflow::monolith_tf::PBIterator; using ::tensorflow::monolith_tf::PBIteratorWithDataFormatTrans; using ::tensorflow::monolith_tf::PBIteratorWithDataFormatTransBaseOutput; using ::tensorflow::monolith_tf::StdinStreamReader; struct DsOptions : DataFormatOptions { bool use_snappy = false; int32 compression_type = InputCompressType::UNKNOW; int64 buffer_size = 64 * 1024 * 1024; }; } // namespace // This is the instance dataset op and used in the estimator as input fn. class PBDatasetOp : public DatasetOpKernel { public: static constexpr const char *const kDatasetType = "PbDataset"; static constexpr const char *const kFileName = "file_name"; static constexpr const char *const kBufferSize = "buffer_size"; static constexpr const char *const kUseSnappy = "use_snappy"; static constexpr const char *const kLagrangexHeader = "lagrangex_header"; static constexpr const char *const kHasSortId = "has_sort_id"; static constexpr const char *const kKafkaDump = "kafka_dump"; static constexpr const char *const kKafkaDumpPrefix = "kafka_dump_prefix"; static constexpr const char *const kInputPbType = "input_pb_type"; static constexpr const char *const kOutputPbType = "output_pb_type"; static constexpr const char *const kOutType = "out_type"; static constexpr const char *const kFeaturePruningType = "feature_pruning_type"; static constexpr const char *const kFeatureNameList = "feature_name_list"; static constexpr const char *const kFeatureIdList = "feature_id_list"; static constexpr const char *const kCompressionType = "compression_type"; explicit PBDatasetOp(OpKernelConstruction *ctx) : DatasetOpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutType, &out_type_)); OP_REQUIRES_OK(ctx, ctx->GetAttr(kCompressionType, &compression_type_)); auto creator = [this](FeatureNameMapperTfBridge **out_mapper) { TF_RETURN_IF_ERROR(FeatureNameMapperTfBridge::New(out_mapper)); return Status::OK(); }; ResourceMgr *resource_mgr = ctx->resource_manager(); OP_REQUIRES_OK(ctx, resource_mgr->LookupOrCreate( resource_mgr->default_container(), FeatureNameMapperTfBridge::kName, &mapper_, creator)); } ~PBDatasetOp() override { mapper_->Unref(); }; private: void MakeDataset(OpKernelContext *ctx, DatasetBase **output) override { tstring file_name; DsOptions options; OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kFileName, &file_name)); OP_REQUIRES_OK( ctx, ParseScalarArgument(ctx, kUseSnappy, &options.use_snappy)); options.compression_type = compression_type_; OP_REQUIRES_OK( ctx, ParseScalarArgument(ctx, kHasSortId, &options.has_sort_id)); OP_REQUIRES_OK( ctx, ParseScalarArgument(ctx, kKafkaDump, &options.kafka_dump)); OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kKafkaDumpPrefix, &options.kafka_dump_prefix)); OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kBufferSize, &options.buffer_size)); OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kLagrangexHeader, &options.lagrangex_header)); tstring input_pb_type; OP_REQUIRES_OK( ctx, ParseScalarArgument(ctx, kInputPbType, &input_pb_type)); tstring output_pb_type; OP_REQUIRES_OK( ctx, ParseScalarArgument(ctx, kOutputPbType, &output_pb_type)); int feature_pruning_type = 0; OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kFeaturePruningType, &feature_pruning_type)); std::vector feature_name_list; OP_REQUIRES_OK(ctx, ParseVectorArgument(ctx, kFeatureNameList, &feature_name_list)); std::vector feature_id_list; OP_REQUIRES_OK(ctx, ParseVectorArgument(ctx, kFeatureIdList, &feature_id_list)); if (feature_name_list.size() != feature_id_list.size()) { LOG(FATAL) << absl::StrFormat( "feature_name_list/feature_id_list size should match, while got %ld " "vs %ld", feature_name_list.size(), feature_id_list.size()); } std::unordered_set feature_name_set(feature_name_list.begin(), feature_name_list.end()); std::unordered_set feature_id_set(feature_id_list.begin(), feature_id_list.end()); if (feature_name_list.size() != feature_name_set.size()) { LOG(FATAL) << "feature name list has duplicates, please investigate and retry !"; } if (feature_id_set.size() > feature_name_set.size()) { LOG(FATAL) << "feature_name -> feature_id should be non-injective and " "surjective, that is feature_id_set.size() should be <= " "feature_name_set.size(), please investigate and retry !"; } output_ = new Dataset(ctx, file_name, options, input_pb_type, output_pb_type, out_type_, feature_pruning_type, feature_name_list, feature_id_list, mapper_->GetFeatureNameMapper()); *output = output_; nlohmann::json j; j[kFileName] = file_name; j[kUseSnappy] = options.use_snappy; j[kCompressionType] = options.compression_type; j[kHasSortId] = options.has_sort_id; j[kKafkaDump] = options.kafka_dump; j[kKafkaDumpPrefix] = options.kafka_dump_prefix; j[kBufferSize] = options.buffer_size; j[kLagrangexHeader] = options.lagrangex_header; j[kInputPbType] = input_pb_type; j[kOutputPbType] = output_pb_type; j[kFeaturePruningType] = feature_pruning_type; LOG(INFO) << j.dump(); } class Dataset : public DatasetBase { public: explicit Dataset(OpKernelContext *ctx, tstring file_name, const DsOptions &options, std::string input_pb_type, std::string output_pb_type, DataType out_type, int feature_pruning_type, std::vector feature_name_list, std::vector feature_id_list, FeatureNameMapper *mapper) : DatasetBase(DatasetContext(ctx)), file_name_(std::move(file_name)), options_(options), input_pb_type_(std::move(input_pb_type)), output_pb_type_(std::move(output_pb_type)), out_type_(out_type), feature_pruning_type_(feature_pruning_type), feature_name_list_(std::move(feature_name_list)), feature_id_list_(std::move(feature_id_list)), mapper_(mapper) { absl::flat_hash_map name_to_id; absl::flat_hash_map> id_to_name; for (size_t i = 0; i < feature_name_list_.size(); ++i) { name_to_id.insert({feature_name_list_[i], feature_id_list_[i]}); id_to_name[feature_id_list_[i]].push_back(feature_name_list_[i]); } CHECK(mapper_->SetMapping(name_to_id, id_to_name)); if (input_pb_type == "examplebatch" && output_pb_type == "example") { // mapper_->TurnOn(); } // LOG_FIRST_N(INFO, 1) << "NameToId: " << mapper_->DebugString(); } std::unique_ptr MakeIteratorInternal( const string &prefix) const override { return absl::make_unique( Iterator::Params{this, strings::StrCat(prefix, "::", kDatasetType)}, mapper_, input_pb_type_, output_pb_type_); } const DataTypeVector &output_dtypes() const override { static auto *dtypes = new DataTypeVector({out_type_}); return *dtypes; } const std::vector &output_shapes() const override { static auto *shapes = new std::vector{TensorShape({})}; return *shapes; } string DebugString() const override { return ("This is the customized Instance Dataset: " + file_name_); } Status CheckExternalState() const override { return Status::OK(); } private: Status AsGraphDefInternal(SerializationContext *ctx, DatasetGraphDefBuilder *b, Node **output) const override { Node *filename = nullptr; TF_RETURN_IF_ERROR(b->AddScalar(file_name_, &filename)); Node *use_snappy = nullptr; TF_RETURN_IF_ERROR(b->AddScalar(options_.use_snappy, &use_snappy)); Node *has_sort_id = nullptr; TF_RETURN_IF_ERROR(b->AddScalar(options_.has_sort_id, &has_sort_id)); Node *kafka_dump = nullptr; TF_RETURN_IF_ERROR(b->AddScalar(options_.kafka_dump, &kafka_dump)); Node *kafka_dump_prefix = nullptr; TF_RETURN_IF_ERROR( b->AddScalar(options_.kafka_dump_prefix, &kafka_dump_prefix)); Node *buffer_size = nullptr; TF_RETURN_IF_ERROR(b->AddScalar(options_.buffer_size, &buffer_size)); Node *lagrangex_header = nullptr; TF_RETURN_IF_ERROR( b->AddScalar(options_.lagrangex_header, &lagrangex_header)); Node *input_pb_type = nullptr; TF_RETURN_IF_ERROR(b->AddScalar(input_pb_type_, &input_pb_type)); Node *output_pb_type = nullptr; TF_RETURN_IF_ERROR(b->AddScalar(output_pb_type_, &output_pb_type)); Node *feature_pruning_type = nullptr; TF_RETURN_IF_ERROR( b->AddScalar(feature_pruning_type_, &feature_pruning_type)); Node *feature_name_list = nullptr; TF_RETURN_IF_ERROR(b->AddVector(feature_name_list_, &feature_name_list)); Node *feature_id_list = nullptr; TF_RETURN_IF_ERROR(b->AddVector(feature_id_list_, &feature_id_list)); AttrValue out_type; b->BuildAttrValue(out_type_, &out_type); AttrValue compression_type; b->BuildAttrValue(options_.compression_type, &compression_type); TF_RETURN_IF_ERROR(b->AddDataset( this, {filename, use_snappy, has_sort_id, kafka_dump, kafka_dump_prefix, buffer_size, lagrangex_header, input_pb_type, output_pb_type, feature_pruning_type, feature_name_list, feature_id_list}, {{kOutType, out_type}, {kCompressionType, compression_type}}, output)); return Status::OK(); } class Iterator : public DatasetIterator { public: explicit Iterator(const Params ¶ms, FeatureNameMapper *mapper, const tstring &input_pb_type, const tstring &output_pb_type) : DatasetIterator(params), mapper_(mapper) { mutex_lock l(mu_); offset_ = 0; input_pb_type_ = ::tensorflow::monolith_tf::data_format::StringToDataFormat( input_pb_type); output_pb_type_ = ::tensorflow::monolith_tf::data_format::StringToDataFormat( output_pb_type); if (input_pb_type_ == ::tensorflow::monolith_tf::data_format::UNKNOW || output_pb_type_ == ::tensorflow::monolith_tf::data_format::UNKNOW) { LOG(FATAL) << "dataformat error:" << input_pb_type << " or " << output_pb_type; } } class CurPBIteratorHandler { public: struct CurOutput : public PBIteratorWithDataFormatTransBaseOutput { std::vector *out_tensors; size_t size = 0; }; template Status HandleReaderNextStauts(const Status &s, const TResult &result) { return Status::OK(); } template Status HandleResult(TResult &&result, CurOutput *output) { output->size = result.ByteSize(); output->out_tensors->back().scalar()() = std::move(result); return Status::OK(); } Status HandleResult(tstring &&serialized, CurOutput *output) { output->out_tensors->back().scalar()() = std::move(serialized); return Status::OK(); } }; Status GetNextInternal(IteratorContext *ctx, std::vector *out_tensors, bool *end_of_sequence) override { out_tensors->reserve(1); mutex_lock l(mu_); if (!reader_) { TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env())); } out_tensors->emplace_back(ctx->allocator({}), dataset()->out_type_, TensorShape({})); PBIteratorWithDataFormatTrans cur_iter( input_pb_type_, output_pb_type_); CurPBIteratorHandler::CurOutput output; output.out_tensors = out_tensors; cur_iter.GetNext(reader_.get(), &output, &offset_); Status s = output.reader_status; if (s.ok()) { static monitoring::CounterCell *bytes_counter = metrics::GetTFDataBytesReadCounter(kDatasetType); bytes_counter->IncrementBy(output.size); *end_of_sequence = false; num_random_samples_++; offset_ = reader_->GetOffset(); if (num_random_samples_ % metric_emit_step_ == 0) { LOG_EVERY_N_SEC(INFO, 300) << absl::StrFormat( "metrics_emit(counter) [instance_num] emit=%llu, " "total_instance_num=%lld", metric_emit_step_, num_random_samples_); monolith::GetMetrics()->emit_counter("instance_num", metric_emit_step_); } return Status::OK(); } out_tensors->pop_back(); ResetStreamsLocked(); if (errors::IsOutOfRange(s)) { *end_of_sequence = true; int64 unsubmit_instance_num = num_random_samples_ % metric_emit_step_; if (unsubmit_instance_num > 0) { LOG(INFO) << absl::StrFormat( "metrics_emit(counter) [instance_num] emit=%lld, " "total_instance_num=%lld, end_of_sequence", unsubmit_instance_num, num_random_samples_); monolith::GetMetrics()->emit_counter("instance_num", unsubmit_instance_num); } return Status::OK(); } return s; } private: std::shared_ptr CreateNode( IteratorContext *ctx, model::Node::Args args) const override { return model::MakeSourceNode(std::move(args)); } Status SaveInternal(SerializationContext *ctx, IteratorStateWriter *writer) override { mutex_lock l(mu_); LOG(INFO) << "Save function is not supported yet."; TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("num_random_samples"), num_random_samples_)); TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("offset"), offset_)); return Status::OK(); } Status RestoreInternal(IteratorContext *ctx, IteratorStateReader *reader) override { mutex_lock l(mu_); LOG(INFO) << "Restore function is not supported yet."; TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("num_random_samples"), &num_random_samples_)); int64 offset; TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("offset"), &offset)); if (dataset()->file_name_.empty()) { offset_ = 0; } else { offset_ = offset; } return Status::OK(); } // Sets up reader streams to read from filename Status SetupStreamsLocked(Env *env) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { std::unique_ptr stream_reader; if (dataset()->file_name_.empty()) { stream_reader = std::make_unique( dataset()->options_, dataset()->options_.buffer_size); } else { std::unique_ptr f; TF_RETURN_IF_ERROR( env->NewRandomAccessFile(dataset()->file_name_, &f)); auto compression_type = FileStreamReader::GetCompressType( dataset()->options_.use_snappy, dataset()->options_.compression_type); stream_reader = std::make_unique( dataset()->options_, std::move(f), compression_type, dataset()->options_.buffer_size); } if (dataset()->input_pb_type_ == "instance" || dataset()->input_pb_type_ == "example") { reader_ = absl::make_unique( std::move(stream_reader), static_cast( dataset()->feature_pruning_type_)); } else { reader_ = absl::make_unique( std::move(stream_reader), static_cast(dataset()->feature_pruning_type_), mapper_); } return Status::OK(); } // Resets all reader streams. void ResetStreamsLocked() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { reader_.reset(); } ::tensorflow::monolith_tf::data_format::DataFormat input_pb_type_; ::tensorflow::monolith_tf::data_format::DataFormat output_pb_type_; mutex mu_; std::unique_ptr reader_ TF_GUARDED_BY(mu_); int64 num_random_samples_ TF_GUARDED_BY(mu_) = 0; uint64 offset_ TF_GUARDED_BY(mu_) = 0; uint64 metric_emit_step_ TF_GUARDED_BY(mu_) = 10000; FeatureNameMapper *mapper_ = nullptr; }; tstring file_name_, input_pb_type_, output_pb_type_; DsOptions options_; DataType out_type_; std::vector feature_name_list_; std::vector feature_id_list_; int feature_pruning_type_ = FeaturePruningType::PRUNING_RAW_FEATURE; FeatureNameMapper *mapper_ = nullptr; }; Dataset *output_ = nullptr; DataType out_type_; int32 compression_type_; FeatureNameMapperTfBridge *mapper_ = nullptr; }; namespace { REGISTER_KERNEL_BUILDER(Name("PBDataset").Device(DEVICE_CPU), PBDatasetOp); } // namespace } // namespace monolith_tf } // namespace data } // namespace tensorflow ================================================ FILE: monolith/native_training/data/kernels/ragged_feature_kernel.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "idl/matrix/proto/example.pb.h" #include "monolith/native_training/data/training_instance/cc/pb_variant.h" #include "monolith/native_training/data/training_instance/cc/reader_util.h" namespace tensorflow { namespace monolith_tf { using Example = ::monolith::io::proto::Example; using ExampleBatch = ::monolith::io::proto::ExampleBatch; using NamedFeature = ::monolith::io::proto::NamedFeature; class SwitchSlotOp : public OpKernel { public: using OpKernel::OpKernel; using ConstFlatSplits = typename TTypes::ConstFlat; explicit SwitchSlotOp(OpKernelConstruction *ctx) : OpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("slot", &slot_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("fid_version", &fid_version_)); } void Compute(OpKernelContext *context) override { // Read the `rt_nested_splits` input & convert to Eigen tensors. OpInputList rt_nested_splits_in; OP_REQUIRES_OK( context, context->input_list("rt_nested_splits", &rt_nested_splits_in)); const int rt_nested_splits_len = rt_nested_splits_in.size(); OpOutputList rt_nested_splits_out; OP_REQUIRES_OK(context, context->output_list("nested_splits_out", &rt_nested_splits_out)); for (int i = 0; i < rt_nested_splits_len; ++i) { Tensor *out_splits; OP_REQUIRES_OK(context, rt_nested_splits_out.allocate( i, rt_nested_splits_in[i].shape(), &out_splits)); std::memcpy(out_splits->data(), rt_nested_splits_in[i].data(), sizeof(int64) * rt_nested_splits_in[i].NumElements()); } const Tensor &rt_dense_values_in = context->input(rt_nested_splits_len); Tensor *dense_values_out; OP_REQUIRES_OK(context, context->allocate_output("dense_values_out", rt_dense_values_in.shape(), &dense_values_out)); auto dense_values_int_ = rt_dense_values_in.flat(); auto dense_values_out_ = dense_values_out->flat(); for (int i = 0; i < dense_values_int_.size(); ++i) { if (fid_version_ == 1) { dense_values_out_(i) = convert_v1(dense_values_int_(i)); } else { dense_values_out_(i) = convert_v2(dense_values_int_(i)); } } } private: inline int64 convert_v1(int64 fid) { static int64 mask = (static_cast(1) << 55) - 1; return (static_cast(slot_) << 54) | (fid & mask); } inline int64 convert_v2(int64 fid) { static int64 mask = (static_cast(1) << 49) - 1; return (static_cast(slot_) << 48) | (fid & mask); } int slot_, fid_version_; }; enum class VariantType { PBExampleBatch, PBExample }; class SwitchSlotBatchOp : public OpKernel { public: using OpKernel::OpKernel; explicit SwitchSlotBatchOp(OpKernelConstruction *ctx) : OpKernel(ctx) { std::vector features; OP_REQUIRES_OK(ctx, ctx->GetAttr("features", &features)); std::vector slots; OP_REQUIRES_OK(ctx, ctx->GetAttr("slots", &slots)); std::vector inplaces; OP_REQUIRES_OK(ctx, ctx->GetAttr("inplaces", &inplaces)); OP_REQUIRES(ctx, features.size() == slots.size(), errors::FailedPrecondition("the length of features and slots are not equal") ); OP_REQUIRES(ctx, features.size() == inplaces.size(), errors::FailedPrecondition("the length of features and inplaces are not equal") ); for (int i = 0; i < features.size(); ++i) { shared_meta_.emplace(std::piecewise_construct, std::forward_as_tuple(features[i]), std::forward_as_tuple(inplaces[i], slots[i])); } OP_REQUIRES_OK(ctx, ctx->GetAttr("suffix", &suffix_)); std::string variant_type; OP_REQUIRES_OK(ctx, ctx->GetAttr("variant_type", &variant_type)); if (variant_type == "example") { variant_type_ = VariantType::PBExample; } else if (variant_type == "example_batch") { variant_type_ = VariantType::PBExampleBatch; } else { OP_REQUIRES_OK(ctx, errors::FailedPrecondition( "variant_type error, variant_type must be example or example_batch!")); } } void Compute(OpKernelContext *ctx) override { const Tensor *pb_input; OP_REQUIRES_OK(ctx, ctx->input("pb_input", &pb_input)); Tensor *pb_output; OP_REQUIRES_OK(ctx, ctx->allocate_output("pb_output", pb_input->shape(), &pb_output)); google::protobuf::Arena arena; if (variant_type_ == VariantType::PBExampleBatch) { switch_example_batch(pb_input, pb_output, &arena); } else { switch_example(pb_input, pb_output, &arena); } } private: void switch_example(const Tensor *pb_input, Tensor *pb_output, google::protobuf::Arena *arena) { auto variant_flat = pb_input->flat(); auto out_variant_flat = pb_output->flat(); for (int i=0; i < variant_flat.size(); ++i) { const Example *example = variant_flat(i).get(); Example *new_example = switch_slot(*example, arena); out_variant_flat(i) = *new_example; } } void switch_example_batch(const Tensor *pb_input, Tensor *pb_output, google::protobuf::Arena *arena) { const Variant &variant = pb_input->scalar()(); auto out_variant_scalar = pb_output->scalar(); const ExampleBatch *eb = variant.get(); ExampleBatch *new_eb = switch_slot(*eb, arena); out_variant_scalar() = *new_eb; } Example *switch_slot(const Example &example, google::protobuf::Arena *arena) { auto *base = google::protobuf::Arena::CreateMessage(arena); base->CopyFrom(example); for (int i = 0; i < example.named_feature_size(); ++i) { const auto &name = base->named_feature(i).name(); auto it = shared_meta_.find(name); if (it != shared_meta_.end()) { bool inplace = it->second.first; uint64_t shared_slot = it->second.second; NamedFeature *named_feature = base->mutable_named_feature(i); if (inplace) { auto *feature = named_feature->mutable_feature(); if (feature->has_fid_v1_list()) { const auto &size = feature->mutable_fid_v1_list()->value_size(); auto *data = feature->mutable_fid_v1_list()->mutable_value()->mutable_data(); for (int j = 0; j < size; ++j) { data[j] = switch_slot_v1(data[j], shared_slot); } } if (feature->has_fid_v2_list()) { const auto &size = feature->mutable_fid_v2_list()->value_size(); auto *data = feature->mutable_fid_v2_list()->mutable_value()->mutable_data(); for (int j = 0; j < size; ++j) { data[j] = switch_slot_v2(data[j], shared_slot); } } } else { auto &feature = named_feature->feature(); auto *additive_nf = base->add_named_feature(); additive_nf->set_id(named_feature->id()); additive_nf->set_name(name + "_" + suffix_); additive_nf->set_sorted_id(named_feature->sorted_id()); if (feature.has_fid_v1_list()) { auto *fid_v1_list = additive_nf->mutable_feature()->mutable_fid_v1_list(); for (const auto &value : feature.fid_v1_list().value()) { fid_v1_list->add_value(switch_slot_v1(value, shared_slot)); } } if (feature.has_fid_v2_list()) { auto *fid_v2_list = additive_nf->mutable_feature()->mutable_fid_v2_list(); for (const auto &value : feature.fid_v2_list().value()) { fid_v2_list->add_value(switch_slot_v2(value, shared_slot)); } } } } } return base; } ExampleBatch *switch_slot(const ExampleBatch &example_batch, google::protobuf::Arena *arena) { auto *base = google::protobuf::Arena::CreateMessage(arena); base->CopyFrom(example_batch); for (int i = 0; i < example_batch.named_feature_list_size(); ++i) { auto *named_feature_list = base->mutable_named_feature_list(i); const auto &name = named_feature_list->name(); auto it = shared_meta_.find(name); if (it != shared_meta_.end()) { bool inplace = it->second.first; uint64_t shared_slot = it->second.second; if (inplace) { for (int j = 0; j < named_feature_list->feature_size(); ++j) { auto *feature = named_feature_list->mutable_feature(j); if (feature->has_fid_v1_list()) { auto *values = named_feature_list->mutable_feature(j)->mutable_fid_v1_list()->mutable_value(); auto *data = values->mutable_data(); for (int k = 0; k < values->size(); ++k) { data[k] = switch_slot_v1(data[k], shared_slot); } } if (feature->has_fid_v2_list()) { auto *values = feature->mutable_fid_v2_list()->mutable_value(); auto *data = values->mutable_data(); for (int k = 0; k < values->size(); ++k) { data[k] = switch_slot_v2(data[k], shared_slot); } } } } else { auto *additive_nfl = base->add_named_feature_list(); additive_nfl->set_id(named_feature_list->id()); additive_nfl->set_name(name + "_" + suffix_); additive_nfl->set_type(named_feature_list->type()); for (int j = 0; j < named_feature_list->feature_size(); ++j) { auto *feature = additive_nfl->add_feature(); if (named_feature_list->feature(j).has_fid_v1_list()) { auto *fid_v1_list = feature->mutable_fid_v1_list(); for (const auto &value : named_feature_list->feature(j).fid_v1_list().value()) { fid_v1_list->add_value(switch_slot_v1(value, shared_slot)); } } if (named_feature_list->feature(j).has_fid_v2_list()) { auto *fid_v2_list = feature->mutable_fid_v2_list(); for (const auto &value : named_feature_list->feature(j).fid_v2_list().value()) { fid_v2_list->add_value(switch_slot_v2(value, shared_slot)); } } } } } } return base; } VariantType variant_type_; std::string suffix_; std::unordered_map> shared_meta_; }; class FeatureCombineOp : public OpKernel { public: using OpKernel::OpKernel; using ConstFlatSplits = typename TTypes::ConstFlat; explicit FeatureCombineOp(OpKernelConstruction *ctx) : OpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("slot", &slot_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("fid_version", &fid_version_)); } void Compute(OpKernelContext *context) override { // Read the `rt_nested_splits` input & convert to Eigen tensors. OpInputList rt_nested_splits_src1_in; OP_REQUIRES_OK(context, context->input_list("rt_nested_splits_src1", &rt_nested_splits_src1_in)); int input_cnt = rt_nested_splits_src1_in.size(); const Tensor &rt_dense_values_src1_in = context->input(input_cnt); input_cnt++; OpInputList rt_nested_splits_src2_in; OP_REQUIRES_OK(context, context->input_list("rt_nested_splits_src2", &rt_nested_splits_src2_in)); input_cnt += rt_nested_splits_src2_in.size(); const Tensor &rt_dense_values_src2_in = context->input(input_cnt); DCHECK_EQ(rt_nested_splits_src1_in.size(), rt_nested_splits_src2_in.size()); OpOutputList nested_splits_sink; OP_REQUIRES_OK(context, context->output_list("nested_splits_sink", &nested_splits_sink)); int src_idx = 0; if (rt_nested_splits_src1_in.size() == 2) { auto batch_splits_src1 = rt_nested_splits_src1_in[src_idx].flat(); auto batch_splits_src2 = rt_nested_splits_src2_in[src_idx].flat(); DCHECK_EQ(batch_splits_src1.size(), batch_splits_src2.size()); for (int i = 0; i < batch_splits_src1.size(); ++i) { DCHECK_EQ(batch_splits_src1(i), batch_splits_src2(i)); } src_idx++; } auto ins_splits_src1 = rt_nested_splits_src1_in[src_idx].flat(); auto rt_dense_values_src1 = rt_dense_values_src1_in.flat(); auto ins_splits_src2 = rt_nested_splits_src2_in[src_idx].flat(); auto rt_dense_values_src2 = rt_dense_values_src2_in.flat(); int batch_size = ins_splits_src1.size() - 1; Tensor *ins_splits_sink; OP_REQUIRES_OK(context, nested_splits_sink.allocate( src_idx, rt_nested_splits_src2_in[src_idx].shape(), &ins_splits_sink)); auto ins_splits = ins_splits_sink->flat(); ins_splits(0) = 0; for (int i = 0; i < batch_size; ++i) { int src1_start = ins_splits_src1(i); int src1_end = ins_splits_src1(i + 1); int src2_start = ins_splits_src2(i); int src2_end = ins_splits_src2(i + 1); ins_splits(i + 1) = ins_splits(i) + (src1_end - src1_start) * (src2_end - src2_start); } Tensor *dense_values_sink; OP_REQUIRES_OK(context, context->allocate_output("dense_values_sink", {ins_splits(batch_size)}, &dense_values_sink)); auto dense_values = dense_values_sink->flat(); int idx = 0; for (int i = 0; i < batch_size; ++i) { int src1_start = ins_splits_src1(i); int src1_end = ins_splits_src1(i + 1); int src2_start = ins_splits_src2(i); int src2_end = ins_splits_src2(i + 1); for (int j = src1_start; j < src1_end; ++j) { int64 fid1 = rt_dense_values_src1(j); for (int k = src2_start; k < src2_end; ++k) { int64 fid2 = rt_dense_values_src2(k); if (fid_version_ == 1) { dense_values(idx++) = convert_v1(combine(fid1, fid2)); } else { dense_values(idx++) = convert_v2(combine(fid1, fid2)); } } } } } private: inline int64 convert_v1(int64 fid) { static int64 mask = (static_cast(1) << 55) - 1; return (static_cast(slot_) << 54) | (fid & mask); } inline int64 convert_v2(int64 fid) { static int64 mask = (static_cast(1) << 49) - 1; return (static_cast(slot_) << 48) | (fid & mask); } int64 combine(int64 fid1, int64 fid2) { auto mu = absl::int128(fid1) * absl::int128(fid2); uint64 hi = static_cast((mu >> 64).operator long()); // NOLINT uint64 lo = static_cast(mu.operator long()); // NOLINT return static_cast(hi ^ lo); } int slot_, fid_version_; }; namespace { REGISTER_KERNEL_BUILDER(Name("FeatureCombine").Device(DEVICE_CPU), FeatureCombineOp); REGISTER_KERNEL_BUILDER(Name("SwitchSlot").Device(DEVICE_CPU), SwitchSlotOp); REGISTER_KERNEL_BUILDER(Name("SwitchSlotBatch").Device(DEVICE_CPU), SwitchSlotBatchOp); } } // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/data/kernels/scatter_label_kernel.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include "absl/container/flat_hash_set.h" #include "absl/strings/str_format.h" #include "monolith/native_training/data/kernels/internal/label_utils.h" #include "monolith/native_training/data/training_instance/cc/instance_utils.h" #include "monolith/native_training/data/training_instance/cc/pb_variant.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/framework/tensor.h" namespace tensorflow { namespace monolith_tf { using IFeature = ::idl::matrix::proto::Feature; using Instance = ::parser::proto::Instance; using Example = ::monolith::io::proto::Example; using LineId = ::idl::matrix::proto::LineId; class ScatterLabelOp : public OpKernel { public: explicit ScatterLabelOp(OpKernelConstruction *context) : OpKernel(context) { OP_REQUIRES_OK(context, context->GetAttr("config", &config_)); OP_REQUIRES_OK(context, context->GetAttr("variant_type", &variant_type_)); if (variant_type_ != "instance" && variant_type_ != "example") { LOG(FATAL) << "Invalid 'variant_type', please choose on from " "['instance', 'example']!"; } std::vector splits = absl::StrSplit(config_, ","); CHECK_GT(splits.size(), 0); int max_label_index_ = 0; for (absl::string_view split : splits) { std::vector chnid_and_index = absl::StrSplit(split, ":"); CHECK_EQ(chnid_and_index.size(), 2); int64_t chnid = 0; int index = 0; CHECK(absl::SimpleAtoi(chnid_and_index[0], &chnid)); CHECK(absl::SimpleAtoi(chnid_and_index[1], &index)); chnid_to_label_index_[chnid] = index; if (max_label_index_ < index) { max_label_index_ = index; } } multi_task_num_ = max_label_index_ + 1; nlohmann::json j; j["chnid_to_label_index"] = chnid_to_label_index_; j["multi_task_num"] = multi_task_num_; LOG(INFO) << j.dump(); } void Compute(OpKernelContext *context) override { const Tensor &input_tensor = context->input(0); Tensor *output_tensor = nullptr; OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(), &output_tensor)); bool is_instance = variant_type_ == "instance"; if (is_instance) { Instance instance; instance.CopyFrom(*input_tensor.scalar()().get()); output_tensor->scalar()() = std::move(instance); } else { Example example; example.CopyFrom(*input_tensor.scalar()().get()); output_tensor->scalar()() = std::move(example); } LineId *line_id = GetLineId(output_tensor, is_instance); auto label = GetLabel(output_tensor, is_instance); float label_value = internal::INVALID_LABEL; if (!label->empty()) { label_value = label->Get(0); } else { LOG_EVERY_N_SEC(ERROR, 60) << "Invalid data: label is empty, please investigate and retry!"; } label->Clear(); label->Resize(multi_task_num_, internal::INVALID_LABEL); int64_t chnid = line_id->chnid(); if (chnid_to_label_index_.count(chnid)) { int idx = chnid_to_label_index_[chnid]; label->Set(idx, label_value); } } private: static LineId *GetLineId(Tensor *output_tensor, bool is_instance) { if (is_instance) { return output_tensor->scalar()() .get() ->mutable_line_id(); } else { return output_tensor->scalar()() .get() ->mutable_line_id(); } } static ::google::protobuf::RepeatedField *GetLabel( Tensor *output_tensor, bool is_instance) { if (is_instance) { return output_tensor->scalar()() .get() ->mutable_label(); } else { return output_tensor->scalar()().get()->mutable_label(); } } std::string config_; std::string variant_type_; int multi_task_num_; std::map chnid_to_label_index_; }; namespace { REGISTER_KERNEL_BUILDER(Name("ScatterLabel").Device(DEVICE_CPU), ScatterLabelOp) } // namespace } // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/data/kernels/split_flow_dataset_kernel.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "monolith/native_training/data/training_instance/cc/pb_variant.h" #include "monolith/native_training/data/training_instance/cc/reader_util.h" #include "monolith/native_training/data/kernels/df_resource_kernel.h" #include "monolith/native_training/data/kernels/internal/datasource_utils.h" #include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/resource_mgr.h" namespace tensorflow { namespace data { namespace monolith_tf { using Instance = ::parser::proto::Instance; using Example = ::monolith::io::proto::Example; using Item = ::tensorflow::monolith_tf::Item; using QueueResource = ::tensorflow::monolith_tf::QueueResource; using VariantType = ::tensorflow::monolith_tf::VariantType; static mutex input_mu_; class SplitFlowDatasetOp : public UnaryDatasetOpKernel { public: static constexpr const char *const kDatasetType = "dataflow_dataset"; static constexpr const char *const kDataFlow = "data_flow"; static constexpr const char *const kIndex = "index"; static constexpr const char *const kMaxQueueSize = "max_queue_size"; static constexpr const char *const kVariantType = "variant_type"; explicit SplitFlowDatasetOp(OpKernelConstruction *ctx); protected: void MakeDataset(OpKernelContext *ctx, DatasetBase *input, DatasetBase **output) override; private: class Dataset; std::vector data_flows_; int index_; int max_queue_size_; VariantType variant_type_; }; class SplitFlowDatasetOp::Dataset : public DatasetBase { public: Dataset(OpKernelContext *ctx, const DatasetBase *input, const std::vector &data_flows, int index, int max_queue_size, const VariantType &variant_type) : DatasetBase(DatasetContext(ctx)), input_(input), data_flows_(data_flows), index_(index), max_queue_size_(max_queue_size), variant_type_(variant_type) { input_->Ref(); } ~Dataset() override { input_->Unref(); } std::unique_ptr MakeIteratorInternal( const string &prefix) const override { return absl::make_unique( Iterator::Params{this, strings::StrCat(prefix, "::", kDatasetType)}); } const DataTypeVector &output_dtypes() const override { return input_->output_dtypes(); } const std::vector &output_shapes() const override { return input_->output_shapes(); } string DebugString() const override { return "This is the customized Dataset: SplitFlowDataset"; } Status InputDatasets( std::vector *inputs) const override { inputs->push_back(input_); return Status::OK(); } Status CheckExternalState() const override { return input_->CheckExternalState(); } void SetContainer(const std::string container) { container_ = container; } std::string GetContainer() const { return container_; } protected: Status AsGraphDefInternal(SerializationContext *ctx, DatasetGraphDefBuilder *b, Node **output) const override { Node *input_graph_node; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); AttrValue data_flows_node; b->BuildAttrValue(data_flows_, &data_flows_node); AttrValue index_node; b->BuildAttrValue(index_, &index_node); AttrValue max_queue_size_node; b->BuildAttrValue(max_queue_size_, &max_queue_size_node); AttrValue variant_type_node; if (variant_type_ == VariantType::PBInstance) { b->BuildAttrValue("instance", &variant_type_node); } else { b->BuildAttrValue("example", &variant_type_node); } TF_RETURN_IF_ERROR( b->AddDataset(this, // dataset {input_graph_node}, // inputs {{kDataFlow, data_flows_node}, {kIndex, index_node}, {kMaxQueueSize, max_queue_size_node}, {kVariantType, variant_type_node}}, // attrs output)); // Node** return Status::OK(); } private: class Iterator : public DatasetIterator { public: explicit Iterator(const Params ¶ms) : DatasetIterator(params), mu_(std::make_shared()), output_mu_(std::make_shared()) {} ~Iterator() override { CancelThreads(); if (deregister_fn_) deregister_fn_(); } void CancelThreads() TF_LOCKS_EXCLUDED(mu_) { cancellation_manager_->StartCancel(); mutex_lock l(*mu_); cancelled_ = true; } Status Initialize(IteratorContext *ctx) override { mutex_lock l(*mu_); name_ = dataset()->data_flows_[dataset()->index_]; cancellation_manager_ = absl::make_unique(); TF_RETURN_IF_ERROR( ::tensorflow::monolith_tf::RegisterCancellationCallback( ctx->cancellation_manager(), [this]() { CancelThreads(); }, &deregister_fn_)); IteratorContext::Params params(ctx); params.cancellation_manager = cancellation_manager_.get(); Status s = dataset()->input_->MakeIterator(IteratorContext(params), this, prefix(), &input_impl_); std::function creator = [this](QueueResource **queue) -> Status { *queue = new QueueResource(dataset()->max_queue_size_); return Status::OK(); }; { mutex_lock input_l(input_mu_); for (size_t i = 0; i < dataset()->data_flows_.size(); ++i) { // 1) get data_flow_name and hash it into uint32 std::string data_flows_name = dataset()->data_flows_[i]; uint32 df_code = static_cast(::tensorflow::monolith_tf::internal::java_hash_code(data_flows_name)); df_code = df_code << 8; // 2) get resource QueueResource *resource = nullptr; s.Update(ctx->resource_mgr()->LookupOrCreate( dataset()->GetContainer(), data_flows_name, &resource, creator)); df_to_queue_.emplace(df_code, resource); if (i == dataset()->index_) { data_flow_ = df_code; queue_ = resource; } } } return s; } Status GetNextInternal(IteratorContext *ctx, std::vector *out_tensors, bool *end_of_sequence) override { // std::thread::id this_id = std::this_thread::get_id(); { mutex_lock l(*mu_); if (dataset()->index_ == 0) { TF_RETURN_IF_ERROR(EnsureThreadStarted(ctx)); } } { mutex_lock output_l(*output_mu_); out_tensors->reserve(1); Item item; bool poped = false; while (!poped) { // the queue is empty and the fetch threas is cancelled or finished if (cancelled_ || prefetch_thread_finished_) { out_tensors->clear(); *end_of_sequence = true; return Status::OK(); } poped = queue_->TryPop(item, 100); } if (!poped || item.end_of_sequence) { out_tensors->clear(); *end_of_sequence = true; } else { for (const auto &tensor : item.out_tensors) { out_tensors->push_back(tensor); } *end_of_sequence = item.end_of_sequence; } } return Status::OK(); } protected: std::shared_ptr CreateNode( IteratorContext *ctx, model::Node::Args args) const override { return model::MakeUnknownRatioNode(std::move(args)); } Status SaveInternal(SerializationContext *ctx, IteratorStateWriter *writer) override { return Status::OK(); } Status RestoreInternal(IteratorContext *ctx, IteratorStateReader *reader) override { return Status::OK(); } private: const std::shared_ptr mu_; const std::shared_ptr output_mu_; std::function deregister_fn_; std::unique_ptr cancellation_manager_; bool cancelled_ TF_GUARDED_BY(*mu_) = false; bool prefetch_thread_started_ TF_GUARDED_BY(*mu_) = false; bool prefetch_thread_finished_ TF_GUARDED_BY(*mu_) = false; uint32 data_flow_; std::string name_; QueueResource *queue_; std::unique_ptr input_impl_; std::unique_ptr prefetch_thread_; std::unordered_map df_to_queue_; Status EnsureThreadStarted(IteratorContext *ctx) TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) { if (!prefetch_thread_started_) { prefetch_thread_started_ = true; std::string name = dataset()->data_flows_[dataset()->index_]; std::shared_ptr new_ctx = std::make_shared(*ctx); prefetch_thread_ = ctx->StartThread( name, [new_ctx, name, this]() { PrefetchThread(new_ctx, name); }); } return Status::OK(); } void PrefetchThread(const std::shared_ptr &ctx, std::string name) { while (true) { { mutex_lock l(*mu_); if (cancelled_) { prefetch_thread_finished_ = true; break; } } if (!prefetch_thread_finished_) { Item item; input_impl_->GetNext(ctx.get(), &item.out_tensors, &item.end_of_sequence); if (item.end_of_sequence) { mutex_lock l(*mu_); item.end_of_sequence = true; if (!cancelled_ && !prefetch_thread_finished_) { for (auto kv : df_to_queue_) { kv.second->Push(item); } } break; } else { uint32 code; if (dataset()->variant_type_ == VariantType::PBInstance) { code = item.out_tensors[0] .scalar()() .get() ->data_source_key(); } else { code = item.out_tensors[0] .scalar()() .get() ->data_source_key(); } bool pushed = false; do { if (cancelled_ || prefetch_thread_finished_) { break; } pushed = df_to_queue_[code]->TryPush(item); } while (!pushed); } } else { break; } } } }; const DatasetBase *const input_; std::vector data_flows_; int index_; int max_queue_size_; VariantType variant_type_; std::string container_; }; SplitFlowDatasetOp::SplitFlowDatasetOp(OpKernelConstruction *ctx) : UnaryDatasetOpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr(kDataFlow, &data_flows_)); OP_REQUIRES_OK(ctx, ctx->GetAttr(kIndex, &index_)); OP_REQUIRES_OK(ctx, ctx->GetAttr(kMaxQueueSize, &max_queue_size_)); std::string variant_type; OP_REQUIRES_OK(ctx, ctx->GetAttr(kVariantType, &variant_type)); if (variant_type == "instance") { variant_type_ = VariantType::PBInstance; } else if (variant_type == "example") { variant_type_ = VariantType::PBExample; } else { LOG(ERROR) << "invalid variant_type: " << variant_type; ctx->SetStatus(Status(tensorflow::error::Code::INVALID_ARGUMENT, "invalid variant_type")); } } void SplitFlowDatasetOp::MakeDataset(OpKernelContext *ctx, DatasetBase *input, DatasetBase **output) { *output = new Dataset(ctx, input, data_flows_, index_, max_queue_size_, variant_type_); std::string container; // OP_REQUIRES_OK(ctx, GetNodeAttr(def(), "container", &container)); static_cast(*output)->SetContainer(""); } namespace { REGISTER_KERNEL_BUILDER(Name("SplitFlowDataset").Device(DEVICE_CPU), SplitFlowDatasetOp); } // namespace } // namespace monolith_tf } // namespace data } // namespace tensorflow ================================================ FILE: monolith/native_training/data/kernels/string_to_variant.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "idl/matrix/proto/example.pb.h" #include "idl/matrix/proto/proto_parser.pb.h" #include "monolith/native_training/data/kernels/internal/datasource_utils.h" #include "monolith/native_training/data/training_instance/cc/data_reader.h" #include "monolith/native_training/data/training_instance/cc/pb_variant.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/lib/core/coding.h" namespace tensorflow { namespace monolith_tf { namespace { using Example = ::monolith::io::proto::Example; using ExampleBatch = ::monolith::io::proto::ExampleBatch; using Instance = ::parser::proto::Instance; class ReadHelper { public: explicit ReadHelper(DataFormatOptions options, bool has_header) : options_(options), has_header_(has_header) {} Status GetData(absl::string_view in, uint8_t* pb_type, uint32_t* data_source_key, absl::string_view* out) { if (has_header_) { ZeroCopyStringViewStreamReader r(options_, in); TF_RETURN_IF_ERROR(r.ReadPBBytes(pb_type, data_source_key, out)); return Status::OK(); } *pb_type = 0; *data_source_key = 0; *out = in; return Status::OK(); } private: DataFormatOptions options_; bool has_header_; }; class StringToVariantOp : public OpKernel { public: using OpKernel::OpKernel; using ConstFlatSplits = typename TTypes::ConstFlat; explicit StringToVariantOp(OpKernelConstruction* ctx) : OpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("input_type", &variant_type_)); std::unordered_set variant_type_set_ = { "instance", "example", "examplebatch", "example_batch"}; OP_REQUIRES( ctx, variant_type_set_.count(variant_type_) != 0, errors::InvalidArgument("variant_type can only be instance, example " "and examplebatch/example_batch")); OP_REQUIRES_OK(ctx, ctx->GetAttr("has_header", &has_header_)); if (has_header_) { OP_REQUIRES_OK(ctx, ctx->GetAttr("has_sort_id", &options_.has_sort_id)); OP_REQUIRES_OK( ctx, ctx->GetAttr("lagrangex_header", &options_.lagrangex_header)); OP_REQUIRES_OK( ctx, ctx->GetAttr("kafka_dump_prefix", &options_.kafka_dump_prefix)); OP_REQUIRES_OK(ctx, ctx->GetAttr("kafka_dump", &options_.kafka_dump)); } std::vector chnid_list; OP_REQUIRES_OK(ctx, ctx->GetAttr("chnids", &chnid_list)); std::vector datasource_list; OP_REQUIRES_OK(ctx, ctx->GetAttr("datasources", &datasource_list)); CHECK_EQ(chnid_list.size(), datasource_list.size()); if (!chnid_list.empty()) { int i = 0; for (const std::string& sv : datasource_list) { uint32 code = internal::java_hash_code(sv); code = code << 8; chnid_to_code_.emplace(chnid_list.at(i), code); LOG(INFO) << "chnid: " << chnid_list.at(i) << ", code: " << code; i++; } } std::string default_datasource; OP_REQUIRES_OK(ctx, ctx->GetAttr("default_datasource", &default_datasource)); uint32 default_code = internal::java_hash_code(default_datasource); default_code_ = (default_code << 8); LOG(INFO) << "default_code: " << default_code_; } void Compute(OpKernelContext* context) override { // Grab the input tensor const Tensor& input_tensor = context->input(0); auto input = input_tensor.flat(); // Create an output tensor Tensor* output_tensor = nullptr; OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(), &output_tensor)); auto output_flat = output_tensor->flat(); uint8_t pb_type; uint32_t data_source_key; ReadHelper reader(options_, has_header_); for (size_t i = 0; i < input.size(); ++i) { const tstring& buf = input(i); absl::string_view res; OP_REQUIRES_OK(context, reader.GetData(buf, &pb_type, &data_source_key, &res)); if (variant_type_ == "instance") { Instance pb; if (res.size() > 0) { CHECK(pb.ParseFromArray(res.data(), res.size())); UpdateDatasourceKey(pb.line_id().chnid(), &data_source_key); pb.set_data_source_key(data_source_key); } output_flat(i) = std::move(pb); } else if (variant_type_ == "example") { Example pb; if (res.size() > 0) { CHECK(pb.ParseFromArray(res.data(), res.size())); UpdateDatasourceKey(pb.line_id().chnid(), &data_source_key); pb.set_data_source_key(data_source_key); } output_flat(i) = std::move(pb); } else { ExampleBatch pb; if (res.size() > 0) { CHECK(pb.ParseFromArray(res.data(), res.size())); pb.set_data_source_key(data_source_key); } output_flat(i) = std::move(pb); } } } private: std::string variant_type_; bool has_header_ = false; DataFormatOptions options_; std::unordered_map chnid_to_code_; uint32 default_code_; void UpdateDatasourceKey(const int64& chnid, uint32_t* data_source_key) { if (has_header_ && options_.lagrangex_header) { return; } else if (!chnid_to_code_.empty()) { if (chnid_to_code_.count(chnid) != 0) { *data_source_key = chnid_to_code_[chnid]; } else { *data_source_key = default_code_; } } else { *data_source_key = default_code_; } } }; class StringToVariantWithTransform : public OpKernel { public: explicit StringToVariantWithTransform(OpKernelConstruction* ctx) : OpKernel(ctx) { std::string variant_type; OP_REQUIRES_OK(ctx, ctx->GetAttr("input_type", &variant_type)); input_type_ = data_format::StringToDataFormat(variant_type); LOG(INFO) << "input_type_:" << variant_type << "," << input_type_; std::string output_type; OP_REQUIRES_OK(ctx, ctx->GetAttr("output_type", &output_type)); output_type_ = data_format::StringToDataFormat(output_type); LOG(INFO) << "output_type_:" << output_type << "," << output_type_; OP_REQUIRES( ctx, (input_type_ == data_format::INSTANCE || input_type_ == data_format::EXAMPLE || input_type_ == data_format::EXAMPLEBATCH) && (output_type_ == data_format::INSTANCE || output_type_ == data_format::EXAMPLE || output_type_ == data_format::EXAMPLEBATCH), errors::InvalidArgument("variant_type can only be instance, example " "and examplebatch/example_batch")); OP_REQUIRES(ctx, !(input_type_ != data_format::EXAMPLEBATCH && output_type_ == data_format::EXAMPLEBATCH), errors::InvalidArgument( "not support output examplebatch input not examplebatch")); OP_REQUIRES_OK(ctx, ctx->GetAttr("has_header", &has_header_)); if (has_header_) { OP_REQUIRES_OK(ctx, ctx->GetAttr("has_sort_id", &options_.has_sort_id)); OP_REQUIRES_OK( ctx, ctx->GetAttr("lagrangex_header", &options_.lagrangex_header)); OP_REQUIRES_OK( ctx, ctx->GetAttr("kafka_dump_prefix", &options_.kafka_dump_prefix)); OP_REQUIRES_OK(ctx, ctx->GetAttr("kafka_dump", &options_.kafka_dump)); } std::vector chnid_list; OP_REQUIRES_OK(ctx, ctx->GetAttr("chnids", &chnid_list)); std::vector datasource_list; OP_REQUIRES_OK(ctx, ctx->GetAttr("datasources", &datasource_list)); CHECK_EQ(chnid_list.size(), datasource_list.size()); if (!chnid_list.empty()) { int i = 0; for (const std::string& sv : datasource_list) { uint32 code = internal::java_hash_code(sv); code = code << 8; chnid_to_code_.emplace(chnid_list.at(i), code); LOG(INFO) << "chnid: " << chnid_list.at(i) << ", code: " << code; i++; } } std::string default_datasource; OP_REQUIRES_OK(ctx, ctx->GetAttr("default_datasource", &default_datasource)); uint32 default_code = internal::java_hash_code(default_datasource); default_code_ = (default_code << 8); LOG(INFO) << "default_code: " << default_code_; } void Compute(OpKernelContext* context) override { // Grab the input tensor const Tensor& input_tensor = context->input(0); auto input = input_tensor.flat(); uint8_t pb_type; uint32_t data_source_key; ReadHelper reader(options_, has_header_); Status s; std::unique_ptr arena = std::make_unique(); if (input_type_ == data_format::EXAMPLEBATCH && output_type_ != data_format::EXAMPLEBATCH) { std::vector eb_list; eb_list.reserve(input.size()); int total_size = 0; for (size_t i = 0; i < input.size(); ++i) { const tstring& buf = input(i); absl::string_view res; OP_REQUIRES_OK(context, reader.GetData(buf, &pb_type, &data_source_key, &res)); auto pb = google::protobuf::Arena::CreateMessage(arena.get()); if (res.size() > 0) { CHECK(pb->ParseFromArray(res.data(), res.size())); pb->set_data_source_key(data_source_key); } eb_list.push_back(pb); total_size += pb->batch_size(); } // Create an output tensor Tensor* output_tensor = nullptr; OP_REQUIRES_OK(context, context->allocate_output(0, {total_size}, &output_tensor)); auto output_flat = output_tensor->flat(); LOG_EVERY_N_SEC(INFO, 60) << "trans pb size:" << input.size() << "->" << total_size; total_size = -1; for (auto pb : eb_list) { for (int index = 0; index < pb->batch_size(); ++index) { if (input_type_ == data_format::INSTANCE) { Instance inst; s = ExampleBatchToInstance(pb, index, &inst); output_flat(++total_size) = std::move(inst); } else { Example ep; s = ExampleBatchToExample(pb, index, &ep, FeaturePruningType::PRUNING_RAW_FEATURE, &fake_mapper_); output_flat(++total_size) = std::move(ep); } if (s != Status::OK()) { LOG(WARNING) << "Trans error:" << s; } } } } else { // Create an output tensor Tensor* output_tensor = nullptr; OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(), &output_tensor)); auto output_flat = output_tensor->flat(); for (size_t i = 0; i < input.size(); ++i) { const tstring& buf = input(i); absl::string_view res; OP_REQUIRES_OK(context, reader.GetData(buf, &pb_type, &data_source_key, &res)); // LOG(ERROR) << "xxx " << buf.size() << "," << res.size(); if (input_type_ == data_format::INSTANCE) { Instance pb; if (res.size() > 0) { CHECK(pb.ParseFromArray(res.data(), res.size())); UpdateDatasourceKey(pb.line_id().chnid(), &data_source_key); pb.set_data_source_key(data_source_key); } if (output_type_ == data_format::INSTANCE) { output_flat(i) = std::move(pb); } else { Example eb_pb; s = InstanceToExample(&pb, &eb_pb); if (s != Status::OK()) { LOG(WARNING) << "Trans error:" << s; } output_flat(i) = std::move(eb_pb); } } else if (input_type_ == data_format::EXAMPLE) { Example pb; if (res.size() > 0) { CHECK(pb.ParseFromArray(res.data(), res.size())); UpdateDatasourceKey(pb.line_id().chnid(), &data_source_key); pb.set_data_source_key(data_source_key); } if (output_type_ == data_format::EXAMPLE) { output_flat(i) = std::move(pb); } else { Instance inst; s = ExampleToInstance(&pb, &inst); if (s != Status::OK()) { LOG(WARNING) << "Trans error:" << s; } output_flat(i) = std::move(inst); } } else { auto pb = google::protobuf::Arena::CreateMessage(arena.get()); if (res.size() > 0) { CHECK(pb->ParseFromArray(res.data(), res.size())); pb->set_data_source_key(data_source_key); } output_flat(i) = std::move(*pb); } } } } private: data_format::DataFormat input_type_, output_type_; bool has_header_ = false; DataFormatOptions options_; std::unordered_map chnid_to_code_; uint32 default_code_; FeatureNameMapper fake_mapper_; void UpdateDatasourceKey(const int64& chnid, uint32_t* data_source_key) { if (has_header_ && options_.lagrangex_header) { return; } else if (!chnid_to_code_.empty()) { if (chnid_to_code_.count(chnid) != 0) { *data_source_key = chnid_to_code_[chnid]; } else { *data_source_key = default_code_; } } else { *data_source_key = default_code_; } } }; class VariantToZerosOp : public OpKernel { public: explicit VariantToZerosOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} void Compute(OpKernelContext* context) override { // Grab the input tensor const Tensor& input_tensor = context->input(0); // Create an output tensor Tensor* output_tensor = nullptr; OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(), &output_tensor)); auto output_flat = output_tensor->flat(); output_flat.setZero(); } }; class HasVariantOp : public OpKernel { public: explicit HasVariantOp(OpKernelConstruction* ctx) : OpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("variant_type", &variant_type_)); std::unordered_set variant_type_set_ = { "instance", "example", "examplebatch", "example_batch"}; OP_REQUIRES( ctx, variant_type_set_.count(variant_type_) != 0, errors::InvalidArgument("variant_type can only be instance, example " "and examplebatch/example_batch")); } void Compute(OpKernelContext* context) override { // Grab the input tensor const Tensor& input_tensor = context->input(0); // Create an output tensor Tensor* output_tensor = nullptr; OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(), &output_tensor)); auto output_scalar = output_tensor->scalar(); int byte_size = 0; if (variant_type_ == "instance") { const auto* instance = input_tensor.scalar()().get(); byte_size = instance->ByteSize(); } else if (variant_type_ == "example") { const auto* example = input_tensor.scalar()().get(); byte_size = example->ByteSize(); } else { const auto* example_batch = input_tensor.scalar()().get(); byte_size = example_batch->ByteSize(); } output_scalar() = byte_size > 0; } private: std::string variant_type_; }; REGISTER_KERNEL_BUILDER(Name("StringToVariant").Device(DEVICE_CPU), StringToVariantOp); REGISTER_KERNEL_BUILDER(Name("StringToVariantWithTransform").Device(DEVICE_CPU), StringToVariantWithTransform); REGISTER_KERNEL_BUILDER(Name("VariantToZeros").Device(DEVICE_CPU), VariantToZerosOp); REGISTER_KERNEL_BUILDER(Name("HasVariant").Device(DEVICE_CPU), HasVariantOp); } // namespace } // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/data/kernels/tf_example_to_example_kernel.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include #include "absl/container/flat_hash_set.h" #include "absl/strings/str_format.h" #include "monolith/native_training/data/data_op_config.pb.h" #include "monolith/native_training/data/training_instance/cc/fid.h" #include "monolith/native_training/data/training_instance/cc/pb_variant.h" #include "tensorflow/core/example/example.pb.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/framework/tensor.h" namespace tensorflow { namespace monolith_tf { using ::monolith::io::proto::Example; using ::monolith::io::proto::Feature; using ::monolith::io::proto::NamedFeature; using ::monolith::native_training::data::config::TFRecordFeatureDescription; class TFExampleToExampleOp : public OpKernel { public: explicit TFExampleToExampleOp(OpKernelConstruction* context) : OpKernel(context) { std::string serialized; OP_REQUIRES_OK(context, context->GetAttr("feature_description", &serialized)); OP_REQUIRES(context, feature_description_.ParseFromString(serialized), errors::InvalidArgument("Corrupted data!")); LOG(INFO) << feature_description_.DebugString(); const auto& s = feature_description_.sparse_features(); const auto& d = feature_description_.dense_features(); absl::flat_hash_set slot_ids, duplicates; for (const auto& kv : s) { sparse_features_.insert(kv.first); auto ret = slot_ids.insert(kv.second); if (!ret.second) { duplicates.insert(kv.second); } } dense_features_.insert(d.begin(), d.end()); std::set intersection; std::set_intersection(sparse_features_.begin(), sparse_features_.end(), dense_features_.begin(), dense_features_.end(), std::inserter(intersection, intersection.begin())); OP_REQUIRES(context, intersection.empty(), errors::InvalidArgument(absl::StrFormat( "%s occur in sparse_features and dense_features " "simultaneously, please investigate and retry!", absl::StrJoin(intersection, ",")))); const auto& label = feature_description_.label(); const auto& instance_weight = feature_description_.instance_weight(); if (!label.empty()) { OP_REQUIRES(context, !sparse_features_.contains(label), errors::InvalidArgument(absl::StrFormat( "label: {%s} should NOT occur in sparse_features, " "please investigate and retry!", label))); OP_REQUIRES(context, !dense_features_.contains(label), errors::InvalidArgument(absl::StrFormat( "label: {%s} should NOT occur in dense_features, " "please investigate and retry!", label))); } if (!instance_weight.empty()) { OP_REQUIRES( context, !sparse_features_.contains(instance_weight), errors::InvalidArgument(absl::StrFormat( "instance_weight: {%s} should NOT occur in sparse_features, " "please investigate and retry!", instance_weight))); OP_REQUIRES( context, !dense_features_.contains(instance_weight), errors::InvalidArgument(absl::StrFormat( "instance_weight: {%s} should NOT occur in dense_features, " "please investigate and retry!", instance_weight))); } OP_REQUIRES(context, duplicates.empty(), errors::InvalidArgument( absl::StrFormat("{%s} have multiple sparse feature name " "mapping, please investigate and retry!", absl::StrJoin(duplicates, ",")))); } void Compute(OpKernelContext* context) override { const Tensor& input_tensor = context->input(0); const auto& serialized = input_tensor.scalar()(); google::protobuf::Arena arena; auto* tf_example = google::protobuf::Arena::CreateMessage(&arena); OP_REQUIRES(context, tf_example->ParseFromString(serialized), errors::DataLoss("Corrupted data!")); Tensor* output_tensor = nullptr; OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(), &output_tensor)); auto* example = google::protobuf::Arena::CreateMessage(&arena); const auto& feature_map = tf_example->features().feature(); const auto& label_name = feature_description_.label(); if (!label_name.empty() && !feature_map.contains(label_name)) { LOG(ERROR) << "label_name: " << label_name << " doest not exist in tf.example.features.feature()!"; } example->set_instance_weight(1.f); const auto& m = feature_description_.sparse_features(); for (const auto& kv : feature_map) { const std::string& name = kv.first; const tensorflow::Feature& f = kv.second; // label if (name == feature_description_.label()) { example->mutable_label()->CopyFrom(f.float_list().value()); continue; } // instance_weight if (name == feature_description_.instance_weight()) { if (!f.has_float_list()) { LOG(ERROR) << absl::StrFormat( "instance_weight: %s does not have float list!", name); } else if (f.float_list().value_size() != 1) { LOG(ERROR) << absl::StrFormat( "instance_weight: %s value_size should be 1", name); } else { example->set_instance_weight(f.float_list().value(0)); } continue; } // sparse & dense if (!sparse_features_.contains(name) && !dense_features_.contains(name)) { continue; } NamedFeature* named_feature = example->add_named_feature(); named_feature->set_name(name); // TODO(zhangbiao.david): set_sorted_id()? // named_feature->set_sorted_id(); Feature* feature = named_feature->mutable_feature(); if (sparse_features_.contains(name)) { int32_t slot_id = m.at(name); named_feature->set_id(slot_id); std::vector fids; if (f.has_int64_list()) { fids.reserve(f.int64_list().value_size()); for (int64_t value : f.int64_list().value()) { fids.push_back(FIDV2(slot_id, value)); } } else if (f.has_float_list()) { fids.reserve(f.float_list().value_size()); for (float value : f.float_list().value()) { int64_t hash_value = CalcHashValue(value); fids.push_back(FIDV2(slot_id, hash_value)); } } else { LOG(ERROR) << "Only supports int64/float32 sparse features!"; } for (FIDV2 fid : fids) { feature->mutable_fid_v2_list()->mutable_value()->Add(fid); } } else if (dense_features_.contains(name)) { if (f.has_int64_list()) { feature->mutable_int64_list()->mutable_value()->CopyFrom( f.int64_list().value()); } else if (f.has_float_list()) { feature->mutable_float_list()->mutable_value()->CopyFrom( f.float_list().value()); } else if (f.has_bytes_list()) { feature->mutable_bytes_list()->mutable_value()->CopyFrom( f.bytes_list().value()); } } } output_tensor->scalar()() = std::move(*example); } private: int64_t CalcHashValue(float value) const { return static_cast(std::log2(std::abs(value) + 1)); } TFRecordFeatureDescription feature_description_; absl::flat_hash_set sparse_features_; absl::flat_hash_set dense_features_; }; namespace { REGISTER_KERNEL_BUILDER(Name("MonolithTFExampleToExample").Device(DEVICE_CPU), TFExampleToExampleOp) } // namespace } // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/data/kernels/transform_dataset_kernel.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/data/kernels/transform_dataset_kernel.h" #include "absl/container/flat_hash_map.h" #include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/framework/stats_aggregator.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/io/buffered_inputstream.h" #include "tensorflow/core/lib/io/inputbuffer.h" #include "tensorflow/core/lib/io/zlib_compression_options.h" #include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/lib/strings/str_util.h" #include "monolith/native_training/data/kernels/internal/label_utils.h" #include "monolith/native_training/data/training_instance/cc/pb_variant.h" #include "monolith/native_training/data/transform/cc/transforms.h" #include "monolith/native_training/runtime/common/linalg_utils.h" #include "third_party/nlohmann/json.hpp" namespace tensorflow { namespace data { namespace monolith_tf { using IFeature = ::idl::matrix::proto::Feature; using Instance = ::parser::proto::Instance; using Example = ::monolith::io::proto::Example; using EFeature = ::monolith::io::proto::Feature; using LineId = ::idl::matrix::proto::LineId; using Action = google::protobuf::RepeatedField; using ::monolith::common::IsAlmostEqual; using monolith::native_training::data::TransformConfig; using tensorflow::monolith_tf::NewTransformFromConfig; using tensorflow::monolith_tf::TransformInterface; // See documentation in ../../ops/dataset_ops.cc for a high-level // description of the following op. /* static */ constexpr const char *const TransformDatasetOp::kDatasetType; /* static */ constexpr const char *const TransformDatasetOp::kInputDataset; /* static */ constexpr const char *const TransformDatasetOp::kConfig; /* static */ constexpr const char *const TransformDatasetOp::kVariantType; class TransformDatasetOp::Dataset : public DatasetBase { public: Dataset(OpKernelContext *ctx, const DatasetBase *input, std::string config_serialized, std::string variant_type) : DatasetBase(DatasetContext(ctx)), input_(input), config_serialized_(std::move(config_serialized)), variant_type_(std::move(variant_type)) { input_->Ref(); OP_REQUIRES(ctx, config_.ParseFromString(config_serialized_), errors::InvalidArgument("Unable to parse config. Make sure it " "is serialized version of " "TransformConfig.")); transform_ = NewTransformFromConfig(config_); } ~Dataset() override { input_->Unref(); } std::unique_ptr MakeIteratorInternal( const string &prefix) const override { return absl::make_unique( Iterator::Params{this, strings::StrCat(prefix, "::", kDatasetType)}); } const DataTypeVector &output_dtypes() const override { return input_->output_dtypes(); } const std::vector &output_shapes() const override { return input_->output_shapes(); } string DebugString() const override { return "This is the customized Dataset: Mixup"; } Status InputDatasets( std::vector *inputs) const override { inputs->push_back(input_); return Status::OK(); } Status CheckExternalState() const override { return input_->CheckExternalState(); } protected: Status AsGraphDefInternal(SerializationContext *ctx, DatasetGraphDefBuilder *b, Node **output) const override { Node *input_graph_node; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); AttrValue config_node; b->BuildAttrValue(config_serialized_, &config_node); AttrValue variant_type_node; b->BuildAttrValue(variant_type_, &variant_type_node); TF_RETURN_IF_ERROR(b->AddDataset( this, {input_graph_node}, {{kConfig, config_node}, {kVariantType, variant_type_node}}, output)); return Status::OK(); } private: class Iterator : public DatasetIterator { public: explicit Iterator(const Params ¶ms) : DatasetIterator(params) {} Status Initialize(IteratorContext *ctx) override { return dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_); } Status GetNextInternal(IteratorContext *ctx, std::vector *out_tensors, bool *end_of_sequence) override { out_tensors->clear(); out_tensors->reserve(1); tensorflow::mutex_lock l(mu_); Status status; if (dataset()->variant_type_ == "instance") { status = NextInternalImpl(ctx, out_tensors, end_of_sequence); } else { status = NextInternalImpl(ctx, out_tensors, end_of_sequence); } return status; } template Status NextInternalImpl(IteratorContext *ctx, std::vector *out_tensors, bool *end_of_sequence) { while (!*end_of_sequence) { std::vector batch_variant; TF_RETURN_IF_ERROR( input_impl_->GetNext(ctx, &batch_variant, end_of_sequence)); if (!*end_of_sequence) { T *instance_or_example = GetCurrent(&batch_variant.back()); std::shared_ptr instance_or_example_ptr; instance_or_example_ptr.reset(instance_or_example, [](...) {}); std::vector> instance_or_example_list; dataset()->transform_->Transform(instance_or_example_ptr, &instance_or_example_list); if (!instance_or_example_list.empty()) { CHECK_EQ(instance_or_example_list.size(), 1); out_tensors->push_back(batch_variant.back()); return Status::OK(); } } } return Status::OK(); } protected: std::shared_ptr CreateNode( IteratorContext *ctx, model::Node::Args args) const override { return model::MakeUnknownRatioNode(std::move(args)); } Status SaveInternal(SerializationContext *ctx, IteratorStateWriter *writer) override { return Status::OK(); } Status RestoreInternal(IteratorContext *ctx, IteratorStateReader *reader) override { return Status::OK(); } private: template inline T *GetCurrent(Tensor *t) { Variant *variant = &t->scalar()(); return variant->get(); } tensorflow::mutex mu_; std::unique_ptr input_impl_ TF_GUARDED_BY(mu_); }; const DatasetBase *const input_; std::string config_serialized_; TransformConfig config_; std::string variant_type_; std::unique_ptr transform_; }; TransformDatasetOp::TransformDatasetOp(OpKernelConstruction *ctx) : UnaryDatasetOpKernel(ctx) { std::string config_serialized; OP_REQUIRES_OK(ctx, ctx->GetAttr(kConfig, &config_serialized)); OP_REQUIRES(ctx, config_.ParseFromString(config_serialized), errors::InvalidArgument("Unable to parse config. Make sure it " "is serialized version of " "TransformConfig.")); OP_REQUIRES_OK(ctx, ctx->GetAttr(kVariantType, &variant_type_)); LOG(INFO) << "variant_type: " << variant_type_ << ", config: \n" << config_.DebugString(); } void TransformDatasetOp::MakeDataset(OpKernelContext *ctx, DatasetBase *input, DatasetBase **output) { *output = new Dataset(ctx, input, config_.SerializeAsString(), variant_type_); } namespace { REGISTER_KERNEL_BUILDER(Name("TransformDataset").Device(DEVICE_CPU), TransformDatasetOp) } // namespace } // namespace monolith_tf } // namespace data } // namespace tensorflow ================================================ FILE: monolith/native_training/data/kernels/transform_dataset_kernel.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_MONOLITH_NATIVE_TRAINING_DATA_KERNELS_TRANSFORM_DATASET_KERNEL_H_ #define MONOLITH_MONOLITH_NATIVE_TRAINING_DATA_KERNELS_TRANSFORM_DATASET_KERNEL_H_ #include "tensorflow/core/framework/dataset.h" #include "monolith/native_training/data/transform/transform_config.pb.h" namespace tensorflow { namespace data { namespace monolith_tf { class TransformDatasetOp : public UnaryDatasetOpKernel { public: static constexpr const char* const kDatasetType = "transform"; static constexpr const char* const kInputDataset = "input_dataset"; static constexpr const char* const kConfig = "config"; static constexpr const char* const kVariantType = "variant_type"; explicit TransformDatasetOp(OpKernelConstruction* ctx); protected: void MakeDataset(OpKernelContext* ctx, DatasetBase* input, DatasetBase** output) override; private: class Dataset; std::string variant_type_; monolith::native_training::data::TransformConfig config_; }; } // namespace monolith_tf } // namespace data } // namespace tensorflow #endif // MONOLITH_MONOLITH_NATIVE_TRAINING_DATA_KERNELS_TRANSFORM_DATASET_KERNEL_H_ ================================================ FILE: monolith/native_training/data/kernels/variant_filter_kernel.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include #include #include #include "absl/container/flat_hash_set.h" #include "absl/strings/str_format.h" #include "monolith/native_training/data/kernels/feature_name_mapper_tf_bridge.h" #include "monolith/native_training/data/kernels/internal/relational_utils.h" #include "monolith/native_training/data/kernels/internal/value_filter_by_feature.h" #include "monolith/native_training/data/kernels/internal/value_filter_by_line_id.h" #include "monolith/native_training/data/training_instance/cc/instance_utils.h" #include "monolith/native_training/data/training_instance/cc/pb_variant.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/variant.h" #include "third_party/nlohmann/json.hpp" namespace tensorflow { namespace monolith_tf { using IFeature = ::idl::matrix::proto::Feature; using Instance = ::parser::proto::Instance; using Example = ::monolith::io::proto::Example; using LineId = ::idl::matrix::proto::LineId; using tensorflow::monolith_tf::internal::LineIdValueFilter; using tensorflow::monolith_tf::internal::FeatureValueFilter; class SetFilterOp : public OpKernel { public: explicit SetFilterOp(OpKernelConstruction *context) : OpKernel(context) { std::vector filter_fids; OP_REQUIRES_OK(context, context->GetAttr("filter_fids", &filter_fids)); filter_fids_.insert(filter_fids.begin(), filter_fids.end()); std::vector has_fids; OP_REQUIRES_OK(context, context->GetAttr("has_fids", &has_fids)); has_fids_.insert(has_fids.begin(), has_fids.end()); std::vector select_fids; OP_REQUIRES_OK(context, context->GetAttr("select_fids", &select_fids)); select_fids_.insert(select_fids.begin(), select_fids.end()); std::vector has_actions; OP_REQUIRES_OK(context, context->GetAttr("has_actions", &has_actions)); has_actions_.insert(has_actions.begin(), has_actions.end()); OP_REQUIRES_OK(context, context->GetAttr("variant_type", &variant_type_)); OP_REQUIRES_OK(context, context->GetAttr("req_time_min", &req_time_min_)); std::vector select_slots; OP_REQUIRES_OK(context, context->GetAttr("select_slots", &select_slots)); for (int32_t slot : select_slots) { CHECK_GE(slot, 0); } select_slots_.insert(select_slots.begin(), select_slots.end()); auto creator = [this](FeatureNameMapperTfBridge **out_mapper) { TF_RETURN_IF_ERROR(FeatureNameMapperTfBridge::New(out_mapper)); return Status::OK(); }; ResourceMgr *resource_mgr = context->resource_manager(); OP_REQUIRES_OK(context, resource_mgr->LookupOrCreate( resource_mgr->default_container(), FeatureNameMapperTfBridge::kName, &mapper_, creator)); if (variant_type_ == "example") { std::vector> valid_ids; for (uint32_t slot : select_slots_) { valid_ids.emplace_back(slot, slot); } for (uint64_t fid : filter_fids_) { valid_ids.emplace_back(slot_id_v1(fid), slot_id_v2(fid)); } for (uint64_t fid : has_fids_) { valid_ids.emplace_back(slot_id_v1(fid), slot_id_v2(fid)); } for (uint64_t fid : select_fids_) { valid_ids.emplace_back(slot_id_v1(fid), slot_id_v2(fid)); } OP_REQUIRES_OK(context, mapper_->RegisterValidIds(valid_ids)); } } ~SetFilterOp() override { mapper_->Unref(); } void Compute(OpKernelContext *context) override { const Tensor &input_tensor = context->input(0); Tensor *output_tensor = nullptr; OP_REQUIRES_OK( context, context->allocate_output(0, input_tensor.shape(), &output_tensor)); auto output = output_tensor->scalar(); output() = IsInstanceOfInterest(input_tensor); } private: bool IsInstanceOfInterest(const Tensor &input_tensor) { auto input = input_tensor.scalar(); if (variant_type_ == "instance") { const auto *instance = input().get(); return monolith_tf::IsInstanceOfInterest( *instance, filter_fids_, has_fids_, select_fids_, has_actions_, req_time_min_, select_slots_); } else { const auto *example = input().get(); return monolith_tf::IsInstanceOfInterest( *example, filter_fids_, has_fids_, select_fids_, has_actions_, req_time_min_, select_slots_); } } std::set filter_fids_; std::set has_fids_; std::set select_fids_; std::set has_actions_; std::string variant_type_ = "instance"; int req_time_min_; std::set select_slots_; FeatureNameMapperTfBridge *mapper_; }; class FeatureValueFilterOp : public OpKernel { public: explicit FeatureValueFilterOp(OpKernelConstruction *context) : OpKernel(context) { OP_REQUIRES_OK(context, context->GetAttr("field_name", &field_name_)); OP_REQUIRES_OK(context, context->GetAttr("op", &op_)); OP_REQUIRES_OK(context, context->GetAttr("float_operand", &float_operand_)); OP_REQUIRES_OK(context, context->GetAttr("int_operand", &int_operand_)); OP_REQUIRES_OK(context, context->GetAttr("string_operand", &string_operand_)); OP_REQUIRES_OK(context, context->GetAttr("operand_filepath", &operand_filepath_)); OP_REQUIRES_OK(context, context->GetAttr("keep_empty", &keep_empty_)); OP_REQUIRES_OK(context, context->GetAttr("field_type", &field_type_)); OP_REQUIRES(context, field_type_ == "int64" || field_type_ == "float" || field_type_ == "double" || field_type_ == "bytes", errors::Unknown( "field_type unknown! need to be int64/float/double/bytes")); feature_value_filter_ = std::make_unique( field_name_, field_type_, op_, float_operand_, int_operand_, string_operand_, operand_filepath_, keep_empty_); } void Compute(OpKernelContext *context) override { const Tensor &input_tensor = context->input(0); const Variant &variant = input_tensor.scalar()(); OP_REQUIRES(context, variant.TypeId() == TypeIndex::Make(), errors::InvalidArgument("input must be Example proto")); Tensor *output_tensor = nullptr; OP_REQUIRES_OK( context, context->allocate_output(0, input_tensor.shape(), &output_tensor)); auto output = output_tensor->scalar(); // only support Example input output() = feature_value_filter_->IsInstanceOfInterest( context->env(), *(input_tensor.scalar()().get())); } private: std::string field_name_; std::string op_; // gt, ge, eq, lt, le, neq, between bool keep_empty_ = false; std::string operand_filepath_; std::vector float_operand_; std::vector int_operand_; std::vector string_operand_; std::unique_ptr feature_value_filter_; std::string field_type_; }; class ValueFilterOp : public OpKernel { public: explicit ValueFilterOp(OpKernelConstruction *context) : OpKernel(context) { OP_REQUIRES_OK(context, context->GetAttr("field_name", &field_name_)); OP_REQUIRES_OK(context, context->GetAttr("op", &op_)); OP_REQUIRES_OK(context, context->GetAttr("float_operand", &float_operand_)); OP_REQUIRES_OK(context, context->GetAttr("int_operand", &int_operand_)); OP_REQUIRES_OK(context, context->GetAttr("string_operand", &string_operand_)); OP_REQUIRES_OK(context, context->GetAttr("operand_filepath", &operand_filepath_)); OP_REQUIRES_OK(context, context->GetAttr("keep_empty", &keep_empty_)); OP_REQUIRES_OK(context, context->GetAttr("variant_type", &variant_type_)); line_id_value_filter_ = std::make_unique( field_name_, op_, float_operand_, int_operand_, string_operand_, operand_filepath_, keep_empty_); } void Compute(OpKernelContext *context) override { const Tensor &input_tensor = context->input(0); Tensor *output_tensor = nullptr; OP_REQUIRES_OK( context, context->allocate_output(0, input_tensor.shape(), &output_tensor)); auto output = output_tensor->scalar(); const LineId &line_id = GetLineId(input_tensor); output() = line_id_value_filter_->IsInstanceOfInterest(context->env(), line_id); } private: const LineId &GetLineId(const Tensor &input_tensor) { if (variant_type_ == "instance") { return input_tensor.scalar()().get()->line_id(); } else { return input_tensor.scalar()().get()->line_id(); } } std::string field_name_; std::string op_; // gt, ge, eq, lt, le, neq, between bool keep_empty_ = false; std::string operand_filepath_; std::vector float_operand_; std::vector int_operand_; std::vector string_operand_; std::unique_ptr line_id_value_filter_; std::string variant_type_; }; class SpecialStrategyOp : public OpKernel { public: explicit SpecialStrategyOp(OpKernelConstruction *context) : OpKernel(context) { std::vector special_strategy; OP_REQUIRES_OK(context, context->GetAttr("special_strategies", &special_strategy)); std::vector sample_rate; OP_REQUIRES_OK(context, context->GetAttr("sample_rates", &sample_rate)); std::vector label; OP_REQUIRES_OK(context, context->GetAttr("labels", &label)); OP_REQUIRES( context, special_strategy.size() == sample_rate.size(), errors::InvalidArgument( "length of sample_rates must identity with special_strategies")); OP_REQUIRES( context, special_strategy.size() == label.size() || label.size() == 0, errors::InvalidArgument( "length of labels must identity with special_strategies or zero")); for (size_t i = 0; i < special_strategy.size(); ++i) { strategy_to_rate_.emplace(special_strategy[i], sample_rate[i]); } if (label.size() > 0) { for (size_t i = 0; i < special_strategy.size(); ++i) { strategy_to_label_.emplace(special_strategy[i], label[i]); } } OP_REQUIRES_OK(context, context->GetAttr("strategy_list", &strategy_list_)); OP_REQUIRES(context, strategy_list_.size() > 0, errors::InvalidArgument("strategy_list cannot be empty")); OP_REQUIRES_OK( context, context->GetAttr("keep_empty_strategy", &keep_empty_strategy_)); OP_REQUIRES_OK(context, context->GetAttr("variant_type", &variant_type_)); } void Compute(OpKernelContext *context) override { Tensor *input_tensor = const_cast(&(context->input(0))); Tensor *output_tensor = nullptr; OP_REQUIRES_OK( context, context->allocate_output(0, input_tensor->shape(), &output_tensor)); auto output = output_tensor->scalar(); output() = DoCompute(input_tensor); } private: const LineId &GetLineId(Tensor *input_tensor) { if (variant_type_ == "instance") { return input_tensor->scalar()().get()->line_id(); } else { return input_tensor->scalar()().get()->line_id(); } } float *GetLabel(Tensor *input_tensor, int index = 0) { if (variant_type_ == "instance") { return input_tensor->scalar()() .get() ->mutable_label() ->Mutable(index); } else { return input_tensor->scalar()() .get() ->mutable_label() ->Mutable(index); } } bool DoCompute(Tensor *input_tensor) { const LineId &line_id = GetLineId(input_tensor); const auto &strategies = line_id.special_strategies(); if (strategies.size() > 1) { LOG(INFO) << "Size of special_strategies is bigger than one, pls. check!"; } if (strategies.size() == 0) { // for unknow samples, drop if (keep_empty_strategy_) { // for special_strategies_neg_ins_keep_normal return true; } else { return false; } } else { for (auto &special_strategy : strategy_list_) { auto found = std::find(strategies.begin(), strategies.end(), special_strategy); if (found != strategies.end()) { auto rit = strategy_to_rate_.find(special_strategy); if (rit != strategy_to_rate_.end()) { float rate = rit->second; bool flag = false; if (rate == 1.0) { flag = true; } else { if (random_neg_sample_(generator_) <= rate) { flag = true; } } if (strategy_to_label_.size() > 0 && flag) { auto lit = strategy_to_label_.find(special_strategy); if (lit != strategy_to_label_.end()) { float new_label = lit->second; float *old_label = GetLabel(input_tensor); *old_label = new_label; } } return flag; } else { return true; } } } } } std::default_random_engine generator_; std::uniform_real_distribution random_neg_sample_; bool keep_empty_strategy_ = true; std::unordered_map strategy_to_rate_; std::unordered_map strategy_to_label_; std::vector strategy_list_; std::string variant_type_; }; class NegativeSampleOp : public OpKernel { public: explicit NegativeSampleOp(OpKernelConstruction *context) : OpKernel(context) { std::vector priorities; std::vector actions; std::vector per_action_drop_rate; OP_REQUIRES_OK(context, context->GetAttr("drop_rate", &drop_rate_)); OP_REQUIRES_OK(context, context->GetAttr("label_index", &label_index_)); OP_REQUIRES_OK(context, context->GetAttr("threshold", &threshold_)); OP_REQUIRES_OK(context, context->GetAttr("priorities", &priorities)); OP_REQUIRES_OK(context, context->GetAttr("actions", &actions)); OP_REQUIRES_OK(context, context->GetAttr("per_action_drop_rate", &per_action_drop_rate)); OP_REQUIRES_OK(context, context->GetAttr("variant_type", &variant_type_)); OP_REQUIRES(context, actions.size() == per_action_drop_rate.size(), errors::Unknown("internal error")); for (size_t i = 0; i < actions.size(); i++) { action_drop_rate_map_.emplace(actions[i], per_action_drop_rate[i]); } for (size_t i = 0; i < priorities.size(); i++) { action_priorities_map_.emplace(priorities[i], i); } if (actions.size() > 0) { enable_drop_by_action_ = true; } } void Compute(OpKernelContext *context) override { const Tensor &input_tensor = context->input(0); Tensor *output_tensor = nullptr; OP_REQUIRES_OK( context, context->allocate_output(0, input_tensor.shape(), &output_tensor)); auto output = output_tensor->scalar(); float label = GetLabel(input_tensor); if (label < threshold_) { float sample_drop_rate = drop_rate_; if (enable_drop_by_action_) { sample_drop_rate = GetNegDropRate(input_tensor); } thread_local std::mt19937 gen((std::random_device())()); float random = gen() % 1000 / 1000.0; output() = random < sample_drop_rate ? false : true; } else { output() = true; } } private: float drop_rate_ = 0.0; int label_index_ = 0; float threshold_ = 0.0; bool enable_drop_by_action_ = false; std::unordered_map action_priorities_map_; std::unordered_map action_drop_rate_map_; std::string variant_type_ = "instance"; float GetLabel(const Tensor &input_tensor) { auto input = input_tensor.scalar(); if (variant_type_ == "instance") { const Instance *instance = input().get(); return instance->label(label_index_); } else { const Example *example = input().get(); return example->label(label_index_); } return 0; } const LineId *GetLineId(const Tensor &input_tensor) { auto input = input_tensor.scalar(); if (variant_type_ == "instance") { const Instance *instance = input().get(); return &instance->line_id(); } else { const Example *example = input().get(); return &example->line_id(); } } int FindMostPriorAction(const Tensor &input_tensor) { const LineId *line_id = GetLineId(input_tensor); CHECK(line_id != nullptr); int most_prior = INT_MAX; int record_action = -1; for (int action : line_id->actions()) { auto it = action_priorities_map_.find(action); if (it != action_priorities_map_.end() && it->second < most_prior) { most_prior = it->second; record_action = action; } } return record_action; } float GetNegDropRate(const Tensor &input_tensor) { int prior_action = FindMostPriorAction(input_tensor); if (prior_action > 0) { auto it = action_drop_rate_map_.find(prior_action); if (it != action_drop_rate_map_.end()) { return it->second; } } return drop_rate_; } }; namespace { REGISTER_KERNEL_BUILDER(Name("SetFilter").Device(DEVICE_CPU), SetFilterOp); REGISTER_KERNEL_BUILDER(Name("FeatureValueFilter").Device(DEVICE_CPU), FeatureValueFilterOp); REGISTER_KERNEL_BUILDER(Name("ValueFilter").Device(DEVICE_CPU), ValueFilterOp); REGISTER_KERNEL_BUILDER(Name("SpecialStrategy").Device(DEVICE_CPU), SpecialStrategyOp); REGISTER_KERNEL_BUILDER(Name("NegativeSample").Device(DEVICE_CPU), NegativeSampleOp); } // namespace } // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/data/multi_flow_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from absl import logging import os import getpass import random import numpy as np import tensorflow as tf from struct import pack, unpack from datetime import datetime, timedelta from idl.matrix.proto.proto_parser_pb2 import Instance from monolith.native_training.data.parsers import parse_instances from monolith.native_training.data.datasets import PBDataset, PbType uids = [674432, 9754221, 7665435, 98797865, 778754432] item_ids = [8767554565, 574220985, 65548979, 5358521231] actions = [1, 2] device_types = ['pc', 'mobile', 'cloud'] slots = [1, 200, 5, 7, 9] NUM_INSTANCE = 4096 MODEL_DIR = os.path.join(os.environ["TEST_TMPDIR"], 'model_dir', 'multi_flow') class MultiFlowTest(tf.test.TestCase): @classmethod def setUpClass(cls): mask = (1 << 54) - 1 start = int(datetime.now().timestamp()) stop = int((datetime.now() + timedelta(days=1)).timestamp()) if not tf.io.gfile.exists(MODEL_DIR): tf.io.gfile.makedirs(MODEL_DIR) ofile = os.path.join(MODEL_DIR, 'data.pb') print(ofile, flush=True) if not tf.io.gfile.exists(ofile): with tf.io.gfile.GFile(ofile, 'wb') as ostream: for _ in range(NUM_INSTANCE): inst = Instance() for slot in slots: h = random.randrange(start, stop) fid = (slot << 54) | (h & mask) inst.fid.append(fid) line_id = inst.line_id line_id.uid = random.choice(uids) line_id.item_id = random.choice(item_ids) line_id.req_time = random.randrange(start, stop) line_id.device_type = random.choice(device_types) line_id.actions.append(random.choice(actions)) lgx_header = cls.mk_kgx_header(dataflow=line_id.device_type) data = inst.SerializeToString() ostream.write(file_content=lgx_header) ostream.write(file_content=pack(f' 0.5: line_id.actions.append(choice(pos_acts)) label.append(1) cls.cid_status[cid]['p'] += 1 else: line_id.actions.append(choice(neg_acts)) label.append(-1) cls.cid_status[cid]['n'] += 1 if variant_type == 'example': label_nf = sample.named_feature.add() label_nf.name = '__LABEL__' label_nf.feature.float_list.value.extend(label) lid = sample.named_feature.add() lid.name = '__LINE_ID__' lid.feature.bytes_list.value.append(line_id.SerializeToString()) else: sample.label.extend(label) sample.line_id.CopyFrom(line_id) es = sample.SerializeToString() ostream.write(pack(' 0 neg = element_out['label'] < 0 pos_cnt += np.sum(pos) neg_cnt += np.sum(neg) count += element_out['label'].shape[0] for cid in self.cid_status: select_channel = element_out[channel_feature_name] == cid np.sum(np.logical_and(select_channel, pos)) np.sum(np.logical_and(select_channel, neg)) except tf.errors.OutOfRangeError: break self.assertEqual(count, pos_cnt + neg_cnt) expect_pos, expect_neg = 0, 0 for pn_dict in self.cid_status.values(): expect_pos += pn_dict['p'] expect_neg += pn_dict['n'] self.assertEqual(expect_pos + expect_neg, num_sample) if not throw_origin: self.assertEqual(pos_cnt, expect_pos) else: self.assertEqual(pos_cnt, 0) if not throw_origin and not throw_origin_neg: if per_channel: pass else: min_gen = (expect_pos - start_num) * neg_num max_gen = expect_pos * neg_num real_gen = count - num_sample self.assertTrue(min_gen <= real_gen <= max_gen) print(count, pos_cnt, neg_cnt, expect_pos, expect_neg, flush=True) logging.info("The number of batch is: {}".format(count)) if __name__ == '__main__': tf.compat.v1.disable_eager_execution() tf.test.main() ================================================ FILE: monolith/native_training/data/ops/feature_utils_ops.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/shape_inference.h" namespace tensorflow { REGISTER_OP("ExtractFid") .Input("input: int64") .Attr("slot: int") .Output("output: int64"); REGISTER_OP("FeatureHash") .Input("input: variant") .Attr("names: list(string)") .Output("output: variant") .SetShapeFn([](shape_inference::InferenceContext *ctx) { ctx->set_output(0, ctx->input(0)); return Status::OK(); }); REGISTER_OP("SetFilter") .Input("input: variant") .Attr("filter_fids: list(int)") .Attr("has_fids: list(int)") .Attr("select_fids: list(int)") .Attr("has_actions: list(int)") .Attr("req_time_min: int") .Attr("select_slots: list(int)") .Attr("variant_type: string") .Output("output: bool"); REGISTER_OP("FeatureValueFilter") .Input("input: variant") .Attr("field_name: string") .Attr("op: string") .Attr("float_operand: list(float)") .Attr("int_operand: list(int)") .Attr("string_operand: list(string)") .Attr("operand_filepath: string") .Attr("field_type: string") .Attr("keep_empty: bool = false") .Output("output: bool"); REGISTER_OP("ValueFilter") .Input("input: variant") .Attr("field_name: string") .Attr("op: string") .Attr("float_operand: list(float)") .Attr("int_operand: list(int)") .Attr("string_operand: list(string)") .Attr("operand_filepath: string") .Attr("keep_empty: bool = false") .Attr("variant_type: string") .Output("output: bool"); REGISTER_OP("AddAction") .Input("input: variant") .Attr("field_name: string") .Attr("op: string") .Attr("float_operand: list(float)") .Attr("int_operand: list(int)") .Attr("string_operand: list(string)") .Attr("variant_type: string") .Attr("actions: list(int)") .Output("output: variant") .SetShapeFn([](shape_inference::InferenceContext *ctx) { ctx->set_output(0, ctx->input(0)); return Status::OK(); }); REGISTER_OP("AddLabel") .Input("input: variant") .Attr("config: string") .Attr("negative_value: float") .Attr("sample_rate: float") .Attr("variant_type: string") .Output("output: variant") .SetShapeFn([](shape_inference::InferenceContext *ctx) { ctx->set_output(0, ctx->input(0)); return Status::OK(); }); REGISTER_OP("MonolithTFExampleToExample") .Input("input: string") .Attr("feature_description: string") .Output("output: variant") .SetShapeFn([](shape_inference::InferenceContext *ctx) { ctx->set_output(0, ctx->input(0)); return Status::OK(); }); REGISTER_OP("ScatterLabel") .Input("input: variant") .Attr("config: string") .Attr("variant_type: string") .Output("output: variant") .SetShapeFn([](shape_inference::InferenceContext *ctx) { ctx->set_output(0, ctx->input(0)); return Status::OK(); }); REGISTER_OP("FilterByLabel") .Input("input: variant") .Attr("label_threshold: list(float)") .Attr("filter_equal: bool") .Attr("variant_type: string") .Output("valid: bool"); REGISTER_OP("SpecialStrategy") .Input("input: variant") .Attr("special_strategies: list(int)") .Attr("sample_rates: list(float)") .Attr("labels: list(float)") .Attr("strategy_list: list(int)") .Attr("keep_empty_strategy: bool = true") .Attr("variant_type: string") .Output("output: bool"); REGISTER_OP("NegativeSample") .Input("input: variant") .Attr("drop_rate: float") .Attr("label_index: int = 0") .Attr("threshold: float = 0.0") .Attr("variant_type: string") .Attr("priorities: list(int)") .Attr("actions: list(int)") .Attr("per_action_drop_rate: list(float)") .Output("output: bool"); REGISTER_OP("LabelUpperBound") .Input("input: variant") .Attr("label_upper_bounds: list(float)") .Attr("variant_type: string") .Output("output: variant"); REGISTER_OP("LabelNormalization") .Input("input: variant") .Attr("norm_methods: list(string)") .Attr("norm_values: list(float)") .Attr("variant_type: string") .Output("output: variant"); REGISTER_OP("UseFieldAsLabel") .Input("input: variant") .Attr("field_name: string") .Attr("overwrite_invalid_value: bool") .Attr("label_threshold: float") .Attr("variant_type: string") .Output("output: variant"); REGISTER_OP("SwitchSlot") .Input("rt_nested_splits: RAGGED_RANK * int64") .Input("rt_dense_values: int64") .Output("nested_splits_out: RAGGED_RANK * int64") .Output("dense_values_out: int64") .Attr("slot: int >=1") .Attr("fid_version: int") .Attr("RAGGED_RANK: int >= 1") .SetShapeFn([](shape_inference::InferenceContext *ctx) { int rank; TF_RETURN_IF_ERROR(ctx->GetAttr("RAGGED_RANK", &rank)); for (int i = 0; i < rank; ++i) { ctx->set_output(i, ctx->input(i)); } ctx->set_output(rank, ctx->input(rank)); return Status::OK(); }); REGISTER_OP("SwitchSlotBatch") .Input("pb_input: variant") .Output("pb_output: variant") .Attr("features: list(string)") .Attr("slots: list(int)") .Attr("inplaces: list(bool)") .Attr("suffix: string") .Attr("variant_type: string") .SetShapeFn([](shape_inference::InferenceContext *ctx) { ctx->set_output(0, ctx->input(0)); return Status::OK(); }); REGISTER_OP("FeatureCombine") .Input("rt_nested_splits_src1: RAGGED_RANK * int64") .Input("rt_dense_values_src1: int64") .Input("rt_nested_splits_src2: RAGGED_RANK * int64") .Input("rt_dense_values_src2: int64") .Output("nested_splits_sink: RAGGED_RANK * int64") .Output("dense_values_sink: int64") .Attr("slot: int >=1") .Attr("fid_version: int") .Attr("RAGGED_RANK: int >= 1") .SetShapeFn([](shape_inference::InferenceContext *ctx) { int rank; TF_RETURN_IF_ERROR(ctx->GetAttr("RAGGED_RANK", &rank)); for (int i = 0; i < rank; ++i) { ctx->set_output(i, ctx->input(i)); } ctx->set_output(rank, ctx->input(rank)); return Status::OK(); }); REGISTER_OP("ItemPoolCreate") .Output("pool: resource") .Attr("start_num: int") .Attr("max_item_num_per_channel: int") .Attr("container: string = ''") .Attr("shared_name: string = ''") // .SetIsStateful() .SetShapeFn(shape_inference::ScalarShape); REGISTER_OP("ItemPoolRandomFill") .Input("ipool: resource") .Output("opool: resource") .SetShapeFn(shape_inference::ScalarShape); REGISTER_OP("ItemPoolCheck") .Input("ipool: resource") .Input("global_step: int64") .Output("opool: resource") .Attr("model_path: string") .Attr("nshards: int") .Attr("buffer_size: int") .SetShapeFn(shape_inference::ScalarShape); REGISTER_OP("ItemPoolSave") .Input("ipool: resource") .Input("global_step: int64") .Output("opool: resource") .Attr("model_path: string") .Attr("nshards: int") .Attr("random_sleep_ms: int=0") // .SetIsStateful() .SetShapeFn(shape_inference::ScalarShape); REGISTER_OP("ItemPoolRestore") .Input("ipool: resource") .Input("global_step: int64") .Output("opool: resource") .Attr("model_path: string") .Attr("buffer_size: int") .Attr("nshards: int") .Attr("random_sleep_ms: int=0") // .SetIsStateful() .SetShapeFn(shape_inference::ScalarShape); REGISTER_OP("FillMultiRankOutput") .Input("input: variant") .Attr("variant_type: string") .Attr("enable_draw_as_rank: bool = false") .Attr("enable_chnid_as_rank: bool = false") .Attr("enable_lineid_rank_as_rank: bool = false") .Attr("rank_num: int = 18") .Output("output: variant") .SetShapeFn(shape_inference::ScalarShape); REGISTER_OP("UseF100MultiHead") .Input("input: variant") .Attr("variant_type: string") .Output("output: variant") .SetShapeFn(shape_inference::ScalarShape); REGISTER_OP("MapId") .Input("input: T") .Attr("from_value: list(int)") .Attr("to_value: list(int)") .Attr("default_value: int") .Output("output: T") .Attr("T: {int32, int64}") .SetShapeFn([](shape_inference::InferenceContext *ctx) { ctx->set_output(0, ctx->input(0)); return Status::OK(); }); REGISTER_OP("MultiLabelGen") .Input("input: variant") .Attr("task_num: int") .Attr("head_to_index: string") .Attr("head_field: string") .Attr("action_priority: string") .Attr("pos_actions: list(int)") .Attr("neg_actions: list(int)") .Attr("use_origin_label: bool") .Attr("pos_label: float") .Attr("neg_label: float") .Attr("variant_type: string") .Output("output: variant") .SetShapeFn([](shape_inference::InferenceContext *ctx) { ctx->set_output(0, ctx->input(0)); return Status::OK(); }); REGISTER_OP("StringToVariant") .Input("input: string") .Attr("input_type: string") .Attr("has_header: bool") .Attr("has_sort_id: bool") .Attr("lagrangex_header: bool") .Attr("kafka_dump_prefix: bool") .Attr("kafka_dump: bool") .Attr("chnids: list(int)") .Attr("datasources: list(string)") .Attr("default_datasource: string") .Output("output: variant") .SetDoNotOptimize() .SetShapeFn([](shape_inference::InferenceContext *ctx) { ctx->set_output(0, ctx->input(0)); return Status::OK(); }); REGISTER_OP("StringToVariantWithTransform") .Input("input: string") .Attr("input_type: string") .Attr("output_type: string") .Attr("has_header: bool") .Attr("has_sort_id: bool") .Attr("lagrangex_header: bool") .Attr("kafka_dump_prefix: bool") .Attr("kafka_dump: bool") .Attr("chnids: list(int)") .Attr("datasources: list(string)") .Attr("default_datasource: string") .Output("output: variant") .SetDoNotOptimize() .SetShapeFn([](shape_inference::InferenceContext *ctx) { ctx->set_output(0, ctx->MakeShape({ctx->UnknownDim()})); return Status::OK(); }); REGISTER_OP("VariantToZeros") .Input("input: variant") .Output("output: int64") .SetDoNotOptimize() .SetShapeFn([](shape_inference::InferenceContext *ctx) { ctx->set_output(0, ctx->input(0)); return Status::OK(); }); REGISTER_OP("HasVariant") .Input("input: variant") .Output("output: bool") .Attr("variant_type: string") .SetShapeFn([](shape_inference::InferenceContext *c) { c->set_output(0, c->Scalar()); return Status::OK(); }); REGISTER_OP("KafkaGroupReadableInit") .Input("topics: string") .Input("metadata: string") .Output("resource: resource") .Attr("input_pb_type: string = ''") .Attr("output_pb_type: string = ''") .Attr("has_sort_id: bool = false") .Attr("lagrangex_header: bool = false") .Attr("kafka_dump_prefix: bool = false") .Attr("kafka_dump: bool = false") .Attr("container: string = ''") .Attr("shared_name: string = ''") .SetShapeFn([](shape_inference::InferenceContext *c) { c->set_output(0, c->Scalar()); return Status::OK(); }); REGISTER_OP("KafkaGroupReadableNext") .Input("input: resource") .Input("index: int64") .Input("message_poll_timeout: int64") .Input("stream_timeout: int64") .Output("message: string") .Output("key: string") .Output("continue_fetch: int64") .SetShapeFn([](shape_inference::InferenceContext *c) { c->set_output(0, c->MakeShape({c->UnknownDim()})); c->set_output(1, c->MakeShape({c->UnknownDim()})); c->set_output(2, c->Scalar()); return Status::OK(); }); REGISTER_OP("KafkaGroupReadableNextV2") .Input("input: resource") .Input("index: int64") .Input("message_poll_timeout: int64") .Input("stream_timeout: int64") .Output("message: variant") .Output("continue_fetch: int64") .SetShapeFn([](shape_inference::InferenceContext *c) { c->set_output(0, c->MakeShape({c->UnknownDim()})); c->set_output(1, c->Scalar()); return Status::OK(); }); REGISTER_OP("MonolithGenFidMask") .Input("splits: T") .Input("values: int64") .Output("mask: float32") .Attr("fid: int") .Attr("T: {int32, int64}") .SetShapeFn([](shape_inference::InferenceContext *ctx) { shape_inference::ShapeHandle shape = ctx->input(0); if (ctx->Rank(shape) == 1) { shape_inference::DimensionHandle new_dim; TF_RETURN_IF_ERROR(ctx->Subtract(ctx->Dim(shape, 0), 1, &new_dim)); shape_inference::ShapeHandle out_shape; TF_RETURN_IF_ERROR(ctx->ReplaceDim(shape, 0, new_dim, &out_shape)); ctx->set_output(0, out_shape); } else { ctx->set_output(0, ctx->MakeShape({ctx->UnknownDim()})); } return Status::OK(); }); } // namespace tensorflow ================================================ FILE: monolith/native_training/data/ops/parse_input_data_ops.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "idl/matrix/proto/example.pb.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/shape_inference.h" namespace tensorflow { namespace monolith_tf { namespace { Status ShapeFn(shape_inference::InferenceContext *ctx) { int batch_size = ctx->Value(ctx->Dim(ctx->input(0), 0)); std::vector shapes; std::vector dtypes; TF_RETURN_IF_ERROR(ctx->GetAttr("shapes", &shapes)); TF_RETURN_IF_ERROR(ctx->GetAttr("dtypes", &dtypes)); if (batch_size > 0) { // know batch_size for (size_t i = 0; i < dtypes.size(); ++i) { if (i >= shapes.size()) { ctx->set_output(i, ctx->Vector(ctx->UnknownDim())); } else { DataType dtype = dtypes[i]; int shape = shapes[i]; if (shape == -1) { if (dtype != DataType::DT_INT64) { return errors::InvalidArgument( "If shape is -1, then dtype must be int64"); } ctx->set_output(i, ctx->Vector(batch_size + 1)); } else { ctx->set_output(i, ctx->Matrix(batch_size, shape)); } } } } else { // batch_size unknown for (size_t i = 0; i < dtypes.size(); ++i) { if (i >= shapes.size()) { ctx->set_output(i, ctx->Vector(ctx->UnknownDim())); } else { int shape = shapes[i]; if (shape > 0) { ctx->set_output(i, ctx->Matrix(ctx->UnknownDim(), shape)); } else { ctx->set_output(i, ctx->Vector(ctx->UnknownDim())); } } } } return Status::OK(); } REGISTER_OP("ParseInstances") .Input("pb_input: T") .Output("tensors: dtypes") .Attr("fidv1_features: list(int)") .Attr("fidv2_features: list(string)") .Attr("names: list(string)") .Attr("shapes: list(int)") .Attr("dtypes: list(type)") .Attr("extra_names: list(string)") .Attr("T: {variant, string}") .SetDoNotOptimize() // Source dataset ops must disable constant folding. .SetShapeFn(ShapeFn); REGISTER_OP("ParseInstancesV2") .Input("pb_input: T") .Output("tensors: dtypes") .Output("sparse_features: variant") .Attr("fidv1_features: list(int)") .Attr("fidv2_features: list(string)") .Attr("names: list(string)") .Attr("shapes: list(int)") .Attr("dtypes: list(type)") .Attr("extra_names: list(string)") .Attr("T: {variant, string}") .SetDoNotOptimize() // Source dataset ops must disable constant folding. .SetShapeFn([](shape_inference::InferenceContext *ctx) { auto status = ShapeFn(ctx); std::vector dtypes; TF_RETURN_IF_ERROR(ctx->GetAttr("dtypes", &dtypes)); ctx->set_output(dtypes.size(), ctx->input(0)); return status; }); REGISTER_OP("ParseExamples") .Input("pb_input: T") .Output("tensors: dtypes") .Attr("names: list(string)") .Attr("shapes: list(int)") .Attr("dtypes: list(type)") .Attr("extra_names: list(string)") .Attr("T: {variant, string}") .SetDoNotOptimize() // Source dataset ops must disable constant folding. .SetShapeFn(ShapeFn); REGISTER_OP("ParseExamplesV2") .Input("pb_input: T") .Output("tensors: dtypes") .Output("sparse_features: variant") .Attr("names: list(string)") .Attr("shapes: list(int)") .Attr("dtypes: list(type)") .Attr("extra_names: list(string)") .Attr("T: {variant, string}") .SetDoNotOptimize() // Source dataset ops must disable constant folding. .SetShapeFn([](shape_inference::InferenceContext *ctx) { auto status = ShapeFn(ctx); std::vector dtypes; TF_RETURN_IF_ERROR(ctx->GetAttr("dtypes", &dtypes)); ctx->set_output(dtypes.size(), ctx->input(0)); return status; }); REGISTER_OP("ParseExampleBatch") .Input("pb_input: T") .Output("tensors: dtypes") .Attr("names: list(string)") .Attr("shapes: list(int)") .Attr("dtypes: list(type)") .Attr("extra_names: list(string)") .Attr("T: {variant, string}") .SetDoNotOptimize() // Source dataset ops must disable constant folding. .SetShapeFn([](shape_inference::InferenceContext *ctx) { std::vector shapes; std::vector dtypes; TF_RETURN_IF_ERROR(ctx->GetAttr("shapes", &shapes)); TF_RETURN_IF_ERROR(ctx->GetAttr("dtypes", &dtypes)); for (size_t i = 0; i < dtypes.size(); ++i) { if (i >= shapes.size()) { ctx->set_output(i, ctx->Vector(ctx->UnknownDim())); } else { int shape = shapes[i]; if (shape > 0) { ctx->set_output(i, ctx->Matrix(ctx->UnknownDim(), shape)); } else { ctx->set_output(i, ctx->Vector(ctx->UnknownDim())); } } } return Status::OK(); }); REGISTER_OP("ParseExampleBatchV2") .Input("pb_input: T") .Output("tensors: dtypes") .Output("sparse_features: variant") .Attr("names: list(string)") .Attr("shapes: list(int)") .Attr("dtypes: list(type)") .Attr("extra_names: list(string)") .Attr("T: {variant, string}") .SetDoNotOptimize() // Source dataset ops must disable constant folding. .SetShapeFn([](shape_inference::InferenceContext *ctx) { // same as ParseExampleBatch std::vector shapes; std::vector dtypes; TF_RETURN_IF_ERROR(ctx->GetAttr("shapes", &shapes)); TF_RETURN_IF_ERROR(ctx->GetAttr("dtypes", &dtypes)); for (size_t i = 0; i < dtypes.size(); ++i) { if (i >= shapes.size()) { ctx->set_output(i, ctx->Vector(ctx->UnknownDim())); } else { int shape = shapes[i]; if (shape > 0) { ctx->set_output(i, ctx->Matrix(ctx->UnknownDim(), shape)); } else { ctx->set_output(i, ctx->Vector(ctx->UnknownDim())); } } } // add sparse_features ctx->set_output(dtypes.size(), ctx->Scalar()); return Status::OK(); }); REGISTER_OP("ParseExampleBatchList") .Input("inputs: N * variant") .Output("tensors: dtypes") .Attr("label_config: string") .Attr("names: list(string)") .Attr("shapes: list(int)") .Attr("dtypes: list(type)") .Attr("extra_names: list(string)") .Attr("positive_label: float") .Attr("negative_label: float") .Attr("N: int") .SetDoNotOptimize() // Source dataset ops must disable constant folding. .SetShapeFn([](shape_inference::InferenceContext *ctx) { std::vector shapes; std::vector dtypes; TF_RETURN_IF_ERROR(ctx->GetAttr("shapes", &shapes)); TF_RETURN_IF_ERROR(ctx->GetAttr("dtypes", &dtypes)); for (size_t i = 0; i < dtypes.size(); ++i) { if (i >= shapes.size()) { ctx->set_output(i, ctx->Vector(ctx->UnknownDim())); } else { int shape = shapes[i]; if (shape > 0) { ctx->set_output(i, ctx->Matrix(ctx->UnknownDim(), shape)); } else { ctx->set_output(i, ctx->Vector(ctx->UnknownDim())); } } } return Status::OK(); }); } // namespace } // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/data/ops/pb_dataset_ops.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_def_builder.h" #include "tensorflow/core/framework/shape_inference.h" namespace tensorflow { REGISTER_OP("PBDataset") .Input("file_name: string") .Input("use_snappy: bool") .Input("has_sort_id: bool") .Input("kafka_dump: bool") .Input("kafka_dump_prefix: bool") .Input("buffer_size: int64") .Input("lagrangex_header: bool") .Input("input_pb_type: string") .Input("output_pb_type: string") .Input("feature_pruning_type: int32") .Input("feature_name_list: string") .Input("feature_id_list: int32") .Attr("out_type: {variant, string}") .Attr("compression_type: int = 0") .Output("handle: variant") .SetDoNotOptimize() // Source dataset ops must disable constant folding. .SetShapeFn([](shape_inference::InferenceContext* c) { shape_inference::ShapeHandle unused; TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused)); return shape_inference::ScalarShape(c); }); REGISTER_OP("ParquetDataset") // basic .Input("file_name: string") .Input("output_pb_type: string") // example or example_batch .Attr("batch_size: int") // valid when output_pb_type is 'example_batch' // select cols .Attr("select_columns: list(string)") .Attr("select_columns_type: list(string)") .Attr("drop_remainder: bool") // output .Output("handle: variant") .SetDoNotOptimize() .SetShapeFn([](shape_inference::InferenceContext* c) { shape_inference::ShapeHandle unused; TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused)); return shape_inference::ScalarShape(c); }); REGISTER_OP("InstanceReweightDataset") .Input("input: variant") .Attr("method: int") .Attr("actions: list(int)") .Attr("weights: list(int)") .Attr("labels: list(int)") .Attr("priorities: list(int)") .Attr("variant_type: string") .Output("handle: variant") .SetDoNotOptimize() // Source dataset ops must disable constant folding. .SetShapeFn([](shape_inference::InferenceContext* c) { shape_inference::ShapeHandle unused; TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 0, &unused)); return shape_inference::ScalarShape(c); }); REGISTER_OP("InstanceNegativeGenDataset") .Input("input: variant") .Input("pool: resource") .Attr("neg_num: int >= 1") .Attr("per_channel: bool") .Attr("channel_feature: string") .Attr("item_features: list(string)") .Attr("label_index: int >= 0") .Attr("positive_label: int") .Attr("negative_label: int") .Attr("negative_action: int") .Attr("action_priority: string") .Attr("positive_actions: list(int)") .Attr("index_feature: string") .Attr("throw_origin: bool") .Attr("throw_origin_neg: bool") .Attr("cache_only_pos: bool") .Attr("cache_negative_actions: list(int)") .Attr("real_neg_instance_weight: float") .Attr("sampled_neg_instance_weight: float") .Attr("unbias_sampled_neg: bool") .Attr("origin_neg_in_pool_proba: float") .Attr("neg_sample_declay_factor: float") .Attr("easy_hard_ratio: float") .Attr("variant_type: string") .Output("handle: variant") .SetDoNotOptimize() .SetShapeFn([](shape_inference::InferenceContext* c) { shape_inference::ShapeHandle unused; TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 0, &unused)); return shape_inference::ScalarShape(c); }); REGISTER_OP("SplitFlowDataset") .Input("input: variant") .Attr("data_flow: list(string)") .Attr("index: int") .Attr("max_queue_size: int") .Attr("variant_type: string") .Output("handle: variant") .SetDoNotOptimize() .SetShapeFn([](shape_inference::InferenceContext* c) { shape_inference::ShapeHandle unused; TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 0, &unused)); return shape_inference::ScalarShape(c); }); REGISTER_OP("MergeFlowDataset") .Input("inputs: N * variant") .Attr("data_flow: list(string)") .Attr("max_queue_size: int") .Attr("variant_type: string") .Attr("N: int >= 1") .Output("handle: variant") .SetDoNotOptimize() .SetShapeFn([](shape_inference::InferenceContext* c) { shape_inference::ShapeHandle unused; TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 0, &unused)); return shape_inference::ScalarShape(c); }); REGISTER_OP("DynamicMatchingFilesDataset") .Input("patterns: string") .Output("handle: variant") .SetDoNotOptimize() .SetShapeFn([](shape_inference::InferenceContext* c) { return shape_inference::ScalarShape(c); }); REGISTER_OP("MonolithCacheOneDataset") .Input("input: variant") .Output("handle: variant") .SetShapeFn([](shape_inference::InferenceContext* c) { shape_inference::ShapeHandle unused; TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 0, &unused)); return shape_inference::ScalarShape(c); }); REGISTER_OP("TransformDataset") .Input("input: variant") .Attr("config: string") .Attr("variant_type: string") .Output("handle: variant") .SetDoNotOptimize() // Source dataset ops must disable constant folding. .SetShapeFn([](shape_inference::InferenceContext* c) { shape_inference::ShapeHandle unused; TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 0, &unused)); return shape_inference::ScalarShape(c); }); } // namespace tensorflow ================================================ FILE: monolith/native_training/data/parse_sparse_feature_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from collections import defaultdict from copy import deepcopy import os import time from typing import Callable, Dict, Iterable, List, Set, Tuple import tensorflow as tf from absl import logging import numpy as np from struct import unpack import random from monolith.native_training.data.parsers import parse_instances, parse_examples, parse_example_batch, \ sharding_sparse_fids, get_default_parser_ctx, ParserCtx from monolith.native_training.model_export.data_gen_utils import gen_fids_v1, gen_fids_v2, fill_line_id, \ gen_instance, FeatureMeta from idl.matrix.proto.example_pb2 import Example, ExampleBatch, FeatureConfigs, FeatureConfig, FeatureListType from idl.matrix.proto.line_id_pb2 import LineId features = { 'f_spm_1': 301, 'f_spm_3': 303, 'f_spm_2': 302, 'f_spm_4': 304, 'f_user_id': 1, 'f_user_ctx_network': 61, 'f_user_id-f_page': 504, 'f_scm': 306, 'f_goods_id': 200, 'f_goods_sale_number_1000': 225, 'f_goods_praise_cnt': 229, 'f_spm': 300, 'f_page': 305, 'f_is_dup': 310, 'f_user_ctx_platform': 52, 'f_goods_title_terms': 209, 'f_goods_tags_terms': 211, 'f_user_test09_array_int32': 554, 'f_user_test15_array_float': 540, 'f_user_test14_array_bool': 543, 'f_user_test12_array_uint64': 551, 'f_user_test10_array_int64': 549 } class DataOpsV2Test(tf.test.TestCase): def __init__(self, *args, **kwargs): super(DataOpsV2Test, self).__init__(*args, **kwargs) self.mask = (1 << 48) - 1 self.version = 3 def fid_v1_to_v2(self, fid_v1): slot_id = (fid_v1 >> 54) fid_v2 = ((slot_id << 48) | (self.mask & fid_v1)) return slot_id, fid_v2 def fill_row_split(self, ps_num, t_cfg, fid_list, row_split): for ps_i in range(ps_num): lenth = 0 for feature_name in t_cfg["feature_list"]: key = feature_name + ":" + str(ps_i) if key in fid_list: lenth += len(fid_list[key]) row_split[t_cfg["table_name"] + ":" + str(ps_i)].append(lenth) def get_pre_output_offset(self, shard, f_cfg): if self.version == 2: return f_cfg["pre_output_index"] + shard * f_cfg[ "table_feature_count"] + f_cfg["feature_in_table_index"] else: return f_cfg["pre_output_index"] + shard def get_feature_cfg(self, raw_feature_cfgs, ps_num): feature_cfg = defaultdict(dict) table_cfg = defaultdict(dict) for feature_name, cfg in raw_feature_cfgs.feature_configs.items(): feature_cfg[feature_name] = { "feature_name": feature_name, "feature_index": -1, "table_name": cfg.table, "table_index": -1, "feature_in_table_index": -1, "table_feature_count": 0, "pre_output_index": 0, "dims_sum": sum(cfg.slice_dims), } if cfg.table not in table_cfg: table_cfg[cfg.table] = { "table_name": cfg.table, "feature_list": [], "table_index": -1, "feature_count": 0, } table_name_sort = sorted(table_cfg.keys()) for idx, name in enumerate(table_name_sort): table_cfg[name]["table_index"] = idx feature_name_sort = sorted(feature_cfg.keys()) for idx, name in enumerate(feature_name_sort): f_cfg = feature_cfg[name] t_cfg = table_cfg[f_cfg["table_name"]] f_cfg["feature_index"] = idx f_cfg["table_index"] = t_cfg["table_index"] f_cfg["feature_in_table_index"] = len(t_cfg["feature_list"]) t_cfg["feature_list"].append(name) pre_index = 0 for idx, name in enumerate(table_name_sort): t_cfg = table_cfg[name] t_cfg["feature_count"] = len(t_cfg["feature_list"]) for feature_name in t_cfg["feature_list"]: f_cfg = feature_cfg[feature_name] f_cfg[ "pre_output_index"] = pre_index if self.version == 2 else idx * ps_num f_cfg["table_feature_count"] = t_cfg["feature_count"] pre_index += max(t_cfg["feature_count"], 1) * ps_num logging.info(f"show feature_cfg: {feature_cfg}") logging.info(f"show table_cfg: {table_cfg}") return feature_cfg, table_cfg, feature_name_sort, table_name_sort def handle_feature(self, fid_v1_list, fid_v2_list, f_cfg, t_cfg, ps_num, fid_offset_list, fid_offset_list2, fid_map_t, fid_map_unique_map, fid_map_unique_t): value_list = [] if len(fid_v1_list) != 0: for fid in fid_v1_list: slot_id, fid_v2 = self.fid_v1_to_v2(fid) value_list.append(fid_v2) elif len(fid_v2_list) != 0: value_list = fid_v2_list for value in value_list: shard = value % ps_num key = f_cfg["feature_name"] + ":" + str(shard) fid_offset_list.append((len(fid_map_t[key]), shard)) fid_map_t[key].append(value) if value not in fid_map_unique_map[key]: fid_map_unique_map[key][value] = len(fid_map_unique_map[key]) fid_map_unique_t[key].append(value) fid_offset_list2.append((fid_map_unique_map[key][value], shard)) def get_offset_result(self, feature_name_sort, table_name_sort, ps_num, feature_cfg, table_cfg, fid_offset_map, fid_offset_map_unique, fid_map_t_in, fid_map_unique_t_in, sparse_feature_shared=set()): fid_map_t = defaultdict(list) fid_map_table_pre_offset = defaultdict(lambda: 0) fid_map_feature_pre_offset = defaultdict(lambda: 0) fid_map_unique_t = defaultdict(list) fid_map_unique_table_pre_offset = defaultdict(lambda: 0) fid_map_unique_feature_pre_offset = defaultdict(lambda: 0) for table_name in table_name_sort: t_cfg = table_cfg[table_name] for ps_i in range(ps_num): to_key = table_name + ":" + str(ps_i) table_pre_offset = fid_map_table_pre_offset[to_key] unique_table_pre_offset = fid_map_unique_table_pre_offset[to_key] for feature_name in t_cfg["feature_list"]: key = feature_name + ":" + str(ps_i) f_cfg = feature_cfg[feature_name] dims_sum = f_cfg["dims_sum"] fid_map_t[to_key].extend(fid_map_t_in[key]) fid_map_feature_pre_offset[key] = table_pre_offset table_pre_offset += dims_sum * len(fid_map_t_in[key]) fid_map_unique_t[to_key].extend(fid_map_unique_t_in[key]) fid_map_unique_feature_pre_offset[key] = unique_table_pre_offset unique_table_pre_offset += dims_sum * len(fid_map_unique_t_in[key]) feature_offset_t = [] nfl_offset_t = [] fid_offset_list = [] fid_offset_list_unique = [] def rewrie_fid_offset(feature_name, pre_offset_dict, f_cfg, fid_list, fid_offset_list_): for mix_offset in fid_list: fid_offset = mix_offset[0] shard = mix_offset[1] feat_offset = self.get_pre_output_offset(shard, f_cfg) if self.version == 3 or self.version == 4: fid_offset *= f_cfg['dims_sum'] fid_offset += pre_offset_dict[feature_name + ":" + str(shard)] fid_offset_list_.append(feat_offset << 32 | fid_offset) #logging.info(f"xxxxx {pre_offset}") for sparse_key in feature_name_sort: f_cfg = feature_cfg[sparse_key] if sparse_key in sparse_feature_shared: nfl_offset = len(feature_offset_t) | 1 << 31 #pass else: nfl_offset = len(feature_offset_t) nfl_offset_t.append(nfl_offset) if sparse_key not in fid_offset_map: continue for fid_list in fid_offset_map[sparse_key]: feature_offset_t.append(len(fid_offset_list)) rewrie_fid_offset(sparse_key, fid_map_feature_pre_offset, f_cfg, fid_list, fid_offset_list) for fid_list in fid_offset_map_unique[sparse_key]: rewrie_fid_offset(sparse_key, fid_map_unique_feature_pre_offset, f_cfg, fid_list, fid_offset_list_unique) feature_offset_t.append(len(fid_offset_list)) nfl_offset_t.append(len(feature_offset_t)) print('==' * 10 + "fid_map_t" + '==' * 10) # print(fid_map_t) print('==' * 10 + "fid_map_unique_t" + '==' * 10) # print(fid_map_unique_t) print('==' * 10 + "fid_offset_map" + '==' * 10) # print(fid_offset_map) return nfl_offset_t, feature_offset_t, fid_offset_list, fid_offset_list_unique, fid_map_t, fid_map_unique_t def diff_test(self, input_type, parse_func, input_str_list, feature_name_sort, table_name_sort, ps_num, feature_cfg, table_cfg, feature_cfgs, sparse_features, dense_features, extra_features, nfl_offset_t, feature_offset_t, fid_offset_list, fid_offset_list_unique, fid_map_t, fid_map_row_split_t, fid_map_unique_t, fid_map_row_split_unique_t): if self.version == 4: fid_list_table_row_length_t = [0] * (ps_num * len(table_name_sort)) fid_list_shard_row_lenth_t = [0] * ps_num fid_list_emb_row_lenth_t = [0] * (ps_num * len(table_name_sort)) fid_list_table_row_length_unique_t = [0] * (ps_num * len(table_name_sort)) fid_list_shard_row_lenth_unique_t = [0] * ps_num fid_list_emb_row_lenth_unique_t = [0] * (ps_num * len(table_name_sort)) fid_map_t2, fid_map_row_split_t2, fid_map_unique_t2, fid_map_row_split_unique_t2 = [], [], [], [] table_count = len(table_name_sort) for ps_i in range(ps_num): for table_i in range(table_count): table_name = table_name_sort[table_i] #t_cfg = table_cfg[table_name] to_key = table_name + ":" + str(ps_i) fid_map_t2.extend(fid_map_t[to_key]) fid_map_row_split_t2.extend(fid_map_row_split_t[to_key]) fid_list_table_row_length_t[ps_i * table_count + table_i] = len( fid_map_t[to_key]) fid_list_shard_row_lenth_t[ps_i] += len(fid_map_t[to_key]) fid_map_unique_t2.extend(fid_map_unique_t[to_key]) fid_map_row_split_unique_t2.extend(fid_map_row_split_unique_t[to_key]) fid_list_table_row_length_unique_t[ps_i * table_count + table_i] = len( fid_map_unique_t[to_key]) fid_list_shard_row_lenth_unique_t[ps_i] += len( fid_map_unique_t[to_key]) emb_dim_sum, emb_dim_sum_unique = 0, 0 t_cfg = table_cfg[table_name] assert len(t_cfg["feature_list"]) + 1 == len( fid_map_row_split_t[to_key]) for feature_idx in range(len(t_cfg["feature_list"])): feature_name = t_cfg["feature_list"][feature_idx] f_cfg = feature_cfg[feature_name] dims_sum = f_cfg["dims_sum"] emb_dim_sum += dims_sum * ( fid_map_row_split_t[to_key][feature_idx + 1] - fid_map_row_split_t[to_key][feature_idx]) emb_dim_sum_unique += dims_sum * ( fid_map_row_split_unique_t[to_key][feature_idx + 1] - fid_map_row_split_unique_t[to_key][feature_idx]) fid_list_emb_row_lenth_t[ps_i * table_count + table_i] = \ emb_dim_sum fid_list_emb_row_lenth_unique_t[ps_i * table_count + table_i] = \ emb_dim_sum_unique fid_map_t, fid_map_row_split_t, fid_map_unique_t, fid_map_row_split_unique_t = \ fid_map_t2, fid_map_row_split_t2, fid_map_unique_t2, fid_map_row_split_unique_t2 with tf.Graph().as_default(): config = tf.compat.v1.ConfigProto() config.graph_options.rewrite_options.disable_meta_optimizer = True input_placeholder = tf.compat.v1.placeholder(dtype=tf.string, shape=(None,)) parsed_results_base, parsed_results = parse_func(input_placeholder) example_batch_varint = parsed_results.pop( ParserCtx.sharding_sparse_fids_sparse_features_key) parallel_flag_list = [0, 1] fid_map_list = [] fid_map_unique_list = [] for parallel_flag in parallel_flag_list: fid_map, fid_offset, feature_offset, nfl_offset, batch_size, nfl_size, feature_size, fid_size, emb_size, \ fid_map_row_split, fid_map_row_split_size, fid_list_emb_row_lenth, \ fid_list_table_row_length, fid_list_shard_row_lenth = sharding_sparse_fids( example_batch_varint, ps_num, feature_cfgs, False, input_type, parallel_flag, version=self.version) fid_map_list.append([ fid_map, fid_offset, feature_offset, nfl_offset, fid_map_row_split ]) fid_map_unique, fid_offset, feature_offset, nfl_offset, batch_size, nfl_size, feature_size, fid_size, emb_size, \ fid_map_unique_row_split, fid_map_unique_row_split_size, fid_list_emb_row_lenth_unique, \ fid_list_table_row_length_unique, fid_list_shard_row_lenth_unique = sharding_sparse_fids( example_batch_varint, ps_num, feature_cfgs, True, input_type, parallel_flag, version=self.version) fid_map_unique_list.append([ fid_map_unique, fid_offset, feature_offset, nfl_offset, fid_map_unique_row_split ]) with self.session(config=config) as sess: parsed_results_base1, parsed_results1 = sess.run( fetches=[parsed_results_base, parsed_results], feed_dict={input_placeholder: input_str_list}) def diff(k, a, b, sort=False): if not isinstance(a[0], list) and sort: a.sort() b.sort() #print("diff:a {} {}".format(k, a), flush=True) #print("diff:b {} {}".format(k, b), flush=True) assert (len(a) == len(b)) if (len(a) == 0): return for i in range(len(a)): if isinstance(a[i], list): assert isinstance(b[i], list) diff(k + "/" + str(i), a[i], b[i], sort) else: assert (a[i] == b[i]), f"{i}: {a[i]} / {b[i]}" #print('==' * 10 + "parsed_results_base1" + '==' * 10, flush=True) #print('==' * 10 + "parsed_results1" + '==' * 10, flush=True) # .numpy() for k, v in parsed_results_base1.items(): if k in sparse_features: continue if k in dense_features + extra_features: if k not in parsed_results1: print("no find {} in parse_example_batch_v2".format(k)) assert (False) diff(k, v.tolist(), parsed_results1[k].tolist()) #.numpy() else: print("no need {}".format(k), flush=True) assert (False) for k, v in parsed_results1.items(): if k not in dense_features + extra_features: print("no need {}".format(k), flush=True) assert (False) for fid_map_index in range(len(parallel_flag_list)): fid_map = fid_map_list[fid_map_index] fid_map_unique = fid_map_unique_list[fid_map_index] fid_map1_list, fid_map_unique1_list = sess.run( fetches=[fid_map, fid_map_unique], feed_dict={input_placeholder: input_str_list}) #print('==' * 10 + "fid_map1" + '==' * 10, flush=True) #print(fid_map1, flush=True) #print('==' * 10 + "fid_map_unique1" + '==' * 10, flush=True) #print(fid_map_unique1, flush=True) #print('==' * 10 + "fid_offset2" + '==' * 10, flush=True) #print(list(fid_offset2), flush=True) #print('==' * 10 + "fid_offset_list_unique" + '==' * 10, flush=True) #print(list(fid_offset_list_unique), flush=True) fid_map1, fid_offset1, feature_offset1, nfl_offset1, fid_map_row_split1 = fid_map1_list fid_map2, fid_offset2, feature_offset2, nfl_offset2, fid_map_unique_row_split1 = fid_map_unique1_list print('==' * 10 + "diff fidoffset " + str(parallel_flag_list[fid_map_index]) + '==' * 10, flush=True) diff("nfl_offset", nfl_offset1, nfl_offset_t) diff("nfl_offset2", nfl_offset2, nfl_offset_t) diff("feature_offset", feature_offset1, feature_offset_t) diff("feature_offset2", feature_offset2, feature_offset_t) diff("fid_offset2", fid_offset2, fid_offset_list_unique) diff("fid_offset", fid_offset1, fid_offset_list) if isinstance(fid_map_t, Dict): assert (len(fid_map_t) == len(fid_map1)) assert (len(fid_map_unique_t) == len(fid_map2)) def fid_diff(a, b): if isinstance(a, Dict): for k, v in a.items(): assert (k in b) diff(k, v.tolist(), b[k]) #.numpy() for k, v in b.items(): assert (k in a) else: diff("main", a.tolist(), b) print('==' * 10 + "diff fid_map1 " + str(parallel_flag_list[fid_map_index]) + '==' * 10, flush=True) #print(f"xxxx fid_map1 {fid_map1} ", flush=True) #print(f"xxxx fid_map_t {fid_map_t} ", flush=True) fid_diff(fid_map1, fid_map_t) #print(f"xxxx fid_map_row_split1 {fid_map_row_split1}", flush=True) #print(f"xxxx fid_map_row_split_t {fid_map_row_split_t}", flush=True) fid_diff(fid_map_row_split1, fid_map_row_split_t) print('==' * 10 + "diff fid_map_unique1 " + str(parallel_flag_list[fid_map_index]) + '==' * 10, flush=True) #print(f"xxxx fid_map2 {fid_map2} ", flush=True) #print(f"xxxx fid_map_unique_t {fid_map_unique_t} ", flush=True) #print(f"xxxx fid_map_unique_row_split {fid_map_unique_row_split} ", # flush=True) fid_diff(fid_map2, fid_map_unique_t) fid_diff(fid_map_unique_row_split1, fid_map_row_split_unique_t) if self.version == 4: fid_list_table_row_length_1, fid_list_shard_row_lenth_1, \ fid_list_table_row_length_unique_1, fid_list_shard_row_lenth_unique_1, \ fid_list_emb_row_lenth_1, fid_list_emb_row_lenth_unique_1 = \ sess.run(fetches=[fid_list_table_row_length, fid_list_shard_row_lenth, \ fid_list_table_row_length_unique, fid_list_shard_row_lenth_unique, \ fid_list_emb_row_lenth, fid_list_emb_row_lenth_unique], \ feed_dict={input_placeholder: input_str_list}) diff("", fid_list_table_row_length_1.tolist(), fid_list_table_row_length_t) diff("", fid_list_shard_row_lenth_1.tolist(), fid_list_shard_row_lenth_t) diff("", fid_list_emb_row_lenth_1.tolist(), fid_list_emb_row_lenth_t) diff("", fid_list_table_row_length_unique_1.tolist(), fid_list_table_row_length_unique_t) diff("", fid_list_shard_row_lenth_unique_1.tolist(), fid_list_shard_row_lenth_unique_t) diff("", fid_list_emb_row_lenth_unique_1.tolist(), fid_list_emb_row_lenth_unique_t) def testExampleBatchSharding(self): file_name = "monolith/native_training/data/training_instance/examplebatch.data" sparse_features = list(features.keys()) with open(file_name, 'rb') as stream: stream.read(8) # strip lagrangex_header size = unpack(" 10: break try: stream.read(8) # strip has_sort_id stream.read(8) # strip kafka_dump size = unpack("> 54) fid_v2 = ((slot_id << 48) | (mask & fid)) value_list.append(fid_v2) elif len(feature.fid_v2_list.value) != 0: value_list = feature.fid_v2_list.value for value in value_list: shard = value % ps_num key = table_name + ":" + str(shard) fid_offset = (table_index * ps_num + shard) << 32 fid_offset_list.append(fid_offset | len(fid_map_t[key])) fid_map_t[key].append(value) if value not in fid_map_unique_map[key]: fid_map_unique_map[key][value] = len(fid_map_unique_map[key]) fid_map_unique_t[key].append(value) fid_offset_list2.append(fid_offset | fid_map_unique_map[key][value]) fid_offset_map = defaultdict(list) fid_offset_map_unique = defaultdict(list) sparse_feature_shared = set() example_batch_feature_map = {} for named_feature_list in example_batch.named_feature_list: if named_feature_list.name not in feature_cfgs.feature_configs or \ named_feature_list.name not in sparse_features: continue example_batch_feature_map[named_feature_list.name] = named_feature_list for sparse_key in sparse_features: if sparse_key not in example_batch_feature_map: continue named_feature_list = example_batch_feature_map[sparse_key] table_name = feature_cfgs.feature_configs[named_feature_list.name].table table_index = table_name_index_map[table_name] if named_feature_list.type == FeatureListType.SHARED: sparse_feature_shared.add(named_feature_list.name) feature = named_feature_list.feature[0] fid_offset_list = [] fid_offset_list2 = [] handle_feature(feature, table_name, table_index, fid_offset_list, fid_offset_list2) fid_offset_map[named_feature_list.name].append(fid_offset_list) fid_offset_map_unique[named_feature_list.name].append(fid_offset_list2) else: for feature in named_feature_list.feature: fid_offset_list = [] fid_offset_list2 = [] handle_feature(feature, table_name, table_index, fid_offset_list, fid_offset_list2) fid_offset_map[named_feature_list.name].append(fid_offset_list) fid_offset_map_unique[named_feature_list.name].append( fid_offset_list2) feature_offset_t = [] nfl_offset_t = [] fid_offset_list = [] fid_offset_list_unique = [] for sparse_key in sparse_features: if sparse_key in sparse_feature_shared: nfl_offset = len(feature_offset_t) | 1 << 31 #pass else: nfl_offset = len(feature_offset_t) nfl_offset_t.append(nfl_offset) if sparse_key not in fid_offset_map: continue for fid_list in fid_offset_map[sparse_key]: feature_offset_t.append(len(fid_offset_list)) fid_offset_list.extend(fid_list) for fid_list in fid_offset_map_unique[sparse_key]: fid_offset_list_unique.extend(fid_list) feature_offset_t.append(len(fid_offset_list)) nfl_offset_t.append(len(feature_offset_t)) print('==' * 10 + "fid_map_t" + '==' * 10) #print(fid_map_t) print('==' * 10 + "fid_map_unique_t" + '==' * 10) #print(fid_map_unique_t) print('==' * 10 + "fid_offset_map" + '==' * 10) #print(fid_offset_map) #print('==' * 10 + "example_batch" + '==' * 10) #print(example_batch) with tf.Graph().as_default(): config = tf.compat.v1.ConfigProto() config.graph_options.rewrite_options.disable_meta_optimizer = True examples_placeholder = tf.compat.v1.placeholder(dtype=tf.string, shape=(None,)) get_default_parser_ctx().enable_fused_layout = False parsed_results_base = parse_example_batch( examples_placeholder, sparse_features=sparse_features, dense_features=dense_features, dense_feature_shapes=dense_feature_shapes, dense_feature_types=dense_feature_types, extra_features=extra_features, extra_feature_shapes=extra_feature_shapes) get_default_parser_ctx().enable_fused_layout = True parsed_results = parse_example_batch( examples_placeholder, sparse_features=[], dense_features=dense_features, dense_feature_shapes=dense_feature_shapes, dense_feature_types=dense_feature_types, extra_features=extra_features, extra_feature_shapes=extra_feature_shapes) example_batch_varint = parsed_results.pop( ParserCtx.sharding_sparse_fids_sparse_features_key) parallel_flag_list = [0, 1] fid_map_list = [] fid_map_unique_list = [] for parallel_flag in parallel_flag_list: fid_map, fid_offset, feature_offset, nfl_offset, batch_size, nfl_size, feature_size, fid_size, emb_size, \ fid_list_row_splits, fid_list_row_splits_size, fid_list_emb_row_lenth, \ fid_list_table_row_length, fid_list_shard_row_lenth = sharding_sparse_fids( example_batch_varint, ps_num, feature_cfgs, False, "examplebatch", parallel_flag, version=1) fid_map_list.append([fid_map, fid_offset, feature_offset, nfl_offset]) fid_map_unique, fid_offset, feature_offset, nfl_offset, batch_size, nfl_size, feature_size, fid_size, emb_size, \ fid_list_row_splits, fid_list_row_splits_size, fid_list_emb_row_lenth, \ fid_list_table_row_length, fid_list_shard_row_lenth = sharding_sparse_fids( example_batch_varint, ps_num, feature_cfgs, True, "examplebatch", parallel_flag, version=1) fid_map_unique_list.append( [fid_map_unique, fid_offset, feature_offset, nfl_offset]) with self.session(config=config) as sess: parsed_results_base1, parsed_results1 = sess.run( fetches=[parsed_results_base, parsed_results], feed_dict={examples_placeholder: [eb_str]}) def diff(k, a, b, sort=False): if not isinstance(a[0], list) and sort: a.sort() b.sort() #print("diff:a {} {}".format(k, a), flush=True) #print("diff:b {} {}".format(k, b), flush=True) assert (len(a) == len(b)) if (len(a) == 0): return for i in range(len(a)): if isinstance(a[i], list): assert isinstance(b[i], list) diff(k + "/" + str(i), a[i], b[i], sort) else: assert (a[i] == b[i]), f"{i}: {a[i]} / {b[i]}" #print('==' * 10 + "parsed_results_base1" + '==' * 10, flush=True) #print('==' * 10 + "parsed_results1" + '==' * 10, flush=True) # .numpy() for k, v in parsed_results_base1.items(): if k in sparse_features: continue if k in dense_features + extra_features: if k not in parsed_results1: print("no find {} in parse_example_batch_v2".format(k)) assert (False) diff(k, v.tolist(), parsed_results1[k].tolist()) #.numpy() else: print("no need {}".format(k), flush=True) assert (False) for k, v in parsed_results1.items(): if k not in dense_features + extra_features: print("no need {}".format(k), flush=True) assert (False) for fid_map_index in range(len(parallel_flag_list)): fid_map = fid_map_list[fid_map_index] fid_map_unique = fid_map_unique_list[fid_map_index] fid_map1_list, fid_map_unique1_list = sess.run( fetches=[fid_map, fid_map_unique], feed_dict={examples_placeholder: [eb_str]}) #print('==' * 10 + "fid_map1" + '==' * 10, flush=True) #print(fid_map1, flush=True) #print('==' * 10 + "fid_map_unique1" + '==' * 10, flush=True) #print(fid_map_unique1, flush=True) fid_map1, fid_offset1, feature_offset1, nfl_offset1 = fid_map1_list fid_map2, fid_offset2, feature_offset2, nfl_offset2 = fid_map_unique1_list print('==' * 10 + "diff fidoffset " + str(parallel_flag_list[fid_map_index]) + '==' * 10, flush=True) diff("nfl_offset", nfl_offset1, nfl_offset_t) diff("nfl_offset2", nfl_offset2, nfl_offset_t) diff("feature_offset", feature_offset1, feature_offset_t) diff("feature_offset2", feature_offset2, feature_offset_t) diff("fid_offset", fid_offset1, fid_offset_list) diff("fid_offset2", fid_offset2, fid_offset_list_unique) assert (len(fid_map_t) == len(fid_map1)) assert (len(fid_map_unique_t) == len(fid_map2)) def fid_diff(a, b): for k, v in a.items(): assert (k in b) diff(k, v.tolist(), b[k], True) #.numpy() for k, v in b.items(): assert (k in a) print('==' * 10 + "diff fid_map1 " + str(parallel_flag_list[fid_map_index]) + '==' * 10, flush=True) fid_diff(fid_map1, fid_map_t) print('==' * 10 + "diff fid_map_unique1 " + str(parallel_flag_list[fid_map_index]) + '==' * 10, flush=True) fid_diff(fid_map2, fid_map_unique_t) #assert (False) def testExampleSharding(self): sparse_features_set = set() dense_features = ['label'] dense_feature_shapes = [2] dense_feature_types = [tf.float32] extra_features = ['uid', 'req_time', 'item_id', 'actions'] extra_feature_shapes = [1, 1, 1, 1] example_str_list = [] example_list = [] file_name = "monolith/native_training/data/training_instance/example.pb" with open(file_name, 'rb') as stream: while (True): if len(example_str_list) > 10: break try: stream.read(8) # strip has_sort_id stream.read(8) # strip kafka_dump size = unpack("> 54) fid_v2 = ((slot_id << 48) | (mask & fid)) value_list.append(fid_v2) elif len(feature.fid_v2_list.value) != 0: value_list = feature.fid_v2_list.value for value in value_list: shard = value % ps_num key = table_name + ":" + str(shard) fid_offset = (table_index * ps_num + shard) << 32 fid_offset_list.append(fid_offset | len(fid_map_t[key])) fid_map_t[key].append(value) if value not in fid_map_unique_map[key]: fid_map_unique_map[key][value] = len(fid_map_unique_map[key]) fid_map_unique_t[key].append(value) fid_offset_list2.append(fid_offset | fid_map_unique_map[key][value]) fid_offset_map = defaultdict(list) fid_offset_map_unique = defaultdict(list) for sparse_key in sparse_features: for example in example_list: find_named_feature = None for named_feature in example.named_feature: if named_feature.name == sparse_key: find_named_feature = named_feature break fid_offset_list = [] fid_offset_list2 = [] if find_named_feature: table_name = feature_cfgs.feature_configs[sparse_key].table table_index = table_name_index_map[table_name] handle_feature(find_named_feature.feature, table_name, table_index, fid_offset_list, fid_offset_list2) fid_offset_map[sparse_key].append(fid_offset_list) fid_offset_map_unique[sparse_key].append(fid_offset_list2) feature_offset_t = [] nfl_offset_t = [] fid_offset_list = [] fid_offset_list_unique = [] for sparse_key in sparse_features: nfl_offset = len(feature_offset_t) nfl_offset_t.append(nfl_offset) if sparse_key not in fid_offset_map: continue for fid_list in fid_offset_map[sparse_key]: feature_offset_t.append(len(fid_offset_list)) fid_offset_list.extend(fid_list) for fid_list in fid_offset_map_unique[sparse_key]: fid_offset_list_unique.extend(fid_list) feature_offset_t.append(len(fid_offset_list)) nfl_offset_t.append(len(feature_offset_t)) print('==' * 10 + "fid_map_t" + '==' * 10, flush=True) #print(fid_map_t, flush=True) print('==' * 10 + "fid_map_unique_t" + '==' * 10, flush=True) #print(fid_map_unique_t, flush=True) print('==' * 10 + "fid_offset_map" + '==' * 10) #print(fid_offset_map) #example_tensor = tf.convert_to_tensor(example_str_list) with tf.Graph().as_default(): config = tf.compat.v1.ConfigProto() config.graph_options.rewrite_options.disable_meta_optimizer = True examples_placeholder = tf.compat.v1.placeholder(dtype=tf.string, shape=(None)) get_default_parser_ctx().enable_fused_layout = False parsed_results_base = parse_examples( examples_placeholder, sparse_features=sparse_features, dense_features=dense_features, dense_feature_shapes=dense_feature_shapes, dense_feature_types=dense_feature_types, extra_features=extra_features, extra_feature_shapes=extra_feature_shapes) get_default_parser_ctx().enable_fused_layout = True parsed_results = parse_examples(examples_placeholder, sparse_features=[], dense_features=dense_features, dense_feature_shapes=dense_feature_shapes, dense_feature_types=dense_feature_types, extra_features=extra_features, extra_feature_shapes=extra_feature_shapes) examples_varint = parsed_results.pop( ParserCtx.sharding_sparse_fids_sparse_features_key) parallel_flag_list = [0, 1] fid_map_list = [] fid_map_unique_list = [] for parallel_flag in parallel_flag_list: fid_map, fid_offset, feature_offset, nfl_offset, batch_size, nfl_size, feature_size, fid_size, emb_size, \ fid_list_row_splits, fid_list_row_splits_size, fid_list_emb_row_lenth, \ fid_list_table_row_length, fid_list_shard_row_lenth = sharding_sparse_fids( examples_varint, ps_num, feature_cfgs, False, "example", parallel_flag, version=1) fid_map_list.append([fid_map, fid_offset, feature_offset, nfl_offset]) fid_map_unique, fid_offset, feature_offset, nfl_offset, batch_size, nfl_size, feature_size, fid_size, emb_size, \ fid_list_row_splits, fid_list_row_splits_size, fid_list_emb_row_lenth, \ fid_list_table_row_length, fid_list_shard_row_lenth = sharding_sparse_fids( examples_varint, ps_num, feature_cfgs, True, "example", parallel_flag, version=1) fid_map_unique_list.append( [fid_map_unique, fid_offset, feature_offset, nfl_offset]) with self.session(config=config) as sess: parsed_results_base1, parsed_results1 = sess.run( fetches=[parsed_results_base, parsed_results], feed_dict={examples_placeholder: example_str_list}) def diff(k, a, b, sort=False): if not isinstance(a[0], list) and sort: a.sort() b.sort() def print_func(): print("diff:a {} {}".format(k, a), flush=True) print("diff:b {} {}".format(k, b), flush=True) return "{}, {}".format(len(a), len(b)) assert (len(a) == len(b)), print_func() if (len(a) == 0): return for i in range(len(a)): if isinstance(a[i], list): assert isinstance(b[i], list), print_func() diff(k + "/" + str(i), a[i], b[i], sort) else: assert (a[i] == b[i]), print_func() #print('==' * 10 + "parsed_results_base1" + '==' * 10, flush=True) #print(parsed_results_base1, flush=True) #print('==' * 10 + "parsed_results1" + '==' * 10, flush=True) #print(parsed_results1, flush=True) # .numpy() for k, v in parsed_results_base1.items(): if k in sparse_features: continue if k in dense_features + extra_features: if k not in parsed_results1: print("no find {} in parse_example_batch_v2".format(k), flush=True) assert (False) diff(k, v.tolist(), parsed_results1[k].tolist()) #.numpy() else: print("no need {}".format(k), flush=True) assert (False) for k, v in parsed_results1.items(): if k not in dense_features + extra_features: print("no need {}".format(k), flush=True) assert (False) for fid_map_index in range(len(parallel_flag_list)): fid_map = fid_map_list[fid_map_index] fid_map_unique = fid_map_unique_list[fid_map_index] fid_map1_list, fid_map_unique1_list = sess.run( fetches=[fid_map, fid_map_unique], feed_dict={examples_placeholder: example_str_list}) #print('==' * 10 + "fid_map1" + '==' * 10, flush=True) #print(fid_map1, flush=True) #print('==' * 10 + "fid_map_unique1" + '==' * 10, flush=True) #print(fid_map_unique1, flush=True) fid_map1, fid_offset1, feature_offset1, nfl_offset1 = fid_map1_list fid_map2, fid_offset2, feature_offset2, nfl_offset2 = fid_map_unique1_list print('==' * 10 + "diff fidoffset " + str(parallel_flag_list[fid_map_index]) + '==' * 10, flush=True) diff("nfl_offset", list(nfl_offset1), nfl_offset_t) diff("nfl_offset2", list(nfl_offset2), nfl_offset_t) diff("feature_offset", list(feature_offset1), feature_offset_t) diff("feature_offset2", list(feature_offset2), feature_offset_t) diff("fid_offset", list(fid_offset1), fid_offset_list) diff("fid_offset2", list(fid_offset2), fid_offset_list_unique) print('==' * 10 + "fid_map1" + '==' * 10, flush=True) print('==' * 10 + "fid_map_unique1" + '==' * 10, flush=True) assert (len(fid_map_t) == len(fid_map1)) assert (len(fid_map_unique_t) == len(fid_map2)) def fid_diff(a, b): for k, v in a.items(): assert (k in b) diff(k, v.tolist(), b[k], True) #.numpy() for k, v in b.items(): assert (k in a) print('==' * 10 + "diff fid_map1 " + str(parallel_flag_list[fid_map_index]) + '==' * 10, flush=True) fid_diff(fid_map1, fid_map_t) print('==' * 10 + "diff fid_map_unique1 " + str(parallel_flag_list[fid_map_index]) + '==' * 10, flush=True) fid_diff(fid_map2, fid_map_unique_t) #assert (False) def testInstanceSharding(self): fidv1_features = [1, 200, 3, 5, 9, 203, 205] fidv2_features = ["fc_v2_1", "fc_v2_2", "fc_v2_3"] dense_features = ['label'] dense_feature_shapes = [2] dense_feature_types = [tf.float32] extra_features = ['uid', 'req_time', 'item_id', 'actions'] extra_feature_shapes = [1, 1, 1, 2] instance_str_list = [] instance_list = [] while (len(instance_str_list) < 128): instance = gen_instance( fidv1_features=fidv1_features, fidv2_features=[], dense_features=[FeatureMeta('label', shape=2, dtype=tf.float32)], extra_features=[ FeatureMeta('actions', shape=2), FeatureMeta('uid'), FeatureMeta('req_time', dtype=tf.int32), FeatureMeta('item_id'), ]) instance_list.append(instance) instance_str_list.append(instance.SerializeToString()) instance2 = deepcopy(instance) for slot, feature_name in enumerate(fidv2_features): feature = instance2.feature.add() feature.name = feature_name feature.fid.extend(gen_fids_v2(1000 + slot, 10)) instance_list.append(instance2) instance_str_list.append(instance2.SerializeToString()) instance3 = deepcopy(instance2) del instance3.fid[:] instance_list.append(instance3) instance_str_list.append(instance3.SerializeToString()) def gen_slot_feature_name(slot_id): return f"slot_{slot_id}" sparse_features = sorted( fidv2_features + [gen_slot_feature_name(slot_id) for slot_id in fidv1_features]) print(sparse_features, flush=True) feature_cfgs = FeatureConfigs() index = 0 ps_num = 3 table_name_index_map = {} for sparse_key in sparse_features: cfg = FeatureConfig() cfg.table = 'table_{}'.format(index % 3) table_name_index_map[cfg.table] = -1 feature_cfgs.feature_configs[sparse_key].CopyFrom(cfg) index += 1 sparse_features.sort() table_name_list = list(table_name_index_map.keys()) table_name_list.sort() for index, table_name in enumerate(table_name_list): table_name_index_map[table_name] = index fid_map_t = defaultdict(list) fid_map_unique_t = defaultdict(list) fid_map_unique_map = defaultdict(dict) mask = (1 << 48) - 1 def handle_feature(value_list, table_name, table_index, fid_offset_list, fid_offset_list2): for value in value_list: shard = value % ps_num key = table_name + ":" + str(shard) fid_offset = (table_index * ps_num + shard) << 32 fid_offset_list.append(fid_offset | len(fid_map_t[key])) fid_map_t[key].append(value) if value not in fid_map_unique_map[key]: fid_map_unique_map[key][value] = len(fid_map_unique_map[key]) fid_map_unique_t[key].append(value) fid_offset_list2.append(fid_offset | fid_map_unique_map[key][value]) def slot_id_v1(fid): return fid >> 54 intance_tmp_dict = defaultdict(list) for instance in instance_list: fid_v2_list = defaultdict(list) for fid in instance.fid: slot_id = slot_id_v1(fid) sparse_key = gen_slot_feature_name(slot_id) if sparse_key not in sparse_features: continue fid_v2 = ((slot_id << 48) | (mask & fid)) fid_v2_list[sparse_key].append(fid_v2) for feature in instance.feature: sparse_key = feature.name if sparse_key not in sparse_features: continue fid_v2_list[sparse_key] = feature.fid for sparse_key in sparse_features: if sparse_key not in fid_v2_list: fid_v2_list[sparse_key] = [] fid_list = fid_v2_list[sparse_key] intance_tmp_dict[sparse_key].append(fid_list) fid_offset_map = defaultdict(list) fid_offset_map_unique = defaultdict(list) for sparse_key in sparse_features: for fid_list in intance_tmp_dict[sparse_key]: fid_offset_list = [] fid_offset_list2 = [] table_name = feature_cfgs.feature_configs[sparse_key].table table_index = table_name_index_map[table_name] handle_feature(fid_list, table_name, table_index, fid_offset_list, fid_offset_list2) fid_offset_map[sparse_key].append(fid_offset_list) fid_offset_map_unique[sparse_key].append(fid_offset_list2) feature_offset_t = [] nfl_offset_t = [] fid_offset_list = [] fid_offset_list_unique = [] for sparse_key in sparse_features: nfl_offset = len(feature_offset_t) nfl_offset_t.append(nfl_offset) if sparse_key not in fid_offset_map: continue for fid_list in fid_offset_map[sparse_key]: feature_offset_t.append(len(fid_offset_list)) fid_offset_list.extend(fid_list) for fid_list in fid_offset_map_unique[sparse_key]: fid_offset_list_unique.extend(fid_list) feature_offset_t.append(len(fid_offset_list)) nfl_offset_t.append(len(feature_offset_t)) print('==' * 10 + "fid_map_t" + '==' * 10, flush=True) #print(fid_map_t, flush=True) print('==' * 10 + "fid_map_unique_t" + '==' * 10, flush=True) #print(fid_map_unique_t, flush=True) print('==' * 10 + "fid_offset_map" + '==' * 10) print("xxxxx:", len(feature_offset_t), len(instance_list)) #print(fid_offset_map) #example_tensor = tf.convert_to_tensor(example_str_list) with tf.Graph().as_default(): config = tf.compat.v1.ConfigProto() config.graph_options.rewrite_options.disable_meta_optimizer = True examples_placeholder = tf.compat.v1.placeholder(dtype=tf.string, shape=(None)) get_default_parser_ctx().enable_fused_layout = False parsed_results_base = parse_instances( examples_placeholder, fidv1_features=fidv1_features, fidv2_features=fidv2_features, dense_features=dense_features, dense_feature_shapes=dense_feature_shapes, dense_feature_types=dense_feature_types, extra_features=extra_features, extra_feature_shapes=extra_feature_shapes) get_default_parser_ctx().enable_fused_layout = True parsed_results = parse_instances( examples_placeholder, fidv1_features=fidv1_features, fidv2_features=fidv2_features, dense_features=dense_features, dense_feature_shapes=dense_feature_shapes, dense_feature_types=dense_feature_types, extra_features=extra_features, extra_feature_shapes=extra_feature_shapes) examples_varint = parsed_results.pop( ParserCtx.sharding_sparse_fids_sparse_features_key) parallel_flag_list = [0, 1] fid_map_list = [] fid_map_unique_list = [] for parallel_flag in parallel_flag_list: fid_map, fid_offset, feature_offset, nfl_offset, batch_size, nfl_size, feature_size, fid_size, emb_size, \ fid_list_row_splits, fid_list_row_splits_size, fid_list_emb_row_lenth, \ fid_list_table_row_length, fid_list_shard_row_lenth = sharding_sparse_fids( examples_varint, ps_num, feature_cfgs, False, "instance", parallel_flag, version=1) fid_map_list.append([fid_map, fid_offset, feature_offset, nfl_offset]) fid_map_unique, fid_offset, feature_offset, nfl_offset, batch_size, nfl_size, feature_size, fid_size, emb_size, \ fid_list_row_splits, fid_list_row_splits_size, fid_list_emb_row_lenth, \ fid_list_table_row_length, fid_list_shard_row_lenth = sharding_sparse_fids( examples_varint, ps_num, feature_cfgs, True, "instance", parallel_flag, version=1) fid_map_unique_list.append( [fid_map_unique, fid_offset, feature_offset, nfl_offset]) with self.session(config=config) as sess: parsed_results_base1, parsed_results1 = sess.run( fetches=[parsed_results_base, parsed_results], feed_dict={examples_placeholder: instance_str_list}) def diff(k, a, b, sort=False): if not isinstance(a[0], list) and sort: a.sort() b.sort() def print_func(): print("diff:a {} {}".format(k, a), flush=True) print("diff:b {} {}".format(k, b), flush=True) return "{}, {}".format(len(a), len(b)) assert (len(a) == len(b)), print_func() if (len(a) == 0): return for i in range(len(a)): if isinstance(a[i], list): assert isinstance(b[i], list), print_func() diff(k + "/" + str(i), a[i], b[i], sort) else: assert (a[i] == b[i]), print_func() print('==' * 10 + "parsed_results_base1" + '==' * 10, flush=True) print(parsed_results_base1, flush=True) print('==' * 10 + "parsed_results1" + '==' * 10, flush=True) print(parsed_results1, flush=True) # .numpy() for k, v in parsed_results_base1.items(): if k in sparse_features: continue if k in dense_features + extra_features: if k not in parsed_results1: print("no find {} in parse_example_batch_v2".format(k), flush=True) assert (False) diff(k, v.tolist(), parsed_results1[k].tolist()) #.numpy() else: print("no need {}".format(k), flush=True) assert (False) for k, v in parsed_results1.items(): if k not in dense_features + extra_features: print("no need {}".format(k), flush=True) assert (False) for fid_map_index in range(len(parallel_flag_list)): fid_map = fid_map_list[fid_map_index] fid_map_unique = fid_map_unique_list[fid_map_index] fid_map1_list, fid_map_unique1_list = sess.run( fetches=[fid_map, fid_map_unique], feed_dict={examples_placeholder: instance_str_list}) #print('==' * 10 + "fid_map1" + '==' * 10, flush=True) #print('==' * 10 + "fid_map_unique1" + '==' * 10, flush=True) fid_map1, fid_offset1, feature_offset1, nfl_offset1 = fid_map1_list fid_map2, fid_offset2, feature_offset2, nfl_offset2 = fid_map_unique1_list print('==' * 10 + "diff fidoffset " + str(parallel_flag_list[fid_map_index]) + '==' * 10, flush=True) diff("nfl_offset", list(nfl_offset1), nfl_offset_t) diff("nfl_offset2", list(nfl_offset2), nfl_offset_t) diff("feature_offset", list(feature_offset1), feature_offset_t) diff("feature_offset2", list(feature_offset2), feature_offset_t) diff("fid_offset", list(fid_offset1), fid_offset_list) diff("fid_offset2", list(fid_offset2), fid_offset_list_unique) print('==' * 10 + "fid_map1" + '==' * 10, flush=True) #print(fid_map1, flush=True) print('==' * 10 + "fid_map_unique1" + '==' * 10, flush=True) #print(fid_map_unique1, flush=True) assert (len(fid_map_t) == len(fid_map1)) assert (len(fid_map_unique_t) == len(fid_map2)) def fid_diff(a, b): for k, v in a.items(): assert (k in b) diff(k, v.tolist(), b[k], True) #.numpy() for k, v in b.items(): assert (k in a) print('==' * 10 + "diff fid_map1 " + str(parallel_flag_list[fid_map_index]) + '==' * 10, flush=True) fid_diff(fid_map1, fid_map_t) print('==' * 10 + "diff fid_map_unique1 " + str(parallel_flag_list[fid_map_index]) + '==' * 10, flush=True) fid_diff(fid_map2, fid_map_unique_t) #assert (False) if __name__ == '__main__': tf.compat.v1.disable_eager_execution() tf.test.main() ================================================ FILE: monolith/native_training/data/parsers.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from absl import logging, flags import os import struct from copy import deepcopy from typing import Dict, List, Iterable, Callable from collections import deque from dataclasses import dataclass import traceback import tensorflow as tf from idl.matrix.proto.line_id_pb2 import LineId from idl.matrix.proto.example_pb2 import FeatureConfigs from monolith.utils import get_libops_path from monolith.native_training import logging_ops from monolith.native_training import native_task_context from monolith.native_training.monolith_export import monolith_export from monolith.native_training.data.feature_list import get_feature_name_and_slot, FeatureList from monolith.native_training.data.data_op_config_pb2 import LabelConf, TaskLabelConf from monolith.native_training.data.feature_list import add_feature, is_example_batch from monolith.native_training.runtime.ops import gen_monolith_ops from monolith.native_training.utils import add_to_collections from tensorflow.python.framework import ops from tensorflow.python.framework import common_shapes FLAGS = flags.FLAGS parse_instance_ops = gen_monolith_ops _line_id_descriptor = LineId.DESCRIPTOR _default_parser_ctx = None FLAGS = flags.FLAGS @dataclass class ShardingSparseFidsOpParams: num_ps: int use_native_multi_hash_table: bool unique: Callable transfer_float16: bool sub_table_name_to_config: Dict feature_configs: FeatureConfigs enable_gpu_emb: bool use_gpu: bool class ParserCtx(object): enable_resource_constrained_roughsort = False sharding_sparse_fids_features_prefix = "__sharding_sparse_fids__" sharding_sparse_fids_sparse_features_key = "__sharding_sparse_fids__sparse_features" def __init__(self, enable_fused_layout: bool = False): self._old_parser_ctx = None self.parser_type = None self.enable_fused_layout = enable_fused_layout self.sharding_sparse_fids_op_params: ShardingSparseFidsOpParams = None self._ctx_kv = {} def __enter__(self): global _default_parser_ctx self._old_parser_ctx = _default_parser_ctx _default_parser_ctx = self return self def __exit__(self, exc_type, exc_val, exc_tb): global _default_parser_ctx _default_parser_ctx = self._old_parser_ctx self._old_parser_ctx = None @classmethod def sharding_sparse_fids_features_insert_to_features(cls, inputs, features): if isinstance(inputs, Dict): for k, v in inputs.items(): if isinstance(v, Dict): for sub_k, sub_v in v.items(): features[cls.sharding_sparse_fids_features_prefix + k + "/" + sub_k] = sub_v elif isinstance(v, List): raise ValueError("not support") else: features[cls.sharding_sparse_fids_features_prefix + k] = v else: raise ValueError("not support") @classmethod def sharding_sparse_fids_features_parse_from_features(cls, features): outputs = {} pop_key_list = [] for k, v in features.items(): if k.startswith(cls.sharding_sparse_fids_features_prefix): name = k[len(cls.sharding_sparse_fids_features_prefix):] level_ouput = outputs level_names = name.split("/") for level_name in level_names[:-1]: if level_name not in level_ouput: level_ouput[level_name] = {} level_ouput = level_ouput[level_name] level_ouput[level_names[-1]] = v pop_key_list.append(k) for k in pop_key_list: features.pop(k) return outputs def set(self, key, value): self._ctx_kv[key] = value def get(self, key, default_value=None): return self._ctx_kv.get(key, default_value) def get_default_parser_ctx() -> ParserCtx: global _default_parser_ctx if _default_parser_ctx is None: _default_parser_ctx = ParserCtx(False) return _default_parser_ctx class ProtoType: TYPE_BOOL: int = 8 TYPE_BYTES: int = 12 TYPE_DOUBLE: int = 1 TYPE_ENUM: int = 14 TYPE_FIXED32: int = 7 TYPE_FIXED64: int = 6 TYPE_FLOAT: int = 2 TYPE_GROUP: int = 10 TYPE_INT32: int = 5 TYPE_INT64: int = 3 TYPE_MESSAGE: int = 11 TYPE_SFIXED32: int = 15 TYPE_SFIXED64: int = 16 TYPE_SINT32: int = 17 TYPE_SINT64: int = 18 TYPE_STRING: int = 9 TYPE_UINT32: int = 13 TYPE_UINT64: int = 4 UNKNOWN = {TYPE_BOOL, TYPE_ENUM, TYPE_GROUP, TYPE_MESSAGE} STRING = {TYPE_BYTES, TYPE_STRING} FLOAT = {TYPE_FLOAT, TYPE_DOUBLE} INT = { TYPE_INT32, TYPE_INT64, TYPE_SINT32, TYPE_SINT64, TYPE_UINT32, TYPE_UINT64, TYPE_FIXED32, TYPE_FIXED64, TYPE_SFIXED32, TYPE_SFIXED64 } @classmethod def get_tf_type(cls, proto_type: int): if proto_type in cls.INT: return tf.int64 elif proto_type in cls.FLOAT: return tf.float32 elif proto_type in cls.STRING: return tf.string else: raise Exception('proto_type {} is not support'.format(proto_type)) def _add_dense_features(names: List[str], shapes: List[int], types: List[tf.compat.v1.dtypes.DType], dense_features: List[str], dense_feature_shapes: List[int], dense_feature_types: List[tf.compat.v1.dtypes.DType]): assert dense_features is not None assert dense_feature_shapes is not None assert len(dense_features) == len(dense_feature_shapes) assert all([s > 0 for s in dense_feature_shapes]) if dense_feature_types is None: dense_feature_types = [tf.float32] * len(dense_features) else: assert len(dense_features) == len(dense_feature_types) names.extend(dense_features) shapes.extend(dense_feature_shapes) types.extend(dense_feature_types) def _add_extra_features(names: List[str], shapes: List[int], types: List[tf.compat.v1.dtypes.DType], extra_features: List[str], extra_feature_shapes: List[int]): assert extra_features is not None assert extra_feature_shapes is not None assert len(extra_features) == len(extra_feature_shapes) assert all([s > 0 for s in extra_feature_shapes]) extra_dtypes = [] for name in extra_features: try: extra_dtypes.append( ProtoType.get_tf_type(_line_id_descriptor.fields_by_name[name].type)) except: raise Exception(f"{name} is not in line id, pls check!") names.extend(extra_features) shapes.extend(extra_feature_shapes) types.extend(extra_dtypes) def _assemble(sparse_features, names, shapes, types, out_list, batch_size: int = None): assert len(out_list) == len(types) features = {} for i, name in enumerate(names): if name in sparse_features: value = out_list[i + len(names)] if batch_size: batch_size = batch_size[0] if isinstance(batch_size, (list, tuple)) else batch_size split = tf.reshape(out_list[i], shape=(batch_size + 1,)) else: split = out_list[i] features[name] = tf.RaggedTensor.from_row_splits(value, split, validate=False) else: features[name] = out_list[i] return features def parse_instances(tensor: tf.Tensor, fidv1_features: List[int] = None, fidv2_features: List[str] = None, dense_features: List[str] = None, dense_feature_shapes: List[int] = None, dense_feature_types: List[tf.compat.v1.dtypes.DType] = None, extra_features: List[str] = None, extra_feature_shapes: List[int] = None): """从Tensor中解析instance Example格式中, 所有特征均存于feature中, 没有平铺的特征. Sparse特征由于长度不定, 输出RaggedTensor, 其它特征输出Tensor Args: tensor (:obj:`tf.Tensor`): 输入样本 fidv1_features (:obj:`List[int]`): 在Instance中, fidv1_features是平铺的, 所以用slot指定, 可以是部分slot fidv2_features (:obj:`List[str]`): 在Instance中, fidv2_features存放于feature中, 可以用名字指定, 可以是部分特征名 dense_features (:obj:`List[str]`): 稠密特征(或Label)名称, 可以有多个, 也可以有不同类型 dense_feature_shapes (:obj:`List[int]`): 稠密特征名称的shape dense_feature_types (:obj:`List[dtype]`): 稠密特征名称的数据类型, 默认为`tf.float32` extra_features (:obj:`List[str]`): 主要指LineId中的字段, 可以有多个, Monolith会自动从LineId中提取数据类型 extra_feature_shapes (:obj:`List[int]`): extra特征名称的shape Returns: Dict[str, Tensor] 解析出特征名到特征的字典 """ if ParserCtx.enable_resource_constrained_roughsort: if extra_features is None: extra_features = ["item_id"] extra_feature_shapes = [1] elif "item_id" not in extra_features: extra_features.append("item_id") extra_feature_shapes.append(1) if dense_features: assert dense_feature_shapes is not None assert len(dense_feature_shapes) == len(dense_features) if dense_feature_types: assert len(dense_feature_types) == len(dense_features) else: dense_feature_types = [tf.float32] * len(dense_features) get_default_parser_ctx().parser_type = 'instance' add_to_collections('fidv1_features', fidv1_features) add_to_collections('fidv2_features', fidv2_features) add_to_collections('dense_features', dense_features) add_to_collections('dense_feature_shapes', dense_feature_shapes) add_to_collections('dense_feature_types', dense_feature_types) add_to_collections('extra_features', extra_features) add_to_collections('extra_feature_shapes', extra_feature_shapes) add_to_collections('variant_type', 'instance') get_default_parser_ctx().set('fidv1_features', fidv1_features) get_default_parser_ctx().set('fidv2_features', fidv2_features) names, shapes, types = [], [], [] if not get_default_parser_ctx().enable_fused_layout: sparse_features = [] if fidv1_features is not None: names.extend( [get_feature_name_and_slot(slot)[0] for slot in fidv1_features]) if all(isinstance(feature_name, str) for feature_name in fidv1_features): try: feature_list = FeatureList.parse() fidv1_features = [ feature_list.get(feature_name).slot for feature_name in fidv1_features ] except: raise RuntimeError("fidv1_features error") shapes.extend([-1] * len(fidv1_features)) types.extend([tf.int64] * len(fidv1_features)) if fidv2_features is not None: names.extend(fidv2_features) shapes.extend([-1] * len(fidv2_features)) types.extend([tf.int64] * len(fidv2_features)) sparse_features.extend(names) if dense_features is not None: _add_dense_features(names, shapes, types, dense_features, dense_feature_shapes, dense_feature_types) if extra_features is not None: _add_extra_features(names, shapes, types, extra_features, extra_feature_shapes) if get_default_parser_ctx().enable_fused_layout: if len(names) == 0: names.append("__FAKE_FEATURE__") shapes.append(1) types.append(tf.float32) out_list, instances = parse_instance_ops.parse_instances_v2( tensor, [], [], names, shapes, types, extra_features or []) features = _assemble([], names, shapes, types, out_list) parser_ctx = get_default_parser_ctx() if parser_ctx.sharding_sparse_fids_op_params is not None and (parser_ctx.sharding_sparse_fids_op_params.use_gpu or FLAGS.dataset_use_dataservice): sharding_sparse_fids_with_context(instances, features, parser_ctx) else: features[ParserCtx.sharding_sparse_fids_sparse_features_key] = instances if "__FAKE_FEATURE__" in features: del features["__FAKE_FEATURE__"] return features else: types.extend([tf.int64] * len(sparse_features)) assert len(names) == len(set(names)), "deplicate names, pls check!" out_list = parse_instance_ops.parse_instances(tensor, fidv1_features or [], fidv2_features or [], names, shapes, types, extra_features or []) return _assemble(sparse_features, names, shapes, types, out_list) @monolith_export def parse_examples(tensor: tf.Tensor, sparse_features: List[str], dense_features: List[str] = None, dense_feature_shapes: List[int] = None, dense_feature_types: List[tf.compat.v1.dtypes.DType] = None, extra_features: List[str] = None, extra_feature_shapes: List[int] = None): """从Tensor中解析example Example格式中, 所有特征均存于feature中, 没有平铺特征. Sparse特征由于长度不定, 输出RaggedTensor, 其它特征输出Tensor Args: tensor (:obj:`tf.Tensor`): 输入样本 sparse_features (:obj:`List[str]`): 稀疏特征名称, 可以有多个 dense_features (:obj:`List[str]`): 稠密特征(或Label)名称, 可以有多个, 也可以有不同类型 dense_feature_shapes (:obj:`List[int]`): 稠密特征名称的shape dense_feature_types (:obj:`List[dtype]`): 稠密特征名称的数据类型, 默认为`tf.float32` extra_features (:obj:`List[str]`): 主要指LineId中的字段, 可以有多个, Monolith会自动从LineId中提取数据类型 extra_feature_shapes (:obj:`List[int]`): extra特征名称的shape Returns: Dict[str, Tensor] 解析出特征名到特征的字典 """ if dense_features: assert dense_feature_shapes is not None assert len(dense_feature_shapes) == len(dense_features) if dense_feature_types: assert len(dense_feature_types) == len(dense_features) else: dense_feature_types = [tf.float32] * len(dense_features) get_default_parser_ctx().parser_type = 'example' add_to_collections('sparse_features', sparse_features) add_to_collections('dense_features', dense_features) add_to_collections('dense_feature_shapes', dense_feature_shapes) add_to_collections('dense_feature_types', dense_feature_types) add_to_collections('extra_features', extra_features) add_to_collections('extra_feature_shapes', extra_feature_shapes) add_to_collections('variant_type', 'example') get_default_parser_ctx().set('sparse_features', sparse_features) if is_example_batch(): add_feature(sparse_features) if dense_features: if 'label' in dense_features: add_feature('__LABEL__') add_feature([feat for feat in dense_features if feat != 'label']) if extra_features: add_feature('__LINE_ID__') names, shapes, types = [], [], [] if not get_default_parser_ctx().enable_fused_layout: assert sparse_features is not None names.extend(sparse_features) shapes.extend([-1] * len(sparse_features)) types.extend([tf.int64] * len(sparse_features)) if dense_features is not None: _add_dense_features(names, shapes, types, dense_features, dense_feature_shapes, dense_feature_types) if extra_features is not None: _add_extra_features(names, shapes, types, extra_features, extra_feature_shapes) assert len(names) == len(set(names)), "deplicate names, pls check!" if get_default_parser_ctx().enable_fused_layout: if len(names) == 0: names.append("__FAKE_FEATURE__") shapes.append(1) types.append(tf.float32) out_list, examples = parse_instance_ops.parse_examples_v2( tensor, names, shapes, types, extra_features or []) features = _assemble([], names, shapes, types, out_list) parser_ctx = get_default_parser_ctx() if parser_ctx.sharding_sparse_fids_op_params is not None and (parser_ctx.sharding_sparse_fids_op_params.use_gpu or FLAGS.dataset_use_dataservice): sharding_sparse_fids_with_context(examples, features, parser_ctx) else: features[ParserCtx.sharding_sparse_fids_sparse_features_key] = examples if "__FAKE_FEATURE__" in features: del features["__FAKE_FEATURE__"] return features else: types.extend([tf.int64] * len(sparse_features)) out_list = parse_instance_ops.parse_examples(tensor, names, shapes, types, extra_features or []) return _assemble(sparse_features, names, shapes, types, out_list) @monolith_export def parse_example_batch( tensor: tf.Tensor, sparse_features: List[str], dense_features: List[str] = None, dense_feature_shapes: List[int] = None, dense_feature_types: List[tf.compat.v1.dtypes.DType] = None, extra_features: List[str] = None, extra_feature_shapes: List[int] = None): """从Tensor中解析example_batch Example_batch格式中, 所有特征均存于feature中, 没有平铺特征. Sparse特征由于长度不定, 输出RaggedTensor, 其它特征输出Tensor Args: tensor (:obj:`tf.Tensor`): 输入样本 sparse_features (:obj:`List[str]`): 稀疏特征名称, 可以有多个 dense_features (:obj:`List[str]`): 稠密特征(或Label)名称, 可以有多个, 也可以有不同类型 dense_feature_shapes (:obj:`List[int]`): 稠密特征名称的shape dense_feature_types (:obj:`List[dtype]`): 稠密特征名称的数据类型, 默认为`tf.float32` extra_features (:obj:`List[str]`): 主要指LineId中的字段, 可以有多个, Monolith会自动从LineId中提取数据类型 extra_feature_shapes (:obj:`List[int]`): extra特征名称的shape Returns: Dict[str, Tensor] 解析出特征名到特征的字典 """ if dense_features: assert dense_feature_shapes is not None assert len(dense_feature_shapes) == len(dense_features) if dense_feature_types: assert len(dense_feature_types) == len(dense_features) else: dense_feature_types = [tf.float32] * len(dense_features) get_default_parser_ctx().parser_type = 'examplebatch' add_to_collections('sparse_features', sparse_features) add_to_collections('dense_features', dense_features) add_to_collections('dense_feature_shapes', dense_feature_shapes) add_to_collections('dense_feature_types', dense_feature_types) add_to_collections('extra_features', extra_features) add_to_collections('extra_feature_shapes', extra_feature_shapes) add_to_collections('variant_type', 'example_batch') get_default_parser_ctx().set('sparse_features', sparse_features) if is_example_batch(): add_feature(sparse_features) if dense_features: if 'label' in dense_features: add_feature('__LABEL__') add_feature([feat for feat in dense_features if feat != 'label']) if extra_features: add_feature('__LINE_ID__') names, shapes, types = [], [], [] if not get_default_parser_ctx().enable_fused_layout: assert sparse_features is not None names.extend(sparse_features) shapes.extend([-1] * len(sparse_features)) types.extend([tf.int64] * len(sparse_features)) if dense_features is not None: _add_dense_features(names, shapes, types, dense_features, dense_feature_shapes, dense_feature_types) if extra_features is not None: _add_extra_features(names, shapes, types, extra_features, extra_feature_shapes) batch_size = get_default_parser_ctx().get('batch_size') assert len(names) == len(set(names)), "deplicate names, pls check!" if get_default_parser_ctx().enable_fused_layout: if len(names) == 0: names.append("__FAKE_FEATURE__") shapes.append(1) types.append(tf.float32) out_list, example_batch = parse_instance_ops.parse_example_batch_v2( tensor, names, shapes, types, extra_features or []) features = _assemble([], names, shapes, types, out_list) parser_ctx = get_default_parser_ctx() if parser_ctx.sharding_sparse_fids_op_params is not None and (parser_ctx.sharding_sparse_fids_op_params.use_gpu or FLAGS.dataset_use_dataservice): sharding_sparse_fids_with_context(example_batch, features, parser_ctx) else: features[ ParserCtx.sharding_sparse_fids_sparse_features_key] = example_batch if "__FAKE_FEATURE__" in features: del features["__FAKE_FEATURE__"] return features else: types.extend([tf.int64] * len(sparse_features)) out_list = parse_instance_ops.parse_example_batch(tensor, names, shapes, types, extra_features or []) return _assemble(sparse_features, names, shapes, types, out_list, batch_size=batch_size) @monolith_export def sharding_sparse_fids(tensor: tf.Tensor, ps_num: int, feature_cfgs: FeatureConfigs, unique: bool, input_type: str, parallel_flag: int = 0, fid_list_ret_list: bool = False, version: int = 5): assert input_type in ["example", "examplebatch", "example_batch", "instance"] input_type = 'examplebatch' if input_type == 'example_batch' else input_type table_name_list = [] for cfg in feature_cfgs.feature_configs.values(): if cfg.table not in table_name_list: table_name_list.append(cfg.table) table_name_list.sort() ps_num = 1 if ps_num == 0 else ps_num logging.info( f"num of multi_type_hashtable is {ps_num} {len(table_name_list)}: [{table_name_list}]" ) table_count = len(table_name_list) * ps_num fid_list_emb_row_lenth = None fid_list_table_row_length = None fid_list_shard_row_lenth = None fid_list_row_splits_size = None nfl_size = None feature_size = None fid_size = None emb_size = None (tensor,), start_ts = logging_ops.tensors_timestamp([tensor]) if version == 5: fid_list, fid_list_row_splits, fid_list_row_splits_size, fid_offset, feature_offset, nfl_offset, batch_size, nfl_size, feature_size, fid_size, emb_size = parse_instance_ops.sharding_sparse_fids_v5( pb_input=tensor, ps_num=ps_num, feature_cfgs=feature_cfgs.SerializeToString(), N=table_count, unique=unique, input_type=input_type, parallel_flag=parallel_flag) elif version == 4: fid_list, fid_list_row_splits, fid_list_table_row_length, fid_list_shard_row_lenth, fid_list_emb_row_lenth, fid_offset, feature_offset, nfl_offset, batch_size = parse_instance_ops.sharding_sparse_fids_v4( pb_input=tensor, ps_num=ps_num, feature_cfgs=feature_cfgs.SerializeToString(), unique=unique, input_type=input_type, parallel_flag=parallel_flag) elif version == 3: fid_list, fid_list_row_splits, fid_offset, feature_offset, nfl_offset, batch_size = parse_instance_ops.sharding_sparse_fids_v3( pb_input=tensor, ps_num=ps_num, feature_cfgs=feature_cfgs.SerializeToString(), N=table_count, unique=unique, input_type=input_type, parallel_flag=parallel_flag, single_thread_feature_watermark=4*80000) fid_list_row_splits_size = [None] * table_count elif version == 2: fid_list, fid_list_row_splits, fid_offset, feature_offset, nfl_offset, batch_size = parse_instance_ops.sharding_sparse_fids_v2( pb_input=tensor, ps_num=ps_num, feature_cfgs=feature_cfgs.SerializeToString(), N=table_count, unique=unique, input_type=input_type, parallel_flag=parallel_flag) fid_list_row_splits_size = [None] * table_count else: fid_list, fid_offset, feature_offset, nfl_offset, batch_size = parse_instance_ops.sharding_sparse_fids( pb_input=tensor, ps_num=ps_num, feature_cfgs=feature_cfgs.SerializeToString(), N=table_count, unique=unique, input_type=input_type, parallel_flag=parallel_flag) fid_list_row_splits = [None] * table_count fid_list_row_splits_size = [None] * table_count (fid_offset,), end_ts = logging_ops.tensors_timestamp([fid_offset]) def emit_sharding_sparse_timer_ops(interval): return [ logging_ops.emit_timer( "sharding_sparse_fids", tf.cast(interval, tf.float32), tags={ "model_name": native_task_context.get().model_name }) ] with tf.control_dependencies(emit_sharding_sparse_timer_ops(end_ts - start_ts)): tf.no_op() if version != 4: assert len(fid_list) == table_count assert len(fid_list_row_splits) == table_count if version == 5: assert len(fid_list_row_splits_size) == table_count if fid_list_ret_list or version == 4: return fid_list, fid_offset, feature_offset, nfl_offset, batch_size, nfl_size, feature_size, fid_size, emb_size, fid_list_row_splits, fid_list_row_splits_size, fid_list_emb_row_lenth, fid_list_table_row_length, fid_list_shard_row_lenth ret = {} ret_row_split = {} ret_row_split_size = {} index = 0 for table_idx in range(len(table_name_list)): table_name = table_name_list[table_idx] for ps_index in range(ps_num): ret[table_name + ":" + str(ps_index)] = fid_list[index] ret_row_split[table_name + ":" + str(ps_index)] = fid_list_row_splits[index] ret_row_split_size[table_name + ":" + str(ps_index)] = fid_list_row_splits_size[index] index += 1 return ret, fid_offset, feature_offset, nfl_offset, batch_size, nfl_size, feature_size, fid_size, emb_size, ret_row_split, ret_row_split_size, fid_list_emb_row_lenth, fid_list_table_row_length, fid_list_shard_row_lenth def sharding_sparse_fids_with_context(sparse_features: tf.Tensor, features, parser_ctx: ParserCtx = None): if parser_ctx is None: parser_ctx = get_default_parser_ctx() shards, fid_offset, feature_offset, nfl_offset, batch_size, nfl_size, \ feature_size, fid_size, emb_size, shards_row_split, shards_row_split_size, \ fid_list_emb_row_lenth, fid_list_table_row_length, fid_list_shard_row_lenth = sharding_sparse_fids( sparse_features, ps_num=parser_ctx.sharding_sparse_fids_op_params.num_ps, feature_cfgs=parser_ctx.sharding_sparse_fids_op_params.feature_configs, unique=parser_ctx.sharding_sparse_fids_op_params.unique(), input_type=parser_ctx.parser_type, fid_list_ret_list=parser_ctx.sharding_sparse_fids_op_params.enable_gpu_emb, parallel_flag=1 if parser_ctx.sharding_sparse_fids_op_params.use_gpu else 0, version=4 if parser_ctx.sharding_sparse_fids_op_params.enable_gpu_emb else 5) if parser_ctx.sharding_sparse_fids_op_params.enable_gpu_emb: ''' table_count = len(shards) // parser_ctx.sharding_sparse_fids_op_params.num_ps def tensor_list_to_ragged_tensor(tensor_list): if len(tensor_list) == 1: shard_ragged_tensor = tf.RaggedTensor.from_row_starts( tensor_list[0], tf.constant([0], dtype=tf.int32), validate=False) else: shard_ragged_tensor = tf.ragged.stack(tensor_list) return shard_ragged_tensor shards_new_order_list = [] for ps_i in range(parser_ctx.sharding_sparse_fids_op_params.num_ps): for table_i in range(table_count): shards_new_order_list.append(shards[ps_i + table_i * parser_ctx.sharding_sparse_fids_op_params.num_ps]) shard_ragged_tensor = tensor_list_to_ragged_tensor(shards_new_order_list) shards_value = shard_ragged_tensor.values shards_table_row_lengths = shard_ragged_tensor.row_lengths() shards_row_lengths = tf.reduce_sum(tf.reshape(shards_table_row_lengths, [parser_ctx.sharding_sparse_fids_op_params.num_ps, table_count]), axis=-1) ''' shards_value = shards shards_row_lengths = fid_list_shard_row_lenth shards_table_row_lengths = fid_list_table_row_length parser_ctx.sharding_sparse_fids_features_insert_to_features( { "shards_value": shards_value, "shards_row_lengths": shards_row_lengths, "shards_table_row_lengths": shards_table_row_lengths, "fid_offset": fid_offset, "feature_offset": feature_offset, "nfl_offset": nfl_offset, "batch_size": batch_size, "fid_list_emb_row_lenth": fid_list_emb_row_lenth, }, features) else: features_dict = { "shards": shards, "fid_offset": fid_offset, "feature_offset": feature_offset, "nfl_offset": nfl_offset, "batch_size": batch_size, "nfl_size": nfl_size, "feature_size": feature_size, "fid_size": fid_size, "emb_size": emb_size, } if parser_ctx.sharding_sparse_fids_op_params.use_native_multi_hash_table: features_dict.update({"shards_row_split": shards_row_split}) features_dict.update({"shards_row_split_size": shards_row_split_size}) parser_ctx.sharding_sparse_fids_features_insert_to_features( features_dict, features) def parse_example_batch_list( tensor: List[tf.Tensor], label_config: str = None, positive_label: float = 1.0, negative_label: float = 0.0, names: List[str] = None, shapes: List[int] = None, dtypes: List[tf.dtypes.DType] = None, extra_features: List[str] = None) -> Dict[str, tf.Tensor]: names, shapes, dtypes = list(names), list(shapes), list(dtypes) get_default_parser_ctx().parser_type = 'examplebatch' label_conf = LabelConf() if label_config is not None and len(label_config) > 0: tasks = label_config.split(';') names.append('label') shapes.append(len(tasks)) dtypes.append(tf.float32) for task in tasks: task_conf = label_conf.conf.add() pos_actions, neg_actions = task.split(':') pos_actions_list = [ int(pos) for pos in pos_actions.split(',') if len(pos) > 0 ] neg_actions_list = [ int(neg) for neg in neg_actions.split(',') if len(neg) > 0 ] task_conf.pos_actions.extend(pos_actions_list) task_conf.neg_actions.extend(neg_actions_list) sparse_features = [] for i, name in enumerate(names): if shapes[i] == -1: sparse_features.append(name) dtypes.append(tf.int64) assert len(names) == len(set(names)), "deplicate names, pls check!" out_list = parse_instance_ops.parse_example_batch_list( tensor, label_config=label_conf.SerializeToString(), names=names, shapes=shapes, dtypes=dtypes, extra_names=extra_features, positive_label=positive_label, negative_label=negative_label) return _assemble(sparse_features, names, shapes, dtypes, out_list) ================================================ FILE: monolith/native_training/data/test_data/BUILD ================================================ package(default_visibility = [ "//monolith/integration_test:__subpackages__", "//monolith/native_training:__subpackages__", ]) filegroup( name = "test_feature_lists", srcs = [ ], ) ================================================ FILE: monolith/native_training/data/test_data/mhy.conf ================================================ feed_name=area; shared=true; feature_id=368235 feed_name=att_traced; shared=true; feature_id=368245 feed_name=bhv_scm; shared=true; feature_id=368239 feed_name=bhv_spm; shared=true; feature_id=368240 feed_name=bhv_spm_1; shared=true; feature_id=368241 feed_name=bhv_spm_2; shared=true; feature_id=368242 feed_name=bhv_spm_3; shared=true; feature_id=368243 feed_name=bhv_spm_4; shared=true; feature_id=368244 feed_name=bhv_time_hour; shared=true; feature_id=368237 feed_name=bhv_time_monthday; shared=true; feature_id=368246 feed_name=bhv_time_weekday; shared=true; feature_id=368238 feed_name=city; shared=true; feature_id=368233 feed_name=client_version; shared=true; feature_id=368227 feed_name=country; shared=true; feature_id=368231 feed_name=device_model; shared=true; feature_id=368229 feed_name=district; shared=true; feature_id=368234 feed_name=doc_author_fans_10; feature_id=368214 feed_name=doc_author_id; feature_id=368198 feed_name=doc_author_level; feature_id=368213 feed_name=doc_author_name; feature_id=368212 feed_name=doc_cate1; feature_id=368195 feed_name=doc_cate2; feature_id=368196 feed_name=doc_cate3; feature_id=368197 feed_name=doc_collect_cnt_10; feature_id=368201 feed_name=doc_collection; feature_id=368215 feed_name=doc_comment_cnt_10; feature_id=368203 feed_name=doc_content_length_2; feature_id=368219 feed_name=doc_create_time; feature_id=368220 feed_name=doc_detail_pic_num; feature_id=368209 feed_name=doc_expire_time; feature_id=368222 feed_name=doc_id; feature_id=368192 feed_name=doc_id_post_click_180d; shared=true; feature_id=368733 feed_name=doc_id_post_click_1d; shared=true; feature_id=368730 feed_name=doc_id_post_click_1h; shared=true; feature_id=368728 feed_name=doc_id_post_click_30d; shared=true; feature_id=368732 feed_name=doc_id_post_click_6h; shared=true; feature_id=368729 feed_name=doc_id_post_click_7d; shared=true; feature_id=368731 feed_name=doc_id_post_favorite_180d; shared=true; feature_id=368739 feed_name=doc_id_post_favorite_1d; shared=true; feature_id=368736 feed_name=doc_id_post_favorite_1h; shared=true; feature_id=368734 feed_name=doc_id_post_favorite_30d; shared=true; feature_id=368738 feed_name=doc_id_post_favorite_6h; shared=true; feature_id=368735 feed_name=doc_id_post_favorite_7d; shared=true; feature_id=368737 feed_name=doc_id_post_praise_180d; shared=true; feature_id=368751 feed_name=doc_id_post_praise_1d; shared=true; feature_id=368748 feed_name=doc_id_post_praise_1h; shared=true; feature_id=368746 feed_name=doc_id_post_praise_30d; shared=true; feature_id=368750 feed_name=doc_id_post_praise_6h; shared=true; feature_id=368747 feed_name=doc_id_post_praise_7d; shared=true; feature_id=368749 feed_name=doc_id_post_share_180d; shared=true; feature_id=368745 feed_name=doc_id_post_share_1d; shared=true; feature_id=368742 feed_name=doc_id_post_share_1h; shared=true; feature_id=368740 feed_name=doc_id_post_share_30d; shared=true; feature_id=368744 feed_name=doc_id_post_share_6h; shared=true; feature_id=368741 feed_name=doc_id_post_share_7d; shared=true; feature_id=368743 feed_name=doc_keyword; feature_id=368211 feed_name=doc_location_tag; feature_id=368217 feed_name=doc_pic_url; feature_id=368206 feed_name=doc_praise_cnt_10; feature_id=368202 feed_name=doc_pub_time; feature_id=368221 feed_name=doc_rating; feature_id=368205 feed_name=doc_related_goods_ids; feature_id=368210 feed_name=doc_share_cnt_10; feature_id=368200 feed_name=doc_source_id; feature_id=368204 feed_name=doc_tags; feature_id=368199 feed_name=doc_title_length; feature_id=368218 feed_name=doc_title_terms; feature_id=368193 feed_name=doc_topic_tag; feature_id=368216 feed_name=doc_type; feature_id=368194 feed_name=doc_video_duration_10; feature_id=368208 feed_name=doc_video_url; feature_id=368207 feed_name=fake_context_id; shared=true; feature_id=368230 feed_name=goods_exposure_cnt_lt; feature_id=449545 feed_name=goods_is_prepublic; feature_id=431592 feed_name=goods_op_rec_status; feature_id=368247 feed_name=goods_quality_score; feature_id=431591 feed_name=goods_rec_scene_id; feature_id=408247 feed_name=network; shared=true; feature_id=368228 feed_name=os; shared=true; feature_id=368225 feed_name=os_version; shared=true; feature_id=368226 feed_name=page; shared=true; feature_id=368223 feed_name=platform; shared=true; feature_id=368224 feed_name=province; shared=true; feature_id=368232 feed_name=time; shared=true; feature_id=368236 feed_name=user_age; shared=true; feature_id=368181 feed_name=user_area; shared=true; feature_id=368189 feed_name=user_city; shared=true; feature_id=368187 feed_name=user_country; shared=true; feature_id=368185 feed_name=user_device_id; shared=true; feature_id=368183 feed_name=user_district; shared=true; feature_id=368188 feed_name=user_gender; shared=true; feature_id=368182 feed_name=user_id; shared=true; feature_id=368180 feed_name=user_is_prepublic; shared=true; feature_id=431587 feed_name=user_lt_doc_author_id_cart_cp; shared=true; feature_id=368361 feed_name=user_lt_doc_author_id_click_cp; shared=true; feature_id=368332 feed_name=user_lt_doc_author_id_conversion_cp; shared=true; feature_id=368385 feed_name=user_lt_doc_author_id_favorite_cp; shared=true; feature_id=368333 feed_name=user_lt_doc_author_id_praise_cp; shared=true; feature_id=368334 feed_name=user_lt_doc_author_id_query_cp; shared=true; feature_id=368335 feed_name=user_lt_doc_cate1_cart_cp; shared=true; feature_id=368352 feed_name=user_lt_doc_cate1_click_cp; shared=true; feature_id=368320 feed_name=user_lt_doc_cate1_conversion_cp; shared=true; feature_id=368376 feed_name=user_lt_doc_cate1_favorite_cp; shared=true; feature_id=368321 feed_name=user_lt_doc_cate1_praise_cp; shared=true; feature_id=368322 feed_name=user_lt_doc_cate1_query_cp; shared=true; feature_id=368323 feed_name=user_lt_doc_cate2_cart_cp; shared=true; feature_id=368355 feed_name=user_lt_doc_cate2_click_cp; shared=true; feature_id=368324 feed_name=user_lt_doc_cate2_conversion_cp; shared=true; feature_id=368379 feed_name=user_lt_doc_cate2_favorite_cp; shared=true; feature_id=368325 feed_name=user_lt_doc_cate2_praise_cp; shared=true; feature_id=368326 feed_name=user_lt_doc_cate2_query_cp; shared=true; feature_id=368327 feed_name=user_lt_doc_cate3_cart_cp; shared=true; feature_id=368358 feed_name=user_lt_doc_cate3_click_cp; shared=true; feature_id=368328 feed_name=user_lt_doc_cate3_conversion_cp; shared=true; feature_id=368382 feed_name=user_lt_doc_cate3_favorite_cp; shared=true; feature_id=368329 feed_name=user_lt_doc_cate3_praise_cp; shared=true; feature_id=368330 feed_name=user_lt_doc_cate3_query_cp; shared=true; feature_id=368331 feed_name=user_lt_doc_id_cart_cp; shared=true; feature_id=368346 feed_name=user_lt_doc_id_click_cp; shared=true; feature_id=368312 feed_name=user_lt_doc_id_conversion_cp; shared=true; feature_id=368370 feed_name=user_lt_doc_id_favorite_cp; shared=true; feature_id=368313 feed_name=user_lt_doc_id_praise_cp; shared=true; feature_id=368314 feed_name=user_lt_doc_id_query_cp; shared=true; feature_id=368315 feed_name=user_lt_doc_keyword_cart_cp; shared=true; feature_id=368367 feed_name=user_lt_doc_keyword_click_cp; shared=true; feature_id=368340 feed_name=user_lt_doc_keyword_conversion_cp; shared=true; feature_id=368391 feed_name=user_lt_doc_keyword_favorite_cp; shared=true; feature_id=368341 feed_name=user_lt_doc_keyword_praise_cp; shared=true; feature_id=368342 feed_name=user_lt_doc_keyword_query_cp; shared=true; feature_id=368343 feed_name=user_lt_doc_tags_cart_cp; shared=true; feature_id=368364 feed_name=user_lt_doc_tags_click_cp; shared=true; feature_id=368336 feed_name=user_lt_doc_tags_conversion_cp; shared=true; feature_id=368388 feed_name=user_lt_doc_tags_favorite_cp; shared=true; feature_id=368337 feed_name=user_lt_doc_tags_praise_cp; shared=true; feature_id=368338 feed_name=user_lt_doc_tags_query_cp; shared=true; feature_id=368339 feed_name=user_lt_doc_title_terms_cart_cp; shared=true; feature_id=368349 feed_name=user_lt_doc_title_terms_click_cp; shared=true; feature_id=368316 feed_name=user_lt_doc_title_terms_conversion_cp; shared=true; feature_id=368373 feed_name=user_lt_doc_title_terms_favorite_cp; shared=true; feature_id=368317 feed_name=user_lt_doc_title_terms_praise_cp; shared=true; feature_id=368318 feed_name=user_lt_doc_title_terms_query_cp; shared=true; feature_id=368319 feed_name=user_membership_level; shared=true; feature_id=368184 feed_name=user_province; shared=true; feature_id=368186 feed_name=user_quality_score; shared=true; feature_id=431588 feed_name=user_recent_click_doc_cate1_180d; shared=true; feature_id=368607 feed_name=user_recent_click_doc_cate1_1d; shared=true; feature_id=368604 feed_name=user_recent_click_doc_cate1_1h; shared=true; feature_id=368602 feed_name=user_recent_click_doc_cate1_30d; shared=true; feature_id=368606 feed_name=user_recent_click_doc_cate1_6h; shared=true; feature_id=368603 feed_name=user_recent_click_doc_cate1_7d; shared=true; feature_id=368605 feed_name=user_recent_click_doc_cate2_180d; shared=true; feature_id=368619 feed_name=user_recent_click_doc_cate2_1d; shared=true; feature_id=368616 feed_name=user_recent_click_doc_cate2_1h; shared=true; feature_id=368614 feed_name=user_recent_click_doc_cate2_30d; shared=true; feature_id=368618 feed_name=user_recent_click_doc_cate2_6h; shared=true; feature_id=368615 feed_name=user_recent_click_doc_cate2_7d; shared=true; feature_id=368617 feed_name=user_recent_click_doc_cate3_180d; shared=true; feature_id=368613 feed_name=user_recent_click_doc_cate3_1d; shared=true; feature_id=368610 feed_name=user_recent_click_doc_cate3_1h; shared=true; feature_id=368608 feed_name=user_recent_click_doc_cate3_30d; shared=true; feature_id=368612 feed_name=user_recent_click_doc_cate3_6h; shared=true; feature_id=368609 feed_name=user_recent_click_doc_cate3_7d; shared=true; feature_id=368611 feed_name=user_recent_click_doc_id_180d; shared=true; feature_id=368625 feed_name=user_recent_click_doc_id_1d; shared=true; feature_id=368622 feed_name=user_recent_click_doc_id_1h; shared=true; feature_id=368620 feed_name=user_recent_click_doc_id_30d; shared=true; feature_id=368624 feed_name=user_recent_click_doc_id_6h; shared=true; feature_id=368621 feed_name=user_recent_click_doc_id_7d; shared=true; feature_id=368623 feed_name=user_recent_click_doc_keyword_180d; shared=true; feature_id=368601 feed_name=user_recent_click_doc_keyword_1d; shared=true; feature_id=368598 feed_name=user_recent_click_doc_keyword_1h; shared=true; feature_id=368596 feed_name=user_recent_click_doc_keyword_30d; shared=true; feature_id=368600 feed_name=user_recent_click_doc_keyword_6h; shared=true; feature_id=368597 feed_name=user_recent_click_doc_keyword_7d; shared=true; feature_id=368599 feed_name=user_recent_click_doc_tags_180d; shared=true; feature_id=368631 feed_name=user_recent_click_doc_tags_1d; shared=true; feature_id=368628 feed_name=user_recent_click_doc_tags_1h; shared=true; feature_id=368626 feed_name=user_recent_click_doc_tags_30d; shared=true; feature_id=368630 feed_name=user_recent_click_doc_tags_6h; shared=true; feature_id=368627 feed_name=user_recent_click_doc_tags_7d; shared=true; feature_id=368629 feed_name=user_recent_click_doc_topic_tag_180d; shared=true; feature_id=368595 feed_name=user_recent_click_doc_topic_tag_1d; shared=true; feature_id=368592 feed_name=user_recent_click_doc_topic_tag_1h; shared=true; feature_id=368590 feed_name=user_recent_click_doc_topic_tag_30d; shared=true; feature_id=368594 feed_name=user_recent_click_doc_topic_tag_6h; shared=true; feature_id=368591 feed_name=user_recent_click_doc_topic_tag_7d; shared=true; feature_id=368593 feed_name=user_recent_click_doc_type_180d; shared=true; feature_id=368589 feed_name=user_recent_click_doc_type_1d; shared=true; feature_id=368586 feed_name=user_recent_click_doc_type_1h; shared=true; feature_id=368584 feed_name=user_recent_click_doc_type_30d; shared=true; feature_id=368588 feed_name=user_recent_click_doc_type_6h; shared=true; feature_id=368585 feed_name=user_recent_click_doc_type_7d; shared=true; feature_id=368587 feed_name=user_recent_exposure_doc_cate1_180d; shared=true; feature_id=368703 feed_name=user_recent_exposure_doc_cate1_1d; shared=true; feature_id=368700 feed_name=user_recent_exposure_doc_cate1_1h; shared=true; feature_id=368698 feed_name=user_recent_exposure_doc_cate1_30d; shared=true; feature_id=368702 feed_name=user_recent_exposure_doc_cate1_6h; shared=true; feature_id=368699 feed_name=user_recent_exposure_doc_cate1_7d; shared=true; feature_id=368701 feed_name=user_recent_exposure_doc_cate2_180d; shared=true; feature_id=368715 feed_name=user_recent_exposure_doc_cate2_1d; shared=true; feature_id=368712 feed_name=user_recent_exposure_doc_cate2_1h; shared=true; feature_id=368710 feed_name=user_recent_exposure_doc_cate2_30d; shared=true; feature_id=368714 feed_name=user_recent_exposure_doc_cate2_6h; shared=true; feature_id=368711 feed_name=user_recent_exposure_doc_cate2_7d; shared=true; feature_id=368713 feed_name=user_recent_exposure_doc_cate3_180d; shared=true; feature_id=368709 feed_name=user_recent_exposure_doc_cate3_1d; shared=true; feature_id=368706 feed_name=user_recent_exposure_doc_cate3_1h; shared=true; feature_id=368704 feed_name=user_recent_exposure_doc_cate3_30d; shared=true; feature_id=368708 feed_name=user_recent_exposure_doc_cate3_6h; shared=true; feature_id=368705 feed_name=user_recent_exposure_doc_cate3_7d; shared=true; feature_id=368707 feed_name=user_recent_exposure_doc_id_180d; shared=true; feature_id=368721 feed_name=user_recent_exposure_doc_id_1d; shared=true; feature_id=368718 feed_name=user_recent_exposure_doc_id_1h; shared=true; feature_id=368716 feed_name=user_recent_exposure_doc_id_30d; shared=true; feature_id=368720 feed_name=user_recent_exposure_doc_id_6h; shared=true; feature_id=368717 feed_name=user_recent_exposure_doc_id_7d; shared=true; feature_id=368719 feed_name=user_recent_exposure_doc_keyword_180d; shared=true; feature_id=368697 feed_name=user_recent_exposure_doc_keyword_1d; shared=true; feature_id=368694 feed_name=user_recent_exposure_doc_keyword_1h; shared=true; feature_id=368692 feed_name=user_recent_exposure_doc_keyword_30d; shared=true; feature_id=368696 feed_name=user_recent_exposure_doc_keyword_6h; shared=true; feature_id=368693 feed_name=user_recent_exposure_doc_keyword_7d; shared=true; feature_id=368695 feed_name=user_recent_exposure_doc_tags_180d; shared=true; feature_id=368727 feed_name=user_recent_exposure_doc_tags_1d; shared=true; feature_id=368724 feed_name=user_recent_exposure_doc_tags_1h; shared=true; feature_id=368722 feed_name=user_recent_exposure_doc_tags_30d; shared=true; feature_id=368726 feed_name=user_recent_exposure_doc_tags_6h; shared=true; feature_id=368723 feed_name=user_recent_exposure_doc_tags_7d; shared=true; feature_id=368725 feed_name=user_recent_exposure_doc_topic_tag_180d; shared=true; feature_id=368691 feed_name=user_recent_exposure_doc_topic_tag_1d; shared=true; feature_id=368688 feed_name=user_recent_exposure_doc_topic_tag_1h; shared=true; feature_id=368686 feed_name=user_recent_exposure_doc_topic_tag_30d; shared=true; feature_id=368690 feed_name=user_recent_exposure_doc_topic_tag_6h; shared=true; feature_id=368687 feed_name=user_recent_exposure_doc_topic_tag_7d; shared=true; feature_id=368689 feed_name=user_recent_exposure_doc_type_180d; shared=true; feature_id=368685 feed_name=user_recent_exposure_doc_type_1d; shared=true; feature_id=368682 feed_name=user_recent_exposure_doc_type_1h; shared=true; feature_id=368680 feed_name=user_recent_exposure_doc_type_30d; shared=true; feature_id=368684 feed_name=user_recent_exposure_doc_type_6h; shared=true; feature_id=368681 feed_name=user_recent_exposure_doc_type_7d; shared=true; feature_id=368683 feed_name=user_recent_favorite_doc_cate1_180d; shared=true; feature_id=368559 feed_name=user_recent_favorite_doc_cate1_1d; shared=true; feature_id=368556 feed_name=user_recent_favorite_doc_cate1_1h; shared=true; feature_id=368554 feed_name=user_recent_favorite_doc_cate1_30d; shared=true; feature_id=368558 feed_name=user_recent_favorite_doc_cate1_6h; shared=true; feature_id=368555 feed_name=user_recent_favorite_doc_cate1_7d; shared=true; feature_id=368557 feed_name=user_recent_favorite_doc_cate2_180d; shared=true; feature_id=368571 feed_name=user_recent_favorite_doc_cate2_1d; shared=true; feature_id=368568 feed_name=user_recent_favorite_doc_cate2_1h; shared=true; feature_id=368566 feed_name=user_recent_favorite_doc_cate2_30d; shared=true; feature_id=368570 feed_name=user_recent_favorite_doc_cate2_6h; shared=true; feature_id=368567 feed_name=user_recent_favorite_doc_cate2_7d; shared=true; feature_id=368569 feed_name=user_recent_favorite_doc_cate3_180d; shared=true; feature_id=368565 feed_name=user_recent_favorite_doc_cate3_1d; shared=true; feature_id=368562 feed_name=user_recent_favorite_doc_cate3_1h; shared=true; feature_id=368560 feed_name=user_recent_favorite_doc_cate3_30d; shared=true; feature_id=368564 feed_name=user_recent_favorite_doc_cate3_6h; shared=true; feature_id=368561 feed_name=user_recent_favorite_doc_cate3_7d; shared=true; feature_id=368563 feed_name=user_recent_favorite_doc_id_180d; shared=true; feature_id=368577 feed_name=user_recent_favorite_doc_id_1d; shared=true; feature_id=368574 feed_name=user_recent_favorite_doc_id_1h; shared=true; feature_id=368572 feed_name=user_recent_favorite_doc_id_30d; shared=true; feature_id=368576 feed_name=user_recent_favorite_doc_id_6h; shared=true; feature_id=368573 feed_name=user_recent_favorite_doc_id_7d; shared=true; feature_id=368575 feed_name=user_recent_favorite_doc_keyword_180d; shared=true; feature_id=368553 feed_name=user_recent_favorite_doc_keyword_1d; shared=true; feature_id=368550 feed_name=user_recent_favorite_doc_keyword_1h; shared=true; feature_id=368548 feed_name=user_recent_favorite_doc_keyword_30d; shared=true; feature_id=368552 feed_name=user_recent_favorite_doc_keyword_6h; shared=true; feature_id=368549 feed_name=user_recent_favorite_doc_keyword_7d; shared=true; feature_id=368551 feed_name=user_recent_favorite_doc_tags_180d; shared=true; feature_id=368583 feed_name=user_recent_favorite_doc_tags_1d; shared=true; feature_id=368580 feed_name=user_recent_favorite_doc_tags_1h; shared=true; feature_id=368578 feed_name=user_recent_favorite_doc_tags_30d; shared=true; feature_id=368582 feed_name=user_recent_favorite_doc_tags_6h; shared=true; feature_id=368579 feed_name=user_recent_favorite_doc_tags_7d; shared=true; feature_id=368581 feed_name=user_recent_favorite_doc_topic_tag_180d; shared=true; feature_id=368547 feed_name=user_recent_favorite_doc_topic_tag_1d; shared=true; feature_id=368544 feed_name=user_recent_favorite_doc_topic_tag_1h; shared=true; feature_id=368542 feed_name=user_recent_favorite_doc_topic_tag_30d; shared=true; feature_id=368546 feed_name=user_recent_favorite_doc_topic_tag_6h; shared=true; feature_id=368543 feed_name=user_recent_favorite_doc_topic_tag_7d; shared=true; feature_id=368545 feed_name=user_recent_favorite_doc_type_180d; shared=true; feature_id=368541 feed_name=user_recent_favorite_doc_type_1d; shared=true; feature_id=368538 feed_name=user_recent_favorite_doc_type_1h; shared=true; feature_id=368536 feed_name=user_recent_favorite_doc_type_30d; shared=true; feature_id=368540 feed_name=user_recent_favorite_doc_type_6h; shared=true; feature_id=368537 feed_name=user_recent_favorite_doc_type_7d; shared=true; feature_id=368539 feed_name=user_recent_praise_doc_cate1_180d; shared=true; feature_id=368655 feed_name=user_recent_praise_doc_cate1_1d; shared=true; feature_id=368652 feed_name=user_recent_praise_doc_cate1_1h; shared=true; feature_id=368650 feed_name=user_recent_praise_doc_cate1_30d; shared=true; feature_id=368654 feed_name=user_recent_praise_doc_cate1_6h; shared=true; feature_id=368651 feed_name=user_recent_praise_doc_cate1_7d; shared=true; feature_id=368653 feed_name=user_recent_praise_doc_cate2_180d; shared=true; feature_id=368667 feed_name=user_recent_praise_doc_cate2_1d; shared=true; feature_id=368664 feed_name=user_recent_praise_doc_cate2_1h; shared=true; feature_id=368662 feed_name=user_recent_praise_doc_cate2_30d; shared=true; feature_id=368666 feed_name=user_recent_praise_doc_cate2_6h; shared=true; feature_id=368663 feed_name=user_recent_praise_doc_cate2_7d; shared=true; feature_id=368665 feed_name=user_recent_praise_doc_cate3_180d; shared=true; feature_id=368661 feed_name=user_recent_praise_doc_cate3_1d; shared=true; feature_id=368658 feed_name=user_recent_praise_doc_cate3_1h; shared=true; feature_id=368656 feed_name=user_recent_praise_doc_cate3_30d; shared=true; feature_id=368660 feed_name=user_recent_praise_doc_cate3_6h; shared=true; feature_id=368657 feed_name=user_recent_praise_doc_cate3_7d; shared=true; feature_id=368659 feed_name=user_recent_praise_doc_id_180d; shared=true; feature_id=368673 feed_name=user_recent_praise_doc_id_1d; shared=true; feature_id=368670 feed_name=user_recent_praise_doc_id_1h; shared=true; feature_id=368668 feed_name=user_recent_praise_doc_id_30d; shared=true; feature_id=368672 feed_name=user_recent_praise_doc_id_6h; shared=true; feature_id=368669 feed_name=user_recent_praise_doc_id_7d; shared=true; feature_id=368671 feed_name=user_recent_praise_doc_keyword_180d; shared=true; feature_id=368649 feed_name=user_recent_praise_doc_keyword_1d; shared=true; feature_id=368646 feed_name=user_recent_praise_doc_keyword_1h; shared=true; feature_id=368644 feed_name=user_recent_praise_doc_keyword_30d; shared=true; feature_id=368648 feed_name=user_recent_praise_doc_keyword_6h; shared=true; feature_id=368645 feed_name=user_recent_praise_doc_keyword_7d; shared=true; feature_id=368647 feed_name=user_recent_praise_doc_tags_180d; shared=true; feature_id=368679 feed_name=user_recent_praise_doc_tags_1d; shared=true; feature_id=368676 feed_name=user_recent_praise_doc_tags_1h; shared=true; feature_id=368674 feed_name=user_recent_praise_doc_tags_30d; shared=true; feature_id=368678 feed_name=user_recent_praise_doc_tags_6h; shared=true; feature_id=368675 feed_name=user_recent_praise_doc_tags_7d; shared=true; feature_id=368677 feed_name=user_recent_praise_doc_topic_tag_180d; shared=true; feature_id=368643 feed_name=user_recent_praise_doc_topic_tag_1d; shared=true; feature_id=368640 feed_name=user_recent_praise_doc_topic_tag_1h; shared=true; feature_id=368638 feed_name=user_recent_praise_doc_topic_tag_30d; shared=true; feature_id=368642 feed_name=user_recent_praise_doc_topic_tag_6h; shared=true; feature_id=368639 feed_name=user_recent_praise_doc_topic_tag_7d; shared=true; feature_id=368641 feed_name=user_recent_praise_doc_type_180d; shared=true; feature_id=368637 feed_name=user_recent_praise_doc_type_1d; shared=true; feature_id=368634 feed_name=user_recent_praise_doc_type_1h; shared=true; feature_id=368632 feed_name=user_recent_praise_doc_type_30d; shared=true; feature_id=368636 feed_name=user_recent_praise_doc_type_6h; shared=true; feature_id=368633 feed_name=user_recent_praise_doc_type_7d; shared=true; feature_id=368635 feed_name=user_recent_share_doc_cate1_180d; shared=true; feature_id=368511 feed_name=user_recent_share_doc_cate1_1d; shared=true; feature_id=368508 feed_name=user_recent_share_doc_cate1_1h; shared=true; feature_id=368506 feed_name=user_recent_share_doc_cate1_30d; shared=true; feature_id=368510 feed_name=user_recent_share_doc_cate1_6h; shared=true; feature_id=368507 feed_name=user_recent_share_doc_cate1_7d; shared=true; feature_id=368509 feed_name=user_recent_share_doc_cate2_180d; shared=true; feature_id=368523 feed_name=user_recent_share_doc_cate2_1d; shared=true; feature_id=368520 feed_name=user_recent_share_doc_cate2_1h; shared=true; feature_id=368518 feed_name=user_recent_share_doc_cate2_30d; shared=true; feature_id=368522 feed_name=user_recent_share_doc_cate2_6h; shared=true; feature_id=368519 feed_name=user_recent_share_doc_cate2_7d; shared=true; feature_id=368521 feed_name=user_recent_share_doc_cate3_180d; shared=true; feature_id=368517 feed_name=user_recent_share_doc_cate3_1d; shared=true; feature_id=368514 feed_name=user_recent_share_doc_cate3_1h; shared=true; feature_id=368512 feed_name=user_recent_share_doc_cate3_30d; shared=true; feature_id=368516 feed_name=user_recent_share_doc_cate3_6h; shared=true; feature_id=368513 feed_name=user_recent_share_doc_cate3_7d; shared=true; feature_id=368515 feed_name=user_recent_share_doc_id_180d; shared=true; feature_id=368529 feed_name=user_recent_share_doc_id_1d; shared=true; feature_id=368526 feed_name=user_recent_share_doc_id_1h; shared=true; feature_id=368524 feed_name=user_recent_share_doc_id_30d; shared=true; feature_id=368528 feed_name=user_recent_share_doc_id_6h; shared=true; feature_id=368525 feed_name=user_recent_share_doc_id_7d; shared=true; feature_id=368527 feed_name=user_recent_share_doc_keyword_180d; shared=true; feature_id=368505 feed_name=user_recent_share_doc_keyword_1d; shared=true; feature_id=368502 feed_name=user_recent_share_doc_keyword_1h; shared=true; feature_id=368500 feed_name=user_recent_share_doc_keyword_30d; shared=true; feature_id=368504 feed_name=user_recent_share_doc_keyword_6h; shared=true; feature_id=368501 feed_name=user_recent_share_doc_keyword_7d; shared=true; feature_id=368503 feed_name=user_recent_share_doc_tags_180d; shared=true; feature_id=368535 feed_name=user_recent_share_doc_tags_1d; shared=true; feature_id=368532 feed_name=user_recent_share_doc_tags_1h; shared=true; feature_id=368530 feed_name=user_recent_share_doc_tags_30d; shared=true; feature_id=368534 feed_name=user_recent_share_doc_tags_6h; shared=true; feature_id=368531 feed_name=user_recent_share_doc_tags_7d; shared=true; feature_id=368533 feed_name=user_recent_share_doc_topic_tag_180d; shared=true; feature_id=368499 feed_name=user_recent_share_doc_topic_tag_1d; shared=true; feature_id=368496 feed_name=user_recent_share_doc_topic_tag_1h; shared=true; feature_id=368494 feed_name=user_recent_share_doc_topic_tag_30d; shared=true; feature_id=368498 feed_name=user_recent_share_doc_topic_tag_6h; shared=true; feature_id=368495 feed_name=user_recent_share_doc_topic_tag_7d; shared=true; feature_id=368497 feed_name=user_recent_share_doc_type_180d; shared=true; feature_id=368493 feed_name=user_recent_share_doc_type_1d; shared=true; feature_id=368490 feed_name=user_recent_share_doc_type_1h; shared=true; feature_id=368488 feed_name=user_recent_share_doc_type_30d; shared=true; feature_id=368492 feed_name=user_recent_share_doc_type_6h; shared=true; feature_id=368489 feed_name=user_recent_share_doc_type_7d; shared=true; feature_id=368491 feed_name=user_register_time; shared=true; feature_id=368191 feed_name=user_st_1d_doc_author_id_cart_cp; shared=true; feature_id=368359 feed_name=user_st_1d_doc_author_id_cart_recent; shared=true; feature_id=368466 feed_name=user_st_1d_doc_author_id_click_cp; shared=true; feature_id=368268 feed_name=user_st_1d_doc_author_id_click_recent; shared=true; feature_id=368412 feed_name=user_st_1d_doc_author_id_conversion_cp; shared=true; feature_id=368383 feed_name=user_st_1d_doc_author_id_conversion_recent; shared=true; feature_id=368482 feed_name=user_st_1d_doc_author_id_favorite_cp; shared=true; feature_id=368269 feed_name=user_st_1d_doc_author_id_favorite_recent; shared=true; feature_id=368413 feed_name=user_st_1d_doc_author_id_praise_cp; shared=true; feature_id=368270 feed_name=user_st_1d_doc_author_id_praise_recent; shared=true; feature_id=368414 feed_name=user_st_1d_doc_author_id_query_cp; shared=true; feature_id=368271 feed_name=user_st_1d_doc_author_id_query_recent; shared=true; feature_id=368415 feed_name=user_st_1d_doc_cate1_cart_cp; shared=true; feature_id=368350 feed_name=user_st_1d_doc_cate1_cart_recent; shared=true; feature_id=368460 feed_name=user_st_1d_doc_cate1_click_cp; shared=true; feature_id=368256 feed_name=user_st_1d_doc_cate1_click_recent; shared=true; feature_id=368400 feed_name=user_st_1d_doc_cate1_conversion_cp; shared=true; feature_id=368374 feed_name=user_st_1d_doc_cate1_conversion_recent; shared=true; feature_id=368476 feed_name=user_st_1d_doc_cate1_favorite_cp; shared=true; feature_id=368257 feed_name=user_st_1d_doc_cate1_favorite_recent; shared=true; feature_id=368401 feed_name=user_st_1d_doc_cate1_praise_cp; shared=true; feature_id=368258 feed_name=user_st_1d_doc_cate1_praise_recent; shared=true; feature_id=368402 feed_name=user_st_1d_doc_cate1_query_cp; shared=true; feature_id=368259 feed_name=user_st_1d_doc_cate1_query_recent; shared=true; feature_id=368403 feed_name=user_st_1d_doc_cate2_cart_cp; shared=true; feature_id=368353 feed_name=user_st_1d_doc_cate2_cart_recent; shared=true; feature_id=368462 feed_name=user_st_1d_doc_cate2_click_cp; shared=true; feature_id=368260 feed_name=user_st_1d_doc_cate2_click_recent; shared=true; feature_id=368404 feed_name=user_st_1d_doc_cate2_conversion_cp; shared=true; feature_id=368377 feed_name=user_st_1d_doc_cate2_conversion_recent; shared=true; feature_id=368478 feed_name=user_st_1d_doc_cate2_favorite_cp; shared=true; feature_id=368261 feed_name=user_st_1d_doc_cate2_favorite_recent; shared=true; feature_id=368405 feed_name=user_st_1d_doc_cate2_praise_cp; shared=true; feature_id=368262 feed_name=user_st_1d_doc_cate2_praise_recent; shared=true; feature_id=368406 feed_name=user_st_1d_doc_cate2_query_cp; shared=true; feature_id=368263 feed_name=user_st_1d_doc_cate2_query_recent; shared=true; feature_id=368407 feed_name=user_st_1d_doc_cate3_cart_cp; shared=true; feature_id=368356 feed_name=user_st_1d_doc_cate3_cart_recent; shared=true; feature_id=368464 feed_name=user_st_1d_doc_cate3_click_cp; shared=true; feature_id=368264 feed_name=user_st_1d_doc_cate3_click_recent; shared=true; feature_id=368408 feed_name=user_st_1d_doc_cate3_conversion_cp; shared=true; feature_id=368380 feed_name=user_st_1d_doc_cate3_conversion_recent; shared=true; feature_id=368480 feed_name=user_st_1d_doc_cate3_favorite_cp; shared=true; feature_id=368265 feed_name=user_st_1d_doc_cate3_favorite_recent; shared=true; feature_id=368409 feed_name=user_st_1d_doc_cate3_praise_cp; shared=true; feature_id=368266 feed_name=user_st_1d_doc_cate3_praise_recent; shared=true; feature_id=368410 feed_name=user_st_1d_doc_cate3_query_cp; shared=true; feature_id=368267 feed_name=user_st_1d_doc_cate3_query_recent; shared=true; feature_id=368411 feed_name=user_st_1d_doc_id_cart_cp; shared=true; feature_id=368344 feed_name=user_st_1d_doc_id_cart_recent; shared=true; feature_id=368456 feed_name=user_st_1d_doc_id_click_cp; shared=true; feature_id=368248 feed_name=user_st_1d_doc_id_click_recent; shared=true; feature_id=368392 feed_name=user_st_1d_doc_id_conversion_cp; shared=true; feature_id=368368 feed_name=user_st_1d_doc_id_conversion_recent; shared=true; feature_id=368472 feed_name=user_st_1d_doc_id_favorite_cp; shared=true; feature_id=368249 feed_name=user_st_1d_doc_id_favorite_recent; shared=true; feature_id=368393 feed_name=user_st_1d_doc_id_praise_cp; shared=true; feature_id=368250 feed_name=user_st_1d_doc_id_praise_recent; shared=true; feature_id=368394 feed_name=user_st_1d_doc_id_query_cp; shared=true; feature_id=368251 feed_name=user_st_1d_doc_id_query_recent; shared=true; feature_id=368395 feed_name=user_st_1d_doc_keyword_cart_cp; shared=true; feature_id=368365 feed_name=user_st_1d_doc_keyword_cart_recent; shared=true; feature_id=368470 feed_name=user_st_1d_doc_keyword_click_cp; shared=true; feature_id=368276 feed_name=user_st_1d_doc_keyword_click_recent; shared=true; feature_id=368420 feed_name=user_st_1d_doc_keyword_conversion_cp; shared=true; feature_id=368389 feed_name=user_st_1d_doc_keyword_conversion_recent; shared=true; feature_id=368486 feed_name=user_st_1d_doc_keyword_favorite_cp; shared=true; feature_id=368277 feed_name=user_st_1d_doc_keyword_favorite_recent; shared=true; feature_id=368421 feed_name=user_st_1d_doc_keyword_praise_cp; shared=true; feature_id=368278 feed_name=user_st_1d_doc_keyword_praise_recent; shared=true; feature_id=368422 feed_name=user_st_1d_doc_keyword_query_cp; shared=true; feature_id=368279 feed_name=user_st_1d_doc_keyword_query_recent; shared=true; feature_id=368423 feed_name=user_st_1d_doc_tags_cart_cp; shared=true; feature_id=368362 feed_name=user_st_1d_doc_tags_cart_recent; shared=true; feature_id=368468 feed_name=user_st_1d_doc_tags_click_cp; shared=true; feature_id=368272 feed_name=user_st_1d_doc_tags_click_recent; shared=true; feature_id=368416 feed_name=user_st_1d_doc_tags_conversion_cp; shared=true; feature_id=368386 feed_name=user_st_1d_doc_tags_conversion_recent; shared=true; feature_id=368484 feed_name=user_st_1d_doc_tags_favorite_cp; shared=true; feature_id=368273 feed_name=user_st_1d_doc_tags_favorite_recent; shared=true; feature_id=368417 feed_name=user_st_1d_doc_tags_praise_cp; shared=true; feature_id=368274 feed_name=user_st_1d_doc_tags_praise_recent; shared=true; feature_id=368418 feed_name=user_st_1d_doc_tags_query_cp; shared=true; feature_id=368275 feed_name=user_st_1d_doc_tags_query_recent; shared=true; feature_id=368419 feed_name=user_st_1d_doc_title_terms_cart_cp; shared=true; feature_id=368347 feed_name=user_st_1d_doc_title_terms_cart_recent; shared=true; feature_id=368458 feed_name=user_st_1d_doc_title_terms_click_cp; shared=true; feature_id=368252 feed_name=user_st_1d_doc_title_terms_click_recent; shared=true; feature_id=368396 feed_name=user_st_1d_doc_title_terms_conversion_cp; shared=true; feature_id=368371 feed_name=user_st_1d_doc_title_terms_conversion_recent; shared=true; feature_id=368474 feed_name=user_st_1d_doc_title_terms_favorite_cp; shared=true; feature_id=368253 feed_name=user_st_1d_doc_title_terms_favorite_recent; shared=true; feature_id=368397 feed_name=user_st_1d_doc_title_terms_praise_cp; shared=true; feature_id=368254 feed_name=user_st_1d_doc_title_terms_praise_recent; shared=true; feature_id=368398 feed_name=user_st_1d_doc_title_terms_query_cp; shared=true; feature_id=368255 feed_name=user_st_1d_doc_title_terms_query_recent; shared=true; feature_id=368399 feed_name=user_st_7d_doc_author_id_cart_cp; shared=true; feature_id=368360 feed_name=user_st_7d_doc_author_id_cart_recent; shared=true; feature_id=368467 feed_name=user_st_7d_doc_author_id_click_cp; shared=true; feature_id=368300 feed_name=user_st_7d_doc_author_id_click_recent; shared=true; feature_id=368444 feed_name=user_st_7d_doc_author_id_conversion_cp; shared=true; feature_id=368384 feed_name=user_st_7d_doc_author_id_conversion_recent; shared=true; feature_id=368483 feed_name=user_st_7d_doc_author_id_favorite_cp; shared=true; feature_id=368301 feed_name=user_st_7d_doc_author_id_favorite_recent; shared=true; feature_id=368445 feed_name=user_st_7d_doc_author_id_praise_cp; shared=true; feature_id=368302 feed_name=user_st_7d_doc_author_id_praise_recent; shared=true; feature_id=368446 feed_name=user_st_7d_doc_author_id_query_cp; shared=true; feature_id=368303 feed_name=user_st_7d_doc_author_id_query_recent; shared=true; feature_id=368447 feed_name=user_st_7d_doc_cate1_cart_cp; shared=true; feature_id=368351 feed_name=user_st_7d_doc_cate1_cart_recent; shared=true; feature_id=368461 feed_name=user_st_7d_doc_cate1_click_cp; shared=true; feature_id=368288 feed_name=user_st_7d_doc_cate1_click_recent; shared=true; feature_id=368432 feed_name=user_st_7d_doc_cate1_conversion_cp; shared=true; feature_id=368375 feed_name=user_st_7d_doc_cate1_conversion_recent; shared=true; feature_id=368477 feed_name=user_st_7d_doc_cate1_favorite_cp; shared=true; feature_id=368289 feed_name=user_st_7d_doc_cate1_favorite_recent; shared=true; feature_id=368433 feed_name=user_st_7d_doc_cate1_praise_cp; shared=true; feature_id=368290 feed_name=user_st_7d_doc_cate1_praise_recent; shared=true; feature_id=368434 feed_name=user_st_7d_doc_cate1_query_cp; shared=true; feature_id=368291 feed_name=user_st_7d_doc_cate1_query_recent; shared=true; feature_id=368435 feed_name=user_st_7d_doc_cate2_cart_cp; shared=true; feature_id=368354 feed_name=user_st_7d_doc_cate2_cart_recent; shared=true; feature_id=368463 feed_name=user_st_7d_doc_cate2_click_cp; shared=true; feature_id=368292 feed_name=user_st_7d_doc_cate2_click_recent; shared=true; feature_id=368436 feed_name=user_st_7d_doc_cate2_conversion_cp; shared=true; feature_id=368378 feed_name=user_st_7d_doc_cate2_conversion_recent; shared=true; feature_id=368479 feed_name=user_st_7d_doc_cate2_favorite_cp; shared=true; feature_id=368293 feed_name=user_st_7d_doc_cate2_favorite_recent; shared=true; feature_id=368437 feed_name=user_st_7d_doc_cate2_praise_cp; shared=true; feature_id=368294 feed_name=user_st_7d_doc_cate2_praise_recent; shared=true; feature_id=368438 feed_name=user_st_7d_doc_cate2_query_cp; shared=true; feature_id=368295 feed_name=user_st_7d_doc_cate2_query_recent; shared=true; feature_id=368439 feed_name=user_st_7d_doc_cate3_cart_cp; shared=true; feature_id=368357 feed_name=user_st_7d_doc_cate3_cart_recent; shared=true; feature_id=368465 feed_name=user_st_7d_doc_cate3_click_cp; shared=true; feature_id=368296 feed_name=user_st_7d_doc_cate3_click_recent; shared=true; feature_id=368440 feed_name=user_st_7d_doc_cate3_conversion_cp; shared=true; feature_id=368381 feed_name=user_st_7d_doc_cate3_conversion_recent; shared=true; feature_id=368481 feed_name=user_st_7d_doc_cate3_favorite_cp; shared=true; feature_id=368297 feed_name=user_st_7d_doc_cate3_favorite_recent; shared=true; feature_id=368441 feed_name=user_st_7d_doc_cate3_praise_cp; shared=true; feature_id=368298 feed_name=user_st_7d_doc_cate3_praise_recent; shared=true; feature_id=368442 feed_name=user_st_7d_doc_cate3_query_cp; shared=true; feature_id=368299 feed_name=user_st_7d_doc_cate3_query_recent; shared=true; feature_id=368443 feed_name=user_st_7d_doc_id_cart_cp; shared=true; feature_id=368345 feed_name=user_st_7d_doc_id_cart_recent; shared=true; feature_id=368457 feed_name=user_st_7d_doc_id_click_cp; shared=true; feature_id=368280 feed_name=user_st_7d_doc_id_click_recent; shared=true; feature_id=368424 feed_name=user_st_7d_doc_id_conversion_cp; shared=true; feature_id=368369 feed_name=user_st_7d_doc_id_conversion_recent; shared=true; feature_id=368473 feed_name=user_st_7d_doc_id_favorite_cp; shared=true; feature_id=368281 feed_name=user_st_7d_doc_id_favorite_recent; shared=true; feature_id=368425 feed_name=user_st_7d_doc_id_praise_cp; shared=true; feature_id=368282 feed_name=user_st_7d_doc_id_praise_recent; shared=true; feature_id=368426 feed_name=user_st_7d_doc_id_query_cp; shared=true; feature_id=368283 feed_name=user_st_7d_doc_id_query_recent; shared=true; feature_id=368427 feed_name=user_st_7d_doc_keyword_cart_cp; shared=true; feature_id=368366 feed_name=user_st_7d_doc_keyword_cart_recent; shared=true; feature_id=368471 feed_name=user_st_7d_doc_keyword_click_cp; shared=true; feature_id=368308 feed_name=user_st_7d_doc_keyword_click_recent; shared=true; feature_id=368452 feed_name=user_st_7d_doc_keyword_conversion_cp; shared=true; feature_id=368390 feed_name=user_st_7d_doc_keyword_conversion_recent; shared=true; feature_id=368487 feed_name=user_st_7d_doc_keyword_favorite_cp; shared=true; feature_id=368309 feed_name=user_st_7d_doc_keyword_favorite_recent; shared=true; feature_id=368453 feed_name=user_st_7d_doc_keyword_praise_cp; shared=true; feature_id=368310 feed_name=user_st_7d_doc_keyword_praise_recent; shared=true; feature_id=368454 feed_name=user_st_7d_doc_keyword_query_cp; shared=true; feature_id=368311 feed_name=user_st_7d_doc_keyword_query_recent; shared=true; feature_id=368455 feed_name=user_st_7d_doc_tags_cart_cp; shared=true; feature_id=368363 feed_name=user_st_7d_doc_tags_cart_recent; shared=true; feature_id=368469 feed_name=user_st_7d_doc_tags_click_cp; shared=true; feature_id=368304 feed_name=user_st_7d_doc_tags_click_recent; shared=true; feature_id=368448 feed_name=user_st_7d_doc_tags_conversion_cp; shared=true; feature_id=368387 feed_name=user_st_7d_doc_tags_conversion_recent; shared=true; feature_id=368485 feed_name=user_st_7d_doc_tags_favorite_cp; shared=true; feature_id=368305 feed_name=user_st_7d_doc_tags_favorite_recent; shared=true; feature_id=368449 feed_name=user_st_7d_doc_tags_praise_cp; shared=true; feature_id=368306 feed_name=user_st_7d_doc_tags_praise_recent; shared=true; feature_id=368450 feed_name=user_st_7d_doc_tags_query_cp; shared=true; feature_id=368307 feed_name=user_st_7d_doc_tags_query_recent; shared=true; feature_id=368451 feed_name=user_st_7d_doc_title_terms_cart_cp; shared=true; feature_id=368348 feed_name=user_st_7d_doc_title_terms_cart_recent; shared=true; feature_id=368459 feed_name=user_st_7d_doc_title_terms_click_cp; shared=true; feature_id=368284 feed_name=user_st_7d_doc_title_terms_click_recent; shared=true; feature_id=368428 feed_name=user_st_7d_doc_title_terms_conversion_cp; shared=true; feature_id=368372 feed_name=user_st_7d_doc_title_terms_conversion_recent; shared=true; feature_id=368475 feed_name=user_st_7d_doc_title_terms_favorite_cp; shared=true; feature_id=368285 feed_name=user_st_7d_doc_title_terms_favorite_recent; shared=true; feature_id=368429 feed_name=user_st_7d_doc_title_terms_praise_cp; shared=true; feature_id=368286 feed_name=user_st_7d_doc_title_terms_praise_recent; shared=true; feature_id=368430 feed_name=user_st_7d_doc_title_terms_query_cp; shared=true; feature_id=368287 feed_name=user_st_7d_doc_title_terms_query_recent; shared=true; feature_id=368431 feed_name=user_tags_list; shared=true; feature_id=368190 column_name : area, att_traced, bhv_scm, bhv_spm, bhv_spm_1, bhv_spm_2, bhv_spm_3, bhv_spm_4, bhv_time_hour, bhv_time_monthday, bhv_time_weekday, city, client_version, country, device_model, district, doc_author_fans_10, doc_author_id, doc_author_level, doc_author_name, doc_cate1, doc_cate2, doc_cate3, doc_collect_cnt_10, doc_collection, doc_comment_cnt_10, doc_content_length_2, doc_create_time, doc_detail_pic_num, doc_expire_time, doc_id, doc_id_post_click_180d, doc_id_post_click_1d, doc_id_post_click_1h, doc_id_post_click_30d, doc_id_post_click_6h, doc_id_post_click_7d, doc_id_post_favorite_180d, doc_id_post_favorite_1d, doc_id_post_favorite_1h, doc_id_post_favorite_30d, doc_id_post_favorite_6h, doc_id_post_favorite_7d, doc_id_post_praise_180d, doc_id_post_praise_1d, doc_id_post_praise_1h, doc_id_post_praise_30d, doc_id_post_praise_6h, doc_id_post_praise_7d, doc_id_post_share_180d, doc_id_post_share_1d, doc_id_post_share_1h, doc_id_post_share_30d, doc_id_post_share_6h, doc_id_post_share_7d, doc_keyword, doc_location_tag, doc_pic_url, doc_praise_cnt_10, doc_pub_time, doc_rating, doc_related_goods_ids, doc_share_cnt_10, doc_source_id, doc_tags, doc_title_length, doc_title_terms, doc_topic_tag, doc_type, doc_video_duration_10, doc_video_url, fake_context_id, goods_exposure_cnt_lt, goods_is_prepublic, goods_op_rec_status, goods_quality_score, goods_rec_scene_id, network, os, os_version, page, platform, province, time, user_age, user_area, user_city, user_country, user_device_id, user_district, user_gender, user_id, user_is_prepublic, user_lt_doc_author_id_cart_cp, user_lt_doc_author_id_click_cp, user_lt_doc_author_id_conversion_cp, user_lt_doc_author_id_favorite_cp, user_lt_doc_author_id_praise_cp, user_lt_doc_author_id_query_cp, user_lt_doc_cate1_cart_cp, user_lt_doc_cate1_click_cp, user_lt_doc_cate1_conversion_cp, user_lt_doc_cate1_favorite_cp, user_lt_doc_cate1_praise_cp, user_lt_doc_cate1_query_cp, user_lt_doc_cate2_cart_cp, user_lt_doc_cate2_click_cp, user_lt_doc_cate2_conversion_cp, user_lt_doc_cate2_favorite_cp, user_lt_doc_cate2_praise_cp, user_lt_doc_cate2_query_cp, user_lt_doc_cate3_cart_cp, user_lt_doc_cate3_click_cp, user_lt_doc_cate3_conversion_cp, user_lt_doc_cate3_favorite_cp, user_lt_doc_cate3_praise_cp, user_lt_doc_cate3_query_cp, user_lt_doc_id_cart_cp, user_lt_doc_id_click_cp, user_lt_doc_id_conversion_cp, user_lt_doc_id_favorite_cp, user_lt_doc_id_praise_cp, user_lt_doc_id_query_cp, user_lt_doc_keyword_cart_cp, user_lt_doc_keyword_click_cp, user_lt_doc_keyword_conversion_cp, user_lt_doc_keyword_favorite_cp, user_lt_doc_keyword_praise_cp, user_lt_doc_keyword_query_cp, user_lt_doc_tags_cart_cp, user_lt_doc_tags_click_cp, user_lt_doc_tags_conversion_cp, user_lt_doc_tags_favorite_cp, user_lt_doc_tags_praise_cp, user_lt_doc_tags_query_cp, user_lt_doc_title_terms_cart_cp, user_lt_doc_title_terms_click_cp, user_lt_doc_title_terms_conversion_cp, user_lt_doc_title_terms_favorite_cp, user_lt_doc_title_terms_praise_cp, user_lt_doc_title_terms_query_cp, user_membership_level, user_province, user_quality_score, user_recent_click_doc_cate1_180d, user_recent_click_doc_cate1_1d, user_recent_click_doc_cate1_1h, user_recent_click_doc_cate1_30d, user_recent_click_doc_cate1_6h, user_recent_click_doc_cate1_7d, user_recent_click_doc_cate2_180d, user_recent_click_doc_cate2_1d, user_recent_click_doc_cate2_1h, user_recent_click_doc_cate2_30d, user_recent_click_doc_cate2_6h, user_recent_click_doc_cate2_7d, user_recent_click_doc_cate3_180d, user_recent_click_doc_cate3_1d, user_recent_click_doc_cate3_1h, user_recent_click_doc_cate3_30d, user_recent_click_doc_cate3_6h, user_recent_click_doc_cate3_7d, user_recent_click_doc_id_180d, user_recent_click_doc_id_1d, user_recent_click_doc_id_1h, user_recent_click_doc_id_30d, user_recent_click_doc_id_6h, user_recent_click_doc_id_7d, user_recent_click_doc_keyword_180d, user_recent_click_doc_keyword_1d, user_recent_click_doc_keyword_1h, user_recent_click_doc_keyword_30d, user_recent_click_doc_keyword_6h, user_recent_click_doc_keyword_7d, user_recent_click_doc_tags_180d, user_recent_click_doc_tags_1d, user_recent_click_doc_tags_1h, user_recent_click_doc_tags_30d, user_recent_click_doc_tags_6h, user_recent_click_doc_tags_7d, user_recent_click_doc_topic_tag_180d, user_recent_click_doc_topic_tag_1d, user_recent_click_doc_topic_tag_1h, user_recent_click_doc_topic_tag_30d, user_recent_click_doc_topic_tag_6h, user_recent_click_doc_topic_tag_7d, user_recent_click_doc_type_180d, user_recent_click_doc_type_1d, user_recent_click_doc_type_1h, user_recent_click_doc_type_30d, user_recent_click_doc_type_6h, user_recent_click_doc_type_7d, user_recent_exposure_doc_cate1_180d, user_recent_exposure_doc_cate1_1d, user_recent_exposure_doc_cate1_1h, user_recent_exposure_doc_cate1_30d, user_recent_exposure_doc_cate1_6h, user_recent_exposure_doc_cate1_7d, user_recent_exposure_doc_cate2_180d, user_recent_exposure_doc_cate2_1d, user_recent_exposure_doc_cate2_1h, user_recent_exposure_doc_cate2_30d, user_recent_exposure_doc_cate2_6h, user_recent_exposure_doc_cate2_7d, user_recent_exposure_doc_cate3_180d, user_recent_exposure_doc_cate3_1d, user_recent_exposure_doc_cate3_1h, user_recent_exposure_doc_cate3_30d, user_recent_exposure_doc_cate3_6h, user_recent_exposure_doc_cate3_7d, user_recent_exposure_doc_id_180d, user_recent_exposure_doc_id_1d, user_recent_exposure_doc_id_1h, user_recent_exposure_doc_id_30d, user_recent_exposure_doc_id_6h, user_recent_exposure_doc_id_7d, user_recent_exposure_doc_keyword_180d, user_recent_exposure_doc_keyword_1d, user_recent_exposure_doc_keyword_1h, user_recent_exposure_doc_keyword_30d, user_recent_exposure_doc_keyword_6h, user_recent_exposure_doc_keyword_7d, user_recent_exposure_doc_tags_180d, user_recent_exposure_doc_tags_1d, user_recent_exposure_doc_tags_1h, user_recent_exposure_doc_tags_30d, user_recent_exposure_doc_tags_6h, user_recent_exposure_doc_tags_7d, user_recent_exposure_doc_topic_tag_180d, user_recent_exposure_doc_topic_tag_1d, user_recent_exposure_doc_topic_tag_1h, user_recent_exposure_doc_topic_tag_30d, user_recent_exposure_doc_topic_tag_6h, user_recent_exposure_doc_topic_tag_7d, user_recent_exposure_doc_type_180d, user_recent_exposure_doc_type_1d, user_recent_exposure_doc_type_1h, user_recent_exposure_doc_type_30d, user_recent_exposure_doc_type_6h, user_recent_exposure_doc_type_7d, user_recent_favorite_doc_cate1_180d, user_recent_favorite_doc_cate1_1d, user_recent_favorite_doc_cate1_1h, user_recent_favorite_doc_cate1_30d, user_recent_favorite_doc_cate1_6h, user_recent_favorite_doc_cate1_7d, user_recent_favorite_doc_cate2_180d, user_recent_favorite_doc_cate2_1d, user_recent_favorite_doc_cate2_1h, user_recent_favorite_doc_cate2_30d, user_recent_favorite_doc_cate2_6h, user_recent_favorite_doc_cate2_7d, user_recent_favorite_doc_cate3_180d, user_recent_favorite_doc_cate3_1d, user_recent_favorite_doc_cate3_1h, user_recent_favorite_doc_cate3_30d, user_recent_favorite_doc_cate3_6h, user_recent_favorite_doc_cate3_7d, user_recent_favorite_doc_id_180d, user_recent_favorite_doc_id_1d, user_recent_favorite_doc_id_1h, user_recent_favorite_doc_id_30d, user_recent_favorite_doc_id_6h, user_recent_favorite_doc_id_7d, user_recent_favorite_doc_keyword_180d, user_recent_favorite_doc_keyword_1d, user_recent_favorite_doc_keyword_1h, user_recent_favorite_doc_keyword_30d, user_recent_favorite_doc_keyword_6h, user_recent_favorite_doc_keyword_7d, user_recent_favorite_doc_tags_180d, user_recent_favorite_doc_tags_1d, user_recent_favorite_doc_tags_1h, user_recent_favorite_doc_tags_30d, user_recent_favorite_doc_tags_6h, user_recent_favorite_doc_tags_7d, user_recent_favorite_doc_topic_tag_180d, user_recent_favorite_doc_topic_tag_1d, user_recent_favorite_doc_topic_tag_1h, user_recent_favorite_doc_topic_tag_30d, user_recent_favorite_doc_topic_tag_6h, user_recent_favorite_doc_topic_tag_7d, user_recent_favorite_doc_type_180d, user_recent_favorite_doc_type_1d, user_recent_favorite_doc_type_1h, user_recent_favorite_doc_type_30d, user_recent_favorite_doc_type_6h, user_recent_favorite_doc_type_7d, user_recent_praise_doc_cate1_180d, user_recent_praise_doc_cate1_1d, user_recent_praise_doc_cate1_1h, user_recent_praise_doc_cate1_30d, user_recent_praise_doc_cate1_6h, user_recent_praise_doc_cate1_7d, user_recent_praise_doc_cate2_180d, user_recent_praise_doc_cate2_1d, user_recent_praise_doc_cate2_1h, user_recent_praise_doc_cate2_30d, user_recent_praise_doc_cate2_6h, user_recent_praise_doc_cate2_7d, user_recent_praise_doc_cate3_180d, user_recent_praise_doc_cate3_1d, user_recent_praise_doc_cate3_1h, user_recent_praise_doc_cate3_30d, user_recent_praise_doc_cate3_6h, user_recent_praise_doc_cate3_7d, user_recent_praise_doc_id_180d, user_recent_praise_doc_id_1d, user_recent_praise_doc_id_1h, user_recent_praise_doc_id_30d, user_recent_praise_doc_id_6h, user_recent_praise_doc_id_7d, user_recent_praise_doc_keyword_180d, user_recent_praise_doc_keyword_1d, user_recent_praise_doc_keyword_1h, user_recent_praise_doc_keyword_30d, user_recent_praise_doc_keyword_6h, user_recent_praise_doc_keyword_7d, user_recent_praise_doc_tags_180d, user_recent_praise_doc_tags_1d, user_recent_praise_doc_tags_1h, user_recent_praise_doc_tags_30d, user_recent_praise_doc_tags_6h, user_recent_praise_doc_tags_7d, user_recent_praise_doc_topic_tag_180d, user_recent_praise_doc_topic_tag_1d, user_recent_praise_doc_topic_tag_1h, user_recent_praise_doc_topic_tag_30d, user_recent_praise_doc_topic_tag_6h, user_recent_praise_doc_topic_tag_7d, user_recent_praise_doc_type_180d, user_recent_praise_doc_type_1d, user_recent_praise_doc_type_1h, user_recent_praise_doc_type_30d, user_recent_praise_doc_type_6h, user_recent_praise_doc_type_7d, user_recent_share_doc_cate1_180d, user_recent_share_doc_cate1_1d, user_recent_share_doc_cate1_1h, user_recent_share_doc_cate1_30d, user_recent_share_doc_cate1_6h, user_recent_share_doc_cate1_7d, user_recent_share_doc_cate2_180d, user_recent_share_doc_cate2_1d, user_recent_share_doc_cate2_1h, user_recent_share_doc_cate2_30d, user_recent_share_doc_cate2_6h, user_recent_share_doc_cate2_7d, user_recent_share_doc_cate3_180d, user_recent_share_doc_cate3_1d, user_recent_share_doc_cate3_1h, user_recent_share_doc_cate3_30d, user_recent_share_doc_cate3_6h, user_recent_share_doc_cate3_7d, user_recent_share_doc_id_180d, user_recent_share_doc_id_1d, user_recent_share_doc_id_1h, user_recent_share_doc_id_30d, user_recent_share_doc_id_6h, user_recent_share_doc_id_7d, user_recent_share_doc_keyword_180d, user_recent_share_doc_keyword_1d, user_recent_share_doc_keyword_1h, user_recent_share_doc_keyword_30d, user_recent_share_doc_keyword_6h, user_recent_share_doc_keyword_7d, user_recent_share_doc_tags_180d, user_recent_share_doc_tags_1d, user_recent_share_doc_tags_1h, user_recent_share_doc_tags_30d, user_recent_share_doc_tags_6h, user_recent_share_doc_tags_7d, user_recent_share_doc_topic_tag_180d, user_recent_share_doc_topic_tag_1d, user_recent_share_doc_topic_tag_1h, user_recent_share_doc_topic_tag_30d, user_recent_share_doc_topic_tag_6h, user_recent_share_doc_topic_tag_7d, user_recent_share_doc_type_180d, user_recent_share_doc_type_1d, user_recent_share_doc_type_1h, user_recent_share_doc_type_30d, user_recent_share_doc_type_6h, user_recent_share_doc_type_7d, user_register_time, user_st_1d_doc_author_id_cart_cp, user_st_1d_doc_author_id_cart_recent, user_st_1d_doc_author_id_click_cp, user_st_1d_doc_author_id_click_recent, user_st_1d_doc_author_id_conversion_cp, user_st_1d_doc_author_id_conversion_recent, user_st_1d_doc_author_id_favorite_cp, user_st_1d_doc_author_id_favorite_recent, user_st_1d_doc_author_id_praise_cp, user_st_1d_doc_author_id_praise_recent, user_st_1d_doc_author_id_query_cp, user_st_1d_doc_author_id_query_recent, user_st_1d_doc_cate1_cart_cp, user_st_1d_doc_cate1_cart_recent, user_st_1d_doc_cate1_click_cp, user_st_1d_doc_cate1_click_recent, user_st_1d_doc_cate1_conversion_cp, user_st_1d_doc_cate1_conversion_recent, user_st_1d_doc_cate1_favorite_cp, user_st_1d_doc_cate1_favorite_recent, user_st_1d_doc_cate1_praise_cp, user_st_1d_doc_cate1_praise_recent, user_st_1d_doc_cate1_query_cp, user_st_1d_doc_cate1_query_recent, user_st_1d_doc_cate2_cart_cp, user_st_1d_doc_cate2_cart_recent, user_st_1d_doc_cate2_click_cp, user_st_1d_doc_cate2_click_recent, user_st_1d_doc_cate2_conversion_cp, user_st_1d_doc_cate2_conversion_recent, user_st_1d_doc_cate2_favorite_cp, user_st_1d_doc_cate2_favorite_recent, user_st_1d_doc_cate2_praise_cp, user_st_1d_doc_cate2_praise_recent, user_st_1d_doc_cate2_query_cp, user_st_1d_doc_cate2_query_recent, user_st_1d_doc_cate3_cart_cp, user_st_1d_doc_cate3_cart_recent, user_st_1d_doc_cate3_click_cp, user_st_1d_doc_cate3_click_recent, user_st_1d_doc_cate3_conversion_cp, user_st_1d_doc_cate3_conversion_recent, user_st_1d_doc_cate3_favorite_cp, user_st_1d_doc_cate3_favorite_recent, user_st_1d_doc_cate3_praise_cp, user_st_1d_doc_cate3_praise_recent, user_st_1d_doc_cate3_query_cp, user_st_1d_doc_cate3_query_recent, user_st_1d_doc_id_cart_cp, user_st_1d_doc_id_cart_recent, user_st_1d_doc_id_click_cp, user_st_1d_doc_id_click_recent, user_st_1d_doc_id_conversion_cp, user_st_1d_doc_id_conversion_recent, user_st_1d_doc_id_favorite_cp, user_st_1d_doc_id_favorite_recent, user_st_1d_doc_id_praise_cp, user_st_1d_doc_id_praise_recent, user_st_1d_doc_id_query_cp, user_st_1d_doc_id_query_recent, user_st_1d_doc_keyword_cart_cp, user_st_1d_doc_keyword_cart_recent, user_st_1d_doc_keyword_click_cp, user_st_1d_doc_keyword_click_recent, user_st_1d_doc_keyword_conversion_cp, user_st_1d_doc_keyword_conversion_recent, user_st_1d_doc_keyword_favorite_cp, user_st_1d_doc_keyword_favorite_recent, user_st_1d_doc_keyword_praise_cp, user_st_1d_doc_keyword_praise_recent, user_st_1d_doc_keyword_query_cp, user_st_1d_doc_keyword_query_recent, user_st_1d_doc_tags_cart_cp, user_st_1d_doc_tags_cart_recent, user_st_1d_doc_tags_click_cp, user_st_1d_doc_tags_click_recent, user_st_1d_doc_tags_conversion_cp, user_st_1d_doc_tags_conversion_recent, user_st_1d_doc_tags_favorite_cp, user_st_1d_doc_tags_favorite_recent, user_st_1d_doc_tags_praise_cp, user_st_1d_doc_tags_praise_recent, user_st_1d_doc_tags_query_cp, user_st_1d_doc_tags_query_recent, user_st_1d_doc_title_terms_cart_cp, user_st_1d_doc_title_terms_cart_recent, user_st_1d_doc_title_terms_click_cp, user_st_1d_doc_title_terms_click_recent, user_st_1d_doc_title_terms_conversion_cp, user_st_1d_doc_title_terms_conversion_recent, user_st_1d_doc_title_terms_favorite_cp, user_st_1d_doc_title_terms_favorite_recent, user_st_1d_doc_title_terms_praise_cp, user_st_1d_doc_title_terms_praise_recent, user_st_1d_doc_title_terms_query_cp, user_st_1d_doc_title_terms_query_recent, user_st_7d_doc_author_id_cart_cp, user_st_7d_doc_author_id_cart_recent, user_st_7d_doc_author_id_click_cp, user_st_7d_doc_author_id_click_recent, user_st_7d_doc_author_id_conversion_cp, user_st_7d_doc_author_id_conversion_recent, user_st_7d_doc_author_id_favorite_cp, user_st_7d_doc_author_id_favorite_recent, user_st_7d_doc_author_id_praise_cp, user_st_7d_doc_author_id_praise_recent, user_st_7d_doc_author_id_query_cp, user_st_7d_doc_author_id_query_recent, user_st_7d_doc_cate1_cart_cp, user_st_7d_doc_cate1_cart_recent, user_st_7d_doc_cate1_click_cp, user_st_7d_doc_cate1_click_recent, user_st_7d_doc_cate1_conversion_cp, user_st_7d_doc_cate1_conversion_recent, user_st_7d_doc_cate1_favorite_cp, user_st_7d_doc_cate1_favorite_recent, user_st_7d_doc_cate1_praise_cp, user_st_7d_doc_cate1_praise_recent, user_st_7d_doc_cate1_query_cp, user_st_7d_doc_cate1_query_recent, user_st_7d_doc_cate2_cart_cp, user_st_7d_doc_cate2_cart_recent, user_st_7d_doc_cate2_click_cp, user_st_7d_doc_cate2_click_recent, user_st_7d_doc_cate2_conversion_cp, user_st_7d_doc_cate2_conversion_recent, user_st_7d_doc_cate2_favorite_cp, user_st_7d_doc_cate2_favorite_recent, user_st_7d_doc_cate2_praise_cp, user_st_7d_doc_cate2_praise_recent, user_st_7d_doc_cate2_query_cp, user_st_7d_doc_cate2_query_recent, user_st_7d_doc_cate3_cart_cp, user_st_7d_doc_cate3_cart_recent, user_st_7d_doc_cate3_click_cp, user_st_7d_doc_cate3_click_recent, user_st_7d_doc_cate3_conversion_cp, user_st_7d_doc_cate3_conversion_recent, user_st_7d_doc_cate3_favorite_cp, user_st_7d_doc_cate3_favorite_recent, user_st_7d_doc_cate3_praise_cp, user_st_7d_doc_cate3_praise_recent, user_st_7d_doc_cate3_query_cp, user_st_7d_doc_cate3_query_recent, user_st_7d_doc_id_cart_cp, user_st_7d_doc_id_cart_recent, user_st_7d_doc_id_click_cp, user_st_7d_doc_id_click_recent, user_st_7d_doc_id_conversion_cp, user_st_7d_doc_id_conversion_recent, user_st_7d_doc_id_favorite_cp, user_st_7d_doc_id_favorite_recent, user_st_7d_doc_id_praise_cp, user_st_7d_doc_id_praise_recent, user_st_7d_doc_id_query_cp, user_st_7d_doc_id_query_recent, user_st_7d_doc_keyword_cart_cp, user_st_7d_doc_keyword_cart_recent, user_st_7d_doc_keyword_click_cp, user_st_7d_doc_keyword_click_recent, user_st_7d_doc_keyword_conversion_cp, user_st_7d_doc_keyword_conversion_recent, user_st_7d_doc_keyword_favorite_cp, user_st_7d_doc_keyword_favorite_recent, user_st_7d_doc_keyword_praise_cp, user_st_7d_doc_keyword_praise_recent, user_st_7d_doc_keyword_query_cp, user_st_7d_doc_keyword_query_recent, user_st_7d_doc_tags_cart_cp, user_st_7d_doc_tags_cart_recent, user_st_7d_doc_tags_click_cp, user_st_7d_doc_tags_click_recent, user_st_7d_doc_tags_conversion_cp, user_st_7d_doc_tags_conversion_recent, user_st_7d_doc_tags_favorite_cp, user_st_7d_doc_tags_favorite_recent, user_st_7d_doc_tags_praise_cp, user_st_7d_doc_tags_praise_recent, user_st_7d_doc_tags_query_cp, user_st_7d_doc_tags_query_recent, user_st_7d_doc_title_terms_cart_cp, user_st_7d_doc_title_terms_cart_recent, user_st_7d_doc_title_terms_click_cp, user_st_7d_doc_title_terms_click_recent, user_st_7d_doc_title_terms_conversion_cp, user_st_7d_doc_title_terms_conversion_recent, user_st_7d_doc_title_terms_favorite_cp, user_st_7d_doc_title_terms_favorite_recent, user_st_7d_doc_title_terms_praise_cp, user_st_7d_doc_title_terms_praise_recent, user_st_7d_doc_title_terms_query_cp, user_st_7d_doc_title_terms_query_recent, user_tags_list feature_name=f_area; depend=area; method=DirectString; slot=312; shared=true; feature_id=368807 feature_name=f_att_traced; depend=att_traced; method=DirectInt32; slot=322; shared=true; feature_id=368817 feature_name=f_bhv_scm; depend=bhv_scm; method=DirectString; slot=316; shared=true; feature_id=368811 feature_name=f_bhv_spm; depend=bhv_spm; method=DirectString; slot=317; shared=true; feature_id=368812 feature_name=f_bhv_spm_1; depend=bhv_spm_1; method=DirectString; slot=318; shared=true; feature_id=368813 feature_name=f_bhv_spm_2; depend=bhv_spm_2; method=DirectString; slot=319; shared=true; feature_id=368814 feature_name=f_bhv_spm_3; depend=bhv_spm_3; method=DirectString; slot=320; shared=true; feature_id=368815 feature_name=f_bhv_spm_4; depend=bhv_spm_4; method=DirectString; slot=321; shared=true; feature_id=368816 feature_name=f_bhv_time_hour; depend=bhv_time_hour; method=DirectString; slot=314; shared=true; feature_id=368809 feature_name=f_bhv_time_monthday; depend=bhv_time_monthday; method=DirectString; slot=323; shared=true; feature_id=368818 feature_name=f_bhv_time_weekday; depend=bhv_time_weekday; method=DirectString; slot=315; shared=true; feature_id=368810 feature_name=f_city; depend=city; method=DirectString; slot=310; shared=true; feature_id=368805 feature_name=f_client_version; depend=client_version; method=DirectString; slot=304; shared=true; feature_id=368799 feature_name=f_country; depend=country; method=DirectString; slot=308; shared=true; feature_id=368803 feature_name=f_device_model; depend=device_model; method=DirectString; slot=306; shared=true; feature_id=368801 feature_name=f_district; depend=district; method=DirectString; slot=311; shared=true; feature_id=368806 feature_name=f_doc_author_fans_10; depend=doc_author_fans_10; method=DirectString; slot=222; feature_id=368786 feature_name=f_doc_author_id; depend=doc_author_id; method=DirectString; slot=206; feature_id=368770 feature_name=f_doc_author_level; depend=doc_author_level; method=DirectString; slot=221; feature_id=368785 feature_name=f_doc_author_name; depend=doc_author_name; method=DirectString; slot=220; feature_id=368784 feature_name=f_doc_cate1; depend=doc_cate1; method=DirectString; slot=203; feature_id=368767 feature_name=f_doc_cate2; depend=doc_cate2; method=DirectString; slot=204; feature_id=368768 feature_name=f_doc_cate3; depend=doc_cate3; method=DirectString; slot=205; feature_id=368769 feature_name=f_doc_collect_cnt_10; depend=doc_collect_cnt_10; method=DirectInt32; slot=209; feature_id=368773 feature_name=f_doc_collection; depend=doc_collection; method=DirectString; slot=223; feature_id=368787 feature_name=f_doc_comment_cnt_10; depend=doc_comment_cnt_10; method=DirectInt32; slot=211; feature_id=368775 feature_name=f_doc_content_length_2; depend=doc_content_length_2; method=DirectInt32; slot=239; feature_id=368791 feature_name=f_doc_create_time; depend=doc_create_time; method=DirectInt64; slot=240; feature_id=368792 feature_name=f_doc_detail_pic_num; depend=doc_detail_pic_num; method=DirectInt32; slot=217; feature_id=368781 feature_name=f_doc_expire_time; depend=doc_expire_time; method=DirectInt64; slot=242; feature_id=368794 feature_name=f_doc_id; depend=doc_id; method=DirectString; slot=200; feature_id=368764 feature_name=f_doc_keyword; depend=doc_keyword; method=VectorTopString; slot=219; feature_id=368783 feature_name=f_doc_location_tag; depend=doc_location_tag; method=DirectString; slot=225; feature_id=368789 feature_name=f_doc_pic_url; depend=doc_pic_url; method=VectorTopString; slot=214; feature_id=368778 feature_name=f_doc_praise_cnt_10; depend=doc_praise_cnt_10; method=DirectInt32; slot=210; feature_id=368774 feature_name=f_doc_pub_time; depend=doc_pub_time; method=DirectInt64; slot=241; feature_id=368793 feature_name=f_doc_rating; depend=doc_rating; method=VectorTopString; slot=213; feature_id=368777 feature_name=f_doc_related_goods_ids; depend=doc_related_goods_ids; method=VectorTopString; slot=218; feature_id=368782 feature_name=f_doc_share_cnt_10; depend=doc_share_cnt_10; method=DirectInt32; slot=208; feature_id=368772 feature_name=f_doc_source_id; depend=doc_source_id; method=DirectString; slot=212; feature_id=368776 feature_name=f_doc_tags; depend=doc_tags; method=VectorTopString; slot=207; feature_id=368771 feature_name=f_doc_title_length; depend=doc_title_length; method=DirectInt32; slot=238; feature_id=368790 feature_name=f_doc_title_terms; depend=doc_title_terms; method=VectorTopString; slot=201; feature_id=368765 feature_name=f_doc_topic_tag; depend=doc_topic_tag; method=DirectString; slot=224; feature_id=368788 feature_name=f_doc_type; depend=doc_type; method=DirectString; slot=202; feature_id=368766 feature_name=f_doc_video_duration_10; depend=doc_video_duration_10; method=DirectInt32; slot=216; feature_id=368780 feature_name=f_doc_video_url; depend=doc_video_url; method=VectorTopString; slot=215; feature_id=368779 feature_name=f_fake_context_id; depend=fake_context_id; method=DirectString; slot=307; shared=true; feature_id=368802 feature_name=f_goods_exposure_cnt_lt; depend=goods_exposure_cnt_lt; method=DirectInt64; slot=515; feature_id=449546 feature_name=f_goods_is_prepublic; depend=goods_is_prepublic; method=DirectInt32; slot=514; feature_id=431594; feature_version=2 feature_name=f_goods_op_rec_status; depend=goods_op_rec_status; method=DirectInt32; slot=512; feature_id=368819 feature_name=f_goods_quality_score; depend=goods_quality_score; method=DirectInt32; slot=512; feature_id=431593 feature_name=f_goods_rec_scene_id; depend=goods_rec_scene_id; method=DirectString; slot=512; feature_id=408248 feature_name=f_network; depend=network; method=DirectString; slot=305; shared=true; feature_id=368800 feature_name=f_os; depend=os; method=DirectString; slot=302; shared=true; feature_id=368797 feature_name=f_os_version; depend=os_version; method=DirectString; slot=303; shared=true; feature_id=368798 feature_name=f_page; depend=page; method=DirectString; slot=300; shared=true; feature_id=368795 feature_name=f_platform; depend=platform; method=DirectString; slot=301; shared=true; feature_id=368796 feature_name=f_province; depend=province; method=DirectString; slot=309; shared=true; feature_id=368804 feature_name=f_time; depend=time; method=DirectInt32; slot=313; shared=true; feature_id=368808 feature_name=f_user_age; depend=user_age; method=DirectString; slot=2; shared=true; feature_id=368753 feature_name=f_user_area; depend=user_area; method=DirectString; slot=10; shared=true; feature_id=368761 feature_name=f_user_city; depend=user_city; method=DirectString; slot=8; shared=true; feature_id=368759 feature_name=f_user_country; depend=user_country; method=DirectString; slot=6; shared=true; feature_id=368757 feature_name=f_user_device_id; depend=user_device_id; method=DirectString; slot=4; shared=true; feature_id=368755 feature_name=f_user_district; depend=user_district; method=DirectString; slot=9; shared=true; feature_id=368760 feature_name=f_user_gender; depend=user_gender; method=DirectString; slot=3; shared=true; feature_id=368754 feature_name=f_user_id; depend=user_id; method=DirectString; slot=1; shared=true; feature_id=368752 feature_name=f_user_is_prepublic; depend=user_is_prepublic; method=DirectInt32; slot=514; shared=true; feature_id=431589; feature_version=2 feature_name=f_user_membership_level; depend=user_membership_level; method=DirectString; slot=5; shared=true; feature_id=368756 feature_name=f_user_province; depend=user_province; method=DirectString; slot=7; shared=true; feature_id=368758 feature_name=f_user_quality_score; depend=user_quality_score; method=DirectInt32; slot=515; shared=true; feature_id=431590 feature_name=f_user_register_time; depend=user_register_time; method=DirectInt64; slot=12; shared=true; feature_id=368763 feature_name=f_user_tags_list; depend=user_tags_list; method=VectorTopString; slot=11; shared=true; feature_id=368762 feature_name=fc_doc_id_post_click_180d_cnt; depend=doc_id_post_click_180d; method=TobInstanceProfilePairList; feature_version=2; args=1,4,0,1; slot=4735; shared=true; feature_id=369795 feature_name=fc_doc_id_post_click_180d_concat; depend=doc_id_post_click_180d; method=TobInstanceProfilePairList; feature_version=2; args=1,6,0,1; slot=4737; shared=true; feature_id=369797 feature_name=fc_doc_id_post_click_180d_smooth; depend=doc_id_post_click_180d; method=TobInstanceProfilePairList; feature_version=2; args=1,7,0,1,4,1000,10000; slot=4736; shared=true; feature_id=369796 feature_name=fc_doc_id_post_click_1d_cnt; depend=doc_id_post_click_1d; method=TobInstanceProfilePairList; feature_version=2; args=1,4,0,1; slot=4726; shared=true; feature_id=369786 feature_name=fc_doc_id_post_click_1d_concat; depend=doc_id_post_click_1d; method=TobInstanceProfilePairList; feature_version=2; args=1,6,0,1; slot=4728; shared=true; feature_id=369788 feature_name=fc_doc_id_post_click_1d_smooth; depend=doc_id_post_click_1d; method=TobInstanceProfilePairList; feature_version=2; args=1,7,0,1,4,1000,10000; slot=4727; shared=true; feature_id=369787 feature_name=fc_doc_id_post_click_1h_cnt; depend=doc_id_post_click_1h; method=TobInstanceProfilePairList; feature_version=2; args=1,4,0,1; slot=4720; shared=true; feature_id=369780 feature_name=fc_doc_id_post_click_1h_concat; depend=doc_id_post_click_1h; method=TobInstanceProfilePairList; feature_version=2; args=1,6,0,1; slot=4722; shared=true; feature_id=369782 feature_name=fc_doc_id_post_click_1h_smooth; depend=doc_id_post_click_1h; method=TobInstanceProfilePairList; feature_version=2; args=1,7,0,1,4,1000,10000; slot=4721; shared=true; feature_id=369781 feature_name=fc_doc_id_post_click_30d_cnt; depend=doc_id_post_click_30d; method=TobInstanceProfilePairList; feature_version=2; args=1,4,0,1; slot=4732; shared=true; feature_id=369792 feature_name=fc_doc_id_post_click_30d_concat; depend=doc_id_post_click_30d; method=TobInstanceProfilePairList; feature_version=2; args=1,6,0,1; slot=4734; shared=true; feature_id=369794 feature_name=fc_doc_id_post_click_30d_smooth; depend=doc_id_post_click_30d; method=TobInstanceProfilePairList; feature_version=2; args=1,7,0,1,4,1000,10000; slot=4733; shared=true; feature_id=369793 feature_name=fc_doc_id_post_click_6h_cnt; depend=doc_id_post_click_6h; method=TobInstanceProfilePairList; feature_version=2; args=1,4,0,1; slot=4723; shared=true; feature_id=369783 feature_name=fc_doc_id_post_click_6h_concat; depend=doc_id_post_click_6h; method=TobInstanceProfilePairList; feature_version=2; args=1,6,0,1; slot=4725; shared=true; feature_id=369785 feature_name=fc_doc_id_post_click_6h_smooth; depend=doc_id_post_click_6h; method=TobInstanceProfilePairList; feature_version=2; args=1,7,0,1,4,1000,10000; slot=4724; shared=true; feature_id=369784 feature_name=fc_doc_id_post_click_7d_cnt; depend=doc_id_post_click_7d; method=TobInstanceProfilePairList; feature_version=2; args=1,4,0,1; slot=4729; shared=true; feature_id=369789 feature_name=fc_doc_id_post_click_7d_concat; depend=doc_id_post_click_7d; method=TobInstanceProfilePairList; feature_version=2; args=1,6,0,1; slot=4731; shared=true; feature_id=369791 feature_name=fc_doc_id_post_click_7d_smooth; depend=doc_id_post_click_7d; method=TobInstanceProfilePairList; feature_version=2; args=1,7,0,1,4,1000,10000; slot=4730; shared=true; feature_id=369790 feature_name=fc_doc_id_post_favorite_180d_cnt; depend=doc_id_post_favorite_180d; method=TobInstanceProfilePairList; feature_version=2; args=1,4,0,1; slot=4753; shared=true; feature_id=369813 feature_name=fc_doc_id_post_favorite_180d_concat; depend=doc_id_post_favorite_180d; method=TobInstanceProfilePairList; feature_version=2; args=1,6,0,1; slot=4755; shared=true; feature_id=369815 feature_name=fc_doc_id_post_favorite_180d_smooth; depend=doc_id_post_favorite_180d; method=TobInstanceProfilePairList; feature_version=2; args=1,7,0,1,4,1000,10000; slot=4754; shared=true; feature_id=369814 feature_name=fc_doc_id_post_favorite_1d_cnt; depend=doc_id_post_favorite_1d; method=TobInstanceProfilePairList; feature_version=2; args=1,4,0,1; slot=4744; shared=true; feature_id=369804 feature_name=fc_doc_id_post_favorite_1d_concat; depend=doc_id_post_favorite_1d; method=TobInstanceProfilePairList; feature_version=2; args=1,6,0,1; slot=4746; shared=true; feature_id=369806 feature_name=fc_doc_id_post_favorite_1d_smooth; depend=doc_id_post_favorite_1d; method=TobInstanceProfilePairList; feature_version=2; args=1,7,0,1,4,1000,10000; slot=4745; shared=true; feature_id=369805 feature_name=fc_doc_id_post_favorite_1h_cnt; depend=doc_id_post_favorite_1h; method=TobInstanceProfilePairList; feature_version=2; args=1,4,0,1; slot=4738; shared=true; feature_id=369798 feature_name=fc_doc_id_post_favorite_1h_concat; depend=doc_id_post_favorite_1h; method=TobInstanceProfilePairList; feature_version=2; args=1,6,0,1; slot=4740; shared=true; feature_id=369800 feature_name=fc_doc_id_post_favorite_1h_smooth; depend=doc_id_post_favorite_1h; method=TobInstanceProfilePairList; feature_version=2; args=1,7,0,1,4,1000,10000; slot=4739; shared=true; feature_id=369799 feature_name=fc_doc_id_post_favorite_30d_cnt; depend=doc_id_post_favorite_30d; method=TobInstanceProfilePairList; feature_version=2; args=1,4,0,1; slot=4750; shared=true; feature_id=369810 feature_name=fc_doc_id_post_favorite_30d_concat; depend=doc_id_post_favorite_30d; method=TobInstanceProfilePairList; feature_version=2; args=1,6,0,1; slot=4752; shared=true; feature_id=369812 feature_name=fc_doc_id_post_favorite_30d_smooth; depend=doc_id_post_favorite_30d; method=TobInstanceProfilePairList; feature_version=2; args=1,7,0,1,4,1000,10000; slot=4751; shared=true; feature_id=369811 feature_name=fc_doc_id_post_favorite_6h_cnt; depend=doc_id_post_favorite_6h; method=TobInstanceProfilePairList; feature_version=2; args=1,4,0,1; slot=4741; shared=true; feature_id=369801 feature_name=fc_doc_id_post_favorite_6h_concat; depend=doc_id_post_favorite_6h; method=TobInstanceProfilePairList; feature_version=2; args=1,6,0,1; slot=4743; shared=true; feature_id=369803 feature_name=fc_doc_id_post_favorite_6h_smooth; depend=doc_id_post_favorite_6h; method=TobInstanceProfilePairList; feature_version=2; args=1,7,0,1,4,1000,10000; slot=4742; shared=true; feature_id=369802 feature_name=fc_doc_id_post_favorite_7d_cnt; depend=doc_id_post_favorite_7d; method=TobInstanceProfilePairList; feature_version=2; args=1,4,0,1; slot=4747; shared=true; feature_id=369807 feature_name=fc_doc_id_post_favorite_7d_concat; depend=doc_id_post_favorite_7d; method=TobInstanceProfilePairList; feature_version=2; args=1,6,0,1; slot=4749; shared=true; feature_id=369809 feature_name=fc_doc_id_post_favorite_7d_smooth; depend=doc_id_post_favorite_7d; method=TobInstanceProfilePairList; feature_version=2; args=1,7,0,1,4,1000,10000; slot=4748; shared=true; feature_id=369808 feature_name=fc_doc_id_post_praise_180d_cnt; depend=doc_id_post_praise_180d; method=TobInstanceProfilePairList; feature_version=2; args=1,4,0,1; slot=4789; shared=true; feature_id=369849 feature_name=fc_doc_id_post_praise_180d_concat; depend=doc_id_post_praise_180d; method=TobInstanceProfilePairList; feature_version=2; args=1,6,0,1; slot=4791; shared=true; feature_id=369851 feature_name=fc_doc_id_post_praise_180d_smooth; depend=doc_id_post_praise_180d; method=TobInstanceProfilePairList; feature_version=2; args=1,7,0,1,4,1000,10000; slot=4790; shared=true; feature_id=369850 feature_name=fc_doc_id_post_praise_1d_cnt; depend=doc_id_post_praise_1d; method=TobInstanceProfilePairList; feature_version=2; args=1,4,0,1; slot=4780; shared=true; feature_id=369840 feature_name=fc_doc_id_post_praise_1d_concat; depend=doc_id_post_praise_1d; method=TobInstanceProfilePairList; feature_version=2; args=1,6,0,1; slot=4782; shared=true; feature_id=369842 feature_name=fc_doc_id_post_praise_1d_smooth; depend=doc_id_post_praise_1d; method=TobInstanceProfilePairList; feature_version=2; args=1,7,0,1,4,1000,10000; slot=4781; shared=true; feature_id=369841 feature_name=fc_doc_id_post_praise_1h_cnt; depend=doc_id_post_praise_1h; method=TobInstanceProfilePairList; feature_version=2; args=1,4,0,1; slot=4774; shared=true; feature_id=369834 feature_name=fc_doc_id_post_praise_1h_concat; depend=doc_id_post_praise_1h; method=TobInstanceProfilePairList; feature_version=2; args=1,6,0,1; slot=4776; shared=true; feature_id=369836 feature_name=fc_doc_id_post_praise_1h_smooth; depend=doc_id_post_praise_1h; method=TobInstanceProfilePairList; feature_version=2; args=1,7,0,1,4,1000,10000; slot=4775; shared=true; feature_id=369835 feature_name=fc_doc_id_post_praise_30d_cnt; depend=doc_id_post_praise_30d; method=TobInstanceProfilePairList; feature_version=2; args=1,4,0,1; slot=4786; shared=true; feature_id=369846 feature_name=fc_doc_id_post_praise_30d_concat; depend=doc_id_post_praise_30d; method=TobInstanceProfilePairList; feature_version=2; args=1,6,0,1; slot=4788; shared=true; feature_id=369848 feature_name=fc_doc_id_post_praise_30d_smooth; depend=doc_id_post_praise_30d; method=TobInstanceProfilePairList; feature_version=2; args=1,7,0,1,4,1000,10000; slot=4787; shared=true; feature_id=369847 feature_name=fc_doc_id_post_praise_6h_cnt; depend=doc_id_post_praise_6h; method=TobInstanceProfilePairList; feature_version=2; args=1,4,0,1; slot=4777; shared=true; feature_id=369837 feature_name=fc_doc_id_post_praise_6h_concat; depend=doc_id_post_praise_6h; method=TobInstanceProfilePairList; feature_version=2; args=1,6,0,1; slot=4779; shared=true; feature_id=369839 feature_name=fc_doc_id_post_praise_6h_smooth; depend=doc_id_post_praise_6h; method=TobInstanceProfilePairList; feature_version=2; args=1,7,0,1,4,1000,10000; slot=4778; shared=true; feature_id=369838 feature_name=fc_doc_id_post_praise_7d_cnt; depend=doc_id_post_praise_7d; method=TobInstanceProfilePairList; feature_version=2; args=1,4,0,1; slot=4783; shared=true; feature_id=369843 feature_name=fc_doc_id_post_praise_7d_concat; depend=doc_id_post_praise_7d; method=TobInstanceProfilePairList; feature_version=2; args=1,6,0,1; slot=4785; shared=true; feature_id=369845 feature_name=fc_doc_id_post_praise_7d_smooth; depend=doc_id_post_praise_7d; method=TobInstanceProfilePairList; feature_version=2; args=1,7,0,1,4,1000,10000; slot=4784; shared=true; feature_id=369844 feature_name=fc_doc_id_post_share_180d_cnt; depend=doc_id_post_share_180d; method=TobInstanceProfilePairList; feature_version=2; args=1,4,0,1; slot=4771; shared=true; feature_id=369831 feature_name=fc_doc_id_post_share_180d_concat; depend=doc_id_post_share_180d; method=TobInstanceProfilePairList; feature_version=2; args=1,6,0,1; slot=4773; shared=true; feature_id=369833 feature_name=fc_doc_id_post_share_180d_smooth; depend=doc_id_post_share_180d; method=TobInstanceProfilePairList; feature_version=2; args=1,7,0,1,4,1000,10000; slot=4772; shared=true; feature_id=369832 feature_name=fc_doc_id_post_share_1d_cnt; depend=doc_id_post_share_1d; method=TobInstanceProfilePairList; feature_version=2; args=1,4,0,1; slot=4762; shared=true; feature_id=369822 feature_name=fc_doc_id_post_share_1d_concat; depend=doc_id_post_share_1d; method=TobInstanceProfilePairList; feature_version=2; args=1,6,0,1; slot=4764; shared=true; feature_id=369824 feature_name=fc_doc_id_post_share_1d_smooth; depend=doc_id_post_share_1d; method=TobInstanceProfilePairList; feature_version=2; args=1,7,0,1,4,1000,10000; slot=4763; shared=true; feature_id=369823 feature_name=fc_doc_id_post_share_1h_cnt; depend=doc_id_post_share_1h; method=TobInstanceProfilePairList; feature_version=2; args=1,4,0,1; slot=4756; shared=true; feature_id=369816 feature_name=fc_doc_id_post_share_1h_concat; depend=doc_id_post_share_1h; method=TobInstanceProfilePairList; feature_version=2; args=1,6,0,1; slot=4758; shared=true; feature_id=369818 feature_name=fc_doc_id_post_share_1h_smooth; depend=doc_id_post_share_1h; method=TobInstanceProfilePairList; feature_version=2; args=1,7,0,1,4,1000,10000; slot=4757; shared=true; feature_id=369817 feature_name=fc_doc_id_post_share_30d_cnt; depend=doc_id_post_share_30d; method=TobInstanceProfilePairList; feature_version=2; args=1,4,0,1; slot=4768; shared=true; feature_id=369828 feature_name=fc_doc_id_post_share_30d_concat; depend=doc_id_post_share_30d; method=TobInstanceProfilePairList; feature_version=2; args=1,6,0,1; slot=4770; shared=true; feature_id=369830 feature_name=fc_doc_id_post_share_30d_smooth; depend=doc_id_post_share_30d; method=TobInstanceProfilePairList; feature_version=2; args=1,7,0,1,4,1000,10000; slot=4769; shared=true; feature_id=369829 feature_name=fc_doc_id_post_share_6h_cnt; depend=doc_id_post_share_6h; method=TobInstanceProfilePairList; feature_version=2; args=1,4,0,1; slot=4759; shared=true; feature_id=369819 feature_name=fc_doc_id_post_share_6h_concat; depend=doc_id_post_share_6h; method=TobInstanceProfilePairList; feature_version=2; args=1,6,0,1; slot=4761; shared=true; feature_id=369821 feature_name=fc_doc_id_post_share_6h_smooth; depend=doc_id_post_share_6h; method=TobInstanceProfilePairList; feature_version=2; args=1,7,0,1,4,1000,10000; slot=4760; shared=true; feature_id=369820 feature_name=fc_doc_id_post_share_7d_cnt; depend=doc_id_post_share_7d; method=TobInstanceProfilePairList; feature_version=2; args=1,4,0,1; slot=4765; shared=true; feature_id=369825 feature_name=fc_doc_id_post_share_7d_concat; depend=doc_id_post_share_7d; method=TobInstanceProfilePairList; feature_version=2; args=1,6,0,1; slot=4767; shared=true; feature_id=369827 feature_name=fc_doc_id_post_share_7d_smooth; depend=doc_id_post_share_7d; method=TobInstanceProfilePairList; feature_version=2; args=1,7,0,1,4,1000,10000; slot=4766; shared=true; feature_id=369826 feature_name=fc_user_lt_doc_author_id_cart_cp; depend=user_lt_doc_author_id_cart_cp; method=VectorTopString; feature_version=2; args=50; slot=1361; shared=true; feature_id=368933 feature_name=fc_user_lt_doc_author_id_click_cp; depend=user_lt_doc_author_id_click_cp; method=VectorTopString; feature_version=2; args=50; slot=1332; shared=true; feature_id=368904 feature_name=fc_user_lt_doc_author_id_conversion_cp; depend=user_lt_doc_author_id_conversion_cp; method=VectorTopString; feature_version=2; args=50; slot=1385; shared=true; feature_id=368957 feature_name=fc_user_lt_doc_author_id_favorite_cp; depend=user_lt_doc_author_id_favorite_cp; method=VectorTopString; feature_version=2; args=50; slot=1333; shared=true; feature_id=368905 feature_name=fc_user_lt_doc_author_id_praise_cp; depend=user_lt_doc_author_id_praise_cp; method=VectorTopString; feature_version=2; args=50; slot=1334; shared=true; feature_id=368906 feature_name=fc_user_lt_doc_author_id_query_cp; depend=user_lt_doc_author_id_query_cp; method=VectorTopString; feature_version=2; args=50; slot=1335; shared=true; feature_id=368907 feature_name=fc_user_lt_doc_cate1_cart_cp; depend=user_lt_doc_cate1_cart_cp; method=VectorTopString; feature_version=2; args=50; slot=1352; shared=true; feature_id=368924 feature_name=fc_user_lt_doc_cate1_click_cp; depend=user_lt_doc_cate1_click_cp; method=VectorTopString; feature_version=2; args=50; slot=1320; shared=true; feature_id=368892 feature_name=fc_user_lt_doc_cate1_conversion_cp; depend=user_lt_doc_cate1_conversion_cp; method=VectorTopString; feature_version=2; args=50; slot=1376; shared=true; feature_id=368948 feature_name=fc_user_lt_doc_cate1_favorite_cp; depend=user_lt_doc_cate1_favorite_cp; method=VectorTopString; feature_version=2; args=50; slot=1321; shared=true; feature_id=368893 feature_name=fc_user_lt_doc_cate1_praise_cp; depend=user_lt_doc_cate1_praise_cp; method=VectorTopString; feature_version=2; args=50; slot=1322; shared=true; feature_id=368894 feature_name=fc_user_lt_doc_cate1_query_cp; depend=user_lt_doc_cate1_query_cp; method=VectorTopString; feature_version=2; args=50; slot=1323; shared=true; feature_id=368895 feature_name=fc_user_lt_doc_cate2_cart_cp; depend=user_lt_doc_cate2_cart_cp; method=VectorTopString; feature_version=2; args=50; slot=1355; shared=true; feature_id=368927 feature_name=fc_user_lt_doc_cate2_click_cp; depend=user_lt_doc_cate2_click_cp; method=VectorTopString; feature_version=2; args=50; slot=1324; shared=true; feature_id=368896 feature_name=fc_user_lt_doc_cate2_conversion_cp; depend=user_lt_doc_cate2_conversion_cp; method=VectorTopString; feature_version=2; args=50; slot=1379; shared=true; feature_id=368951 feature_name=fc_user_lt_doc_cate2_favorite_cp; depend=user_lt_doc_cate2_favorite_cp; method=VectorTopString; feature_version=2; args=50; slot=1325; shared=true; feature_id=368897 feature_name=fc_user_lt_doc_cate2_praise_cp; depend=user_lt_doc_cate2_praise_cp; method=VectorTopString; feature_version=2; args=50; slot=1326; shared=true; feature_id=368898 feature_name=fc_user_lt_doc_cate2_query_cp; depend=user_lt_doc_cate2_query_cp; method=VectorTopString; feature_version=2; args=50; slot=1327; shared=true; feature_id=368899 feature_name=fc_user_lt_doc_cate3_cart_cp; depend=user_lt_doc_cate3_cart_cp; method=VectorTopString; feature_version=2; args=50; slot=1358; shared=true; feature_id=368930 feature_name=fc_user_lt_doc_cate3_click_cp; depend=user_lt_doc_cate3_click_cp; method=VectorTopString; feature_version=2; args=50; slot=1328; shared=true; feature_id=368900 feature_name=fc_user_lt_doc_cate3_conversion_cp; depend=user_lt_doc_cate3_conversion_cp; method=VectorTopString; feature_version=2; args=50; slot=1382; shared=true; feature_id=368954 feature_name=fc_user_lt_doc_cate3_favorite_cp; depend=user_lt_doc_cate3_favorite_cp; method=VectorTopString; feature_version=2; args=50; slot=1329; shared=true; feature_id=368901 feature_name=fc_user_lt_doc_cate3_praise_cp; depend=user_lt_doc_cate3_praise_cp; method=VectorTopString; feature_version=2; args=50; slot=1330; shared=true; feature_id=368902 feature_name=fc_user_lt_doc_cate3_query_cp; depend=user_lt_doc_cate3_query_cp; method=VectorTopString; feature_version=2; args=50; slot=1331; shared=true; feature_id=368903 feature_name=fc_user_lt_doc_id_cart_cp; depend=user_lt_doc_id_cart_cp; method=VectorTopString; feature_version=2; args=50; slot=1346; shared=true; feature_id=368918 feature_name=fc_user_lt_doc_id_click_cp; depend=user_lt_doc_id_click_cp; method=VectorTopString; feature_version=2; args=50; slot=1312; shared=true; feature_id=368884 feature_name=fc_user_lt_doc_id_conversion_cp; depend=user_lt_doc_id_conversion_cp; method=VectorTopString; feature_version=2; args=50; slot=1370; shared=true; feature_id=368942 feature_name=fc_user_lt_doc_id_favorite_cp; depend=user_lt_doc_id_favorite_cp; method=VectorTopString; feature_version=2; args=50; slot=1313; shared=true; feature_id=368885 feature_name=fc_user_lt_doc_id_praise_cp; depend=user_lt_doc_id_praise_cp; method=VectorTopString; feature_version=2; args=50; slot=1314; shared=true; feature_id=368886 feature_name=fc_user_lt_doc_id_query_cp; depend=user_lt_doc_id_query_cp; method=VectorTopString; feature_version=2; args=50; slot=1315; shared=true; feature_id=368887 feature_name=fc_user_lt_doc_keyword_cart_cp; depend=user_lt_doc_keyword_cart_cp; method=VectorTopString; feature_version=2; args=50; slot=1367; shared=true; feature_id=368939 feature_name=fc_user_lt_doc_keyword_click_cp; depend=user_lt_doc_keyword_click_cp; method=VectorTopString; feature_version=2; args=50; slot=1340; shared=true; feature_id=368912 feature_name=fc_user_lt_doc_keyword_conversion_cp; depend=user_lt_doc_keyword_conversion_cp; method=VectorTopString; feature_version=2; args=50; slot=1391; shared=true; feature_id=368963 feature_name=fc_user_lt_doc_keyword_favorite_cp; depend=user_lt_doc_keyword_favorite_cp; method=VectorTopString; feature_version=2; args=50; slot=1341; shared=true; feature_id=368913 feature_name=fc_user_lt_doc_keyword_praise_cp; depend=user_lt_doc_keyword_praise_cp; method=VectorTopString; feature_version=2; args=50; slot=1342; shared=true; feature_id=368914 feature_name=fc_user_lt_doc_keyword_query_cp; depend=user_lt_doc_keyword_query_cp; method=VectorTopString; feature_version=2; args=50; slot=1343; shared=true; feature_id=368915 feature_name=fc_user_lt_doc_tags_cart_cp; depend=user_lt_doc_tags_cart_cp; method=VectorTopString; feature_version=2; args=50; slot=1364; shared=true; feature_id=368936 feature_name=fc_user_lt_doc_tags_click_cp; depend=user_lt_doc_tags_click_cp; method=VectorTopString; feature_version=2; args=50; slot=1336; shared=true; feature_id=368908 feature_name=fc_user_lt_doc_tags_conversion_cp; depend=user_lt_doc_tags_conversion_cp; method=VectorTopString; feature_version=2; args=50; slot=1388; shared=true; feature_id=368960 feature_name=fc_user_lt_doc_tags_favorite_cp; depend=user_lt_doc_tags_favorite_cp; method=VectorTopString; feature_version=2; args=50; slot=1337; shared=true; feature_id=368909 feature_name=fc_user_lt_doc_tags_praise_cp; depend=user_lt_doc_tags_praise_cp; method=VectorTopString; feature_version=2; args=50; slot=1338; shared=true; feature_id=368910 feature_name=fc_user_lt_doc_tags_query_cp; depend=user_lt_doc_tags_query_cp; method=VectorTopString; feature_version=2; args=50; slot=1339; shared=true; feature_id=368911 feature_name=fc_user_lt_doc_title_terms_cart_cp; depend=user_lt_doc_title_terms_cart_cp; method=VectorTopString; feature_version=2; args=50; slot=1349; shared=true; feature_id=368921 feature_name=fc_user_lt_doc_title_terms_click_cp; depend=user_lt_doc_title_terms_click_cp; method=VectorTopString; feature_version=2; args=50; slot=1316; shared=true; feature_id=368888 feature_name=fc_user_lt_doc_title_terms_conversion_cp; depend=user_lt_doc_title_terms_conversion_cp; method=VectorTopString; feature_version=2; args=50; slot=1373; shared=true; feature_id=368945 feature_name=fc_user_lt_doc_title_terms_favorite_cp; depend=user_lt_doc_title_terms_favorite_cp; method=VectorTopString; feature_version=2; args=50; slot=1317; shared=true; feature_id=368889 feature_name=fc_user_lt_doc_title_terms_praise_cp; depend=user_lt_doc_title_terms_praise_cp; method=VectorTopString; feature_version=2; args=50; slot=1318; shared=true; feature_id=368890 feature_name=fc_user_lt_doc_title_terms_query_cp; depend=user_lt_doc_title_terms_query_cp; method=VectorTopString; feature_version=2; args=50; slot=1319; shared=true; feature_id=368891 feature_name=fc_user_recent_click_doc_cate1_180d; depend=user_recent_click_doc_cate1_180d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4357; shared=true; feature_id=369417 feature_name=fc_user_recent_click_doc_cate1_180d_has_match; depend=fc_user_recent_click_doc_cate1_180d,f_doc_cate1; method=HasMatch; feature_version=2; slot=4359; feature_id=369419 feature_name=fc_user_recent_click_doc_cate1_180d_tob_profile_match; depend=user_recent_click_doc_cate1_180d,f_doc_cate1; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4358; feature_id=369418 feature_name=fc_user_recent_click_doc_cate1_1d; depend=user_recent_click_doc_cate1_1d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4348; shared=true; feature_id=369408 feature_name=fc_user_recent_click_doc_cate1_1d_has_match; depend=fc_user_recent_click_doc_cate1_1d,f_doc_cate1; method=HasMatch; feature_version=2; slot=4350; feature_id=369410 feature_name=fc_user_recent_click_doc_cate1_1d_tob_profile_match; depend=user_recent_click_doc_cate1_1d,f_doc_cate1; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4349; feature_id=369409 feature_name=fc_user_recent_click_doc_cate1_1h; depend=user_recent_click_doc_cate1_1h; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4342; shared=true; feature_id=369402 feature_name=fc_user_recent_click_doc_cate1_1h_has_match; depend=fc_user_recent_click_doc_cate1_1h,f_doc_cate1; method=HasMatch; feature_version=2; slot=4344; feature_id=369404 feature_name=fc_user_recent_click_doc_cate1_1h_tob_profile_match; depend=user_recent_click_doc_cate1_1h,f_doc_cate1; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4343; feature_id=369403 feature_name=fc_user_recent_click_doc_cate1_30d; depend=user_recent_click_doc_cate1_30d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4354; shared=true; feature_id=369414 feature_name=fc_user_recent_click_doc_cate1_30d_has_match; depend=fc_user_recent_click_doc_cate1_30d,f_doc_cate1; method=HasMatch; feature_version=2; slot=4356; feature_id=369416 feature_name=fc_user_recent_click_doc_cate1_30d_tob_profile_match; depend=user_recent_click_doc_cate1_30d,f_doc_cate1; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4355; feature_id=369415 feature_name=fc_user_recent_click_doc_cate1_6h; depend=user_recent_click_doc_cate1_6h; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4345; shared=true; feature_id=369405 feature_name=fc_user_recent_click_doc_cate1_6h_has_match; depend=fc_user_recent_click_doc_cate1_6h,f_doc_cate1; method=HasMatch; feature_version=2; slot=4347; feature_id=369407 feature_name=fc_user_recent_click_doc_cate1_6h_tob_profile_match; depend=user_recent_click_doc_cate1_6h,f_doc_cate1; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4346; feature_id=369406 feature_name=fc_user_recent_click_doc_cate1_7d; depend=user_recent_click_doc_cate1_7d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4351; shared=true; feature_id=369411 feature_name=fc_user_recent_click_doc_cate1_7d_has_match; depend=fc_user_recent_click_doc_cate1_7d,f_doc_cate1; method=HasMatch; feature_version=2; slot=4353; feature_id=369413 feature_name=fc_user_recent_click_doc_cate1_7d_tob_profile_match; depend=user_recent_click_doc_cate1_7d,f_doc_cate1; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4352; feature_id=369412 feature_name=fc_user_recent_click_doc_cate2_180d; depend=user_recent_click_doc_cate2_180d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4393; shared=true; feature_id=369453 feature_name=fc_user_recent_click_doc_cate2_180d_has_match; depend=fc_user_recent_click_doc_cate2_180d,f_doc_cate2; method=HasMatch; feature_version=2; slot=4395; feature_id=369455 feature_name=fc_user_recent_click_doc_cate2_180d_tob_profile_match; depend=user_recent_click_doc_cate2_180d,f_doc_cate2; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4394; feature_id=369454 feature_name=fc_user_recent_click_doc_cate2_1d; depend=user_recent_click_doc_cate2_1d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4384; shared=true; feature_id=369444 feature_name=fc_user_recent_click_doc_cate2_1d_has_match; depend=fc_user_recent_click_doc_cate2_1d,f_doc_cate2; method=HasMatch; feature_version=2; slot=4386; feature_id=369446 feature_name=fc_user_recent_click_doc_cate2_1d_tob_profile_match; depend=user_recent_click_doc_cate2_1d,f_doc_cate2; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4385; feature_id=369445 feature_name=fc_user_recent_click_doc_cate2_1h; depend=user_recent_click_doc_cate2_1h; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4378; shared=true; feature_id=369438 feature_name=fc_user_recent_click_doc_cate2_1h_has_match; depend=fc_user_recent_click_doc_cate2_1h,f_doc_cate2; method=HasMatch; feature_version=2; slot=4380; feature_id=369440 feature_name=fc_user_recent_click_doc_cate2_1h_tob_profile_match; depend=user_recent_click_doc_cate2_1h,f_doc_cate2; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4379; feature_id=369439 feature_name=fc_user_recent_click_doc_cate2_30d; depend=user_recent_click_doc_cate2_30d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4390; shared=true; feature_id=369450 feature_name=fc_user_recent_click_doc_cate2_30d_has_match; depend=fc_user_recent_click_doc_cate2_30d,f_doc_cate2; method=HasMatch; feature_version=2; slot=4392; feature_id=369452 feature_name=fc_user_recent_click_doc_cate2_30d_tob_profile_match; depend=user_recent_click_doc_cate2_30d,f_doc_cate2; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4391; feature_id=369451 feature_name=fc_user_recent_click_doc_cate2_6h; depend=user_recent_click_doc_cate2_6h; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4381; shared=true; feature_id=369441 feature_name=fc_user_recent_click_doc_cate2_6h_has_match; depend=fc_user_recent_click_doc_cate2_6h,f_doc_cate2; method=HasMatch; feature_version=2; slot=4383; feature_id=369443 feature_name=fc_user_recent_click_doc_cate2_6h_tob_profile_match; depend=user_recent_click_doc_cate2_6h,f_doc_cate2; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4382; feature_id=369442 feature_name=fc_user_recent_click_doc_cate2_7d; depend=user_recent_click_doc_cate2_7d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4387; shared=true; feature_id=369447 feature_name=fc_user_recent_click_doc_cate2_7d_has_match; depend=fc_user_recent_click_doc_cate2_7d,f_doc_cate2; method=HasMatch; feature_version=2; slot=4389; feature_id=369449 feature_name=fc_user_recent_click_doc_cate2_7d_tob_profile_match; depend=user_recent_click_doc_cate2_7d,f_doc_cate2; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4388; feature_id=369448 feature_name=fc_user_recent_click_doc_cate3_180d; depend=user_recent_click_doc_cate3_180d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4375; shared=true; feature_id=369435 feature_name=fc_user_recent_click_doc_cate3_180d_has_match; depend=fc_user_recent_click_doc_cate3_180d,f_doc_cate3; method=HasMatch; feature_version=2; slot=4377; feature_id=369437 feature_name=fc_user_recent_click_doc_cate3_180d_tob_profile_match; depend=user_recent_click_doc_cate3_180d,f_doc_cate3; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4376; feature_id=369436 feature_name=fc_user_recent_click_doc_cate3_1d; depend=user_recent_click_doc_cate3_1d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4366; shared=true; feature_id=369426 feature_name=fc_user_recent_click_doc_cate3_1d_has_match; depend=fc_user_recent_click_doc_cate3_1d,f_doc_cate3; method=HasMatch; feature_version=2; slot=4368; feature_id=369428 feature_name=fc_user_recent_click_doc_cate3_1d_tob_profile_match; depend=user_recent_click_doc_cate3_1d,f_doc_cate3; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4367; feature_id=369427 feature_name=fc_user_recent_click_doc_cate3_1h; depend=user_recent_click_doc_cate3_1h; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4360; shared=true; feature_id=369420 feature_name=fc_user_recent_click_doc_cate3_1h_has_match; depend=fc_user_recent_click_doc_cate3_1h,f_doc_cate3; method=HasMatch; feature_version=2; slot=4362; feature_id=369422 feature_name=fc_user_recent_click_doc_cate3_1h_tob_profile_match; depend=user_recent_click_doc_cate3_1h,f_doc_cate3; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4361; feature_id=369421 feature_name=fc_user_recent_click_doc_cate3_30d; depend=user_recent_click_doc_cate3_30d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4372; shared=true; feature_id=369432 feature_name=fc_user_recent_click_doc_cate3_30d_has_match; depend=fc_user_recent_click_doc_cate3_30d,f_doc_cate3; method=HasMatch; feature_version=2; slot=4374; feature_id=369434 feature_name=fc_user_recent_click_doc_cate3_30d_tob_profile_match; depend=user_recent_click_doc_cate3_30d,f_doc_cate3; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4373; feature_id=369433 feature_name=fc_user_recent_click_doc_cate3_6h; depend=user_recent_click_doc_cate3_6h; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4363; shared=true; feature_id=369423 feature_name=fc_user_recent_click_doc_cate3_6h_has_match; depend=fc_user_recent_click_doc_cate3_6h,f_doc_cate3; method=HasMatch; feature_version=2; slot=4365; feature_id=369425 feature_name=fc_user_recent_click_doc_cate3_6h_tob_profile_match; depend=user_recent_click_doc_cate3_6h,f_doc_cate3; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4364; feature_id=369424 feature_name=fc_user_recent_click_doc_cate3_7d; depend=user_recent_click_doc_cate3_7d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4369; shared=true; feature_id=369429 feature_name=fc_user_recent_click_doc_cate3_7d_has_match; depend=fc_user_recent_click_doc_cate3_7d,f_doc_cate3; method=HasMatch; feature_version=2; slot=4371; feature_id=369431 feature_name=fc_user_recent_click_doc_cate3_7d_tob_profile_match; depend=user_recent_click_doc_cate3_7d,f_doc_cate3; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4370; feature_id=369430 feature_name=fc_user_recent_click_doc_id_180d; depend=user_recent_click_doc_id_180d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4411; shared=true; feature_id=369471 feature_name=fc_user_recent_click_doc_id_180d_has_match; depend=fc_user_recent_click_doc_id_180d,f_doc_id; method=HasMatch; feature_version=2; slot=4413; feature_id=369473 feature_name=fc_user_recent_click_doc_id_180d_tob_profile_match; depend=user_recent_click_doc_id_180d,f_doc_id; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4412; feature_id=369472 feature_name=fc_user_recent_click_doc_id_1d; depend=user_recent_click_doc_id_1d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4402; shared=true; feature_id=369462 feature_name=fc_user_recent_click_doc_id_1d_has_match; depend=fc_user_recent_click_doc_id_1d,f_doc_id; method=HasMatch; feature_version=2; slot=4404; feature_id=369464 feature_name=fc_user_recent_click_doc_id_1d_tob_profile_match; depend=user_recent_click_doc_id_1d,f_doc_id; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4403; feature_id=369463 feature_name=fc_user_recent_click_doc_id_1h; depend=user_recent_click_doc_id_1h; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4396; shared=true; feature_id=369456 feature_name=fc_user_recent_click_doc_id_1h_has_match; depend=fc_user_recent_click_doc_id_1h,f_doc_id; method=HasMatch; feature_version=2; slot=4398; feature_id=369458 feature_name=fc_user_recent_click_doc_id_1h_tob_profile_match; depend=user_recent_click_doc_id_1h,f_doc_id; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4397; feature_id=369457 feature_name=fc_user_recent_click_doc_id_30d; depend=user_recent_click_doc_id_30d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4408; shared=true; feature_id=369468 feature_name=fc_user_recent_click_doc_id_30d_has_match; depend=fc_user_recent_click_doc_id_30d,f_doc_id; method=HasMatch; feature_version=2; slot=4410; feature_id=369470 feature_name=fc_user_recent_click_doc_id_30d_tob_profile_match; depend=user_recent_click_doc_id_30d,f_doc_id; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4409; feature_id=369469 feature_name=fc_user_recent_click_doc_id_6h; depend=user_recent_click_doc_id_6h; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4399; shared=true; feature_id=369459 feature_name=fc_user_recent_click_doc_id_6h_has_match; depend=fc_user_recent_click_doc_id_6h,f_doc_id; method=HasMatch; feature_version=2; slot=4401; feature_id=369461 feature_name=fc_user_recent_click_doc_id_6h_tob_profile_match; depend=user_recent_click_doc_id_6h,f_doc_id; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4400; feature_id=369460 feature_name=fc_user_recent_click_doc_id_7d; depend=user_recent_click_doc_id_7d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4405; shared=true; feature_id=369465 feature_name=fc_user_recent_click_doc_id_7d_has_match; depend=fc_user_recent_click_doc_id_7d,f_doc_id; method=HasMatch; feature_version=2; slot=4407; feature_id=369467 feature_name=fc_user_recent_click_doc_id_7d_tob_profile_match; depend=user_recent_click_doc_id_7d,f_doc_id; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4406; feature_id=369466 feature_name=fc_user_recent_click_doc_keyword_180d; depend=user_recent_click_doc_keyword_180d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4339; shared=true; feature_id=369399 feature_name=fc_user_recent_click_doc_keyword_180d_has_match; depend=fc_user_recent_click_doc_keyword_180d,f_doc_keyword; method=HasMatch; feature_version=2; slot=4341; feature_id=369401 feature_name=fc_user_recent_click_doc_keyword_180d_tob_profile_match; depend=user_recent_click_doc_keyword_180d,f_doc_keyword; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4340; feature_id=369400 feature_name=fc_user_recent_click_doc_keyword_1d; depend=user_recent_click_doc_keyword_1d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4330; shared=true; feature_id=369390 feature_name=fc_user_recent_click_doc_keyword_1d_has_match; depend=fc_user_recent_click_doc_keyword_1d,f_doc_keyword; method=HasMatch; feature_version=2; slot=4332; feature_id=369392 feature_name=fc_user_recent_click_doc_keyword_1d_tob_profile_match; depend=user_recent_click_doc_keyword_1d,f_doc_keyword; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4331; feature_id=369391 feature_name=fc_user_recent_click_doc_keyword_1h; depend=user_recent_click_doc_keyword_1h; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4324; shared=true; feature_id=369384 feature_name=fc_user_recent_click_doc_keyword_1h_has_match; depend=fc_user_recent_click_doc_keyword_1h,f_doc_keyword; method=HasMatch; feature_version=2; slot=4326; feature_id=369386 feature_name=fc_user_recent_click_doc_keyword_1h_tob_profile_match; depend=user_recent_click_doc_keyword_1h,f_doc_keyword; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4325; feature_id=369385 feature_name=fc_user_recent_click_doc_keyword_30d; depend=user_recent_click_doc_keyword_30d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4336; shared=true; feature_id=369396 feature_name=fc_user_recent_click_doc_keyword_30d_has_match; depend=fc_user_recent_click_doc_keyword_30d,f_doc_keyword; method=HasMatch; feature_version=2; slot=4338; feature_id=369398 feature_name=fc_user_recent_click_doc_keyword_30d_tob_profile_match; depend=user_recent_click_doc_keyword_30d,f_doc_keyword; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4337; feature_id=369397 feature_name=fc_user_recent_click_doc_keyword_6h; depend=user_recent_click_doc_keyword_6h; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4327; shared=true; feature_id=369387 feature_name=fc_user_recent_click_doc_keyword_6h_has_match; depend=fc_user_recent_click_doc_keyword_6h,f_doc_keyword; method=HasMatch; feature_version=2; slot=4329; feature_id=369389 feature_name=fc_user_recent_click_doc_keyword_6h_tob_profile_match; depend=user_recent_click_doc_keyword_6h,f_doc_keyword; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4328; feature_id=369388 feature_name=fc_user_recent_click_doc_keyword_7d; depend=user_recent_click_doc_keyword_7d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4333; shared=true; feature_id=369393 feature_name=fc_user_recent_click_doc_keyword_7d_has_match; depend=fc_user_recent_click_doc_keyword_7d,f_doc_keyword; method=HasMatch; feature_version=2; slot=4335; feature_id=369395 feature_name=fc_user_recent_click_doc_keyword_7d_tob_profile_match; depend=user_recent_click_doc_keyword_7d,f_doc_keyword; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4334; feature_id=369394 feature_name=fc_user_recent_click_doc_tags_180d; depend=user_recent_click_doc_tags_180d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4429; shared=true; feature_id=369489 feature_name=fc_user_recent_click_doc_tags_180d_has_match; depend=fc_user_recent_click_doc_tags_180d,f_doc_tags; method=HasMatch; feature_version=2; slot=4431; feature_id=369491 feature_name=fc_user_recent_click_doc_tags_180d_tob_profile_match; depend=user_recent_click_doc_tags_180d,f_doc_tags; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4430; feature_id=369490 feature_name=fc_user_recent_click_doc_tags_1d; depend=user_recent_click_doc_tags_1d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4420; shared=true; feature_id=369480 feature_name=fc_user_recent_click_doc_tags_1d_has_match; depend=fc_user_recent_click_doc_tags_1d,f_doc_tags; method=HasMatch; feature_version=2; slot=4422; feature_id=369482 feature_name=fc_user_recent_click_doc_tags_1d_tob_profile_match; depend=user_recent_click_doc_tags_1d,f_doc_tags; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4421; feature_id=369481 feature_name=fc_user_recent_click_doc_tags_1h; depend=user_recent_click_doc_tags_1h; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4414; shared=true; feature_id=369474 feature_name=fc_user_recent_click_doc_tags_1h_has_match; depend=fc_user_recent_click_doc_tags_1h,f_doc_tags; method=HasMatch; feature_version=2; slot=4416; feature_id=369476 feature_name=fc_user_recent_click_doc_tags_1h_tob_profile_match; depend=user_recent_click_doc_tags_1h,f_doc_tags; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4415; feature_id=369475 feature_name=fc_user_recent_click_doc_tags_30d; depend=user_recent_click_doc_tags_30d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4426; shared=true; feature_id=369486 feature_name=fc_user_recent_click_doc_tags_30d_has_match; depend=fc_user_recent_click_doc_tags_30d,f_doc_tags; method=HasMatch; feature_version=2; slot=4428; feature_id=369488 feature_name=fc_user_recent_click_doc_tags_30d_tob_profile_match; depend=user_recent_click_doc_tags_30d,f_doc_tags; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4427; feature_id=369487 feature_name=fc_user_recent_click_doc_tags_6h; depend=user_recent_click_doc_tags_6h; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4417; shared=true; feature_id=369477 feature_name=fc_user_recent_click_doc_tags_6h_has_match; depend=fc_user_recent_click_doc_tags_6h,f_doc_tags; method=HasMatch; feature_version=2; slot=4419; feature_id=369479 feature_name=fc_user_recent_click_doc_tags_6h_tob_profile_match; depend=user_recent_click_doc_tags_6h,f_doc_tags; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4418; feature_id=369478 feature_name=fc_user_recent_click_doc_tags_7d; depend=user_recent_click_doc_tags_7d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4423; shared=true; feature_id=369483 feature_name=fc_user_recent_click_doc_tags_7d_has_match; depend=fc_user_recent_click_doc_tags_7d,f_doc_tags; method=HasMatch; feature_version=2; slot=4425; feature_id=369485 feature_name=fc_user_recent_click_doc_tags_7d_tob_profile_match; depend=user_recent_click_doc_tags_7d,f_doc_tags; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4424; feature_id=369484 feature_name=fc_user_recent_click_doc_topic_tag_180d; depend=user_recent_click_doc_topic_tag_180d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4321; shared=true; feature_id=369381 feature_name=fc_user_recent_click_doc_topic_tag_180d_has_match; depend=fc_user_recent_click_doc_topic_tag_180d,f_doc_topic_tag; method=HasMatch; feature_version=2; slot=4323; feature_id=369383 feature_name=fc_user_recent_click_doc_topic_tag_180d_tob_profile_match; depend=user_recent_click_doc_topic_tag_180d,f_doc_topic_tag; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4322; feature_id=369382 feature_name=fc_user_recent_click_doc_topic_tag_1d; depend=user_recent_click_doc_topic_tag_1d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4312; shared=true; feature_id=369372 feature_name=fc_user_recent_click_doc_topic_tag_1d_has_match; depend=fc_user_recent_click_doc_topic_tag_1d,f_doc_topic_tag; method=HasMatch; feature_version=2; slot=4314; feature_id=369374 feature_name=fc_user_recent_click_doc_topic_tag_1d_tob_profile_match; depend=user_recent_click_doc_topic_tag_1d,f_doc_topic_tag; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4313; feature_id=369373 feature_name=fc_user_recent_click_doc_topic_tag_1h; depend=user_recent_click_doc_topic_tag_1h; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4306; shared=true; feature_id=369366 feature_name=fc_user_recent_click_doc_topic_tag_1h_has_match; depend=fc_user_recent_click_doc_topic_tag_1h,f_doc_topic_tag; method=HasMatch; feature_version=2; slot=4308; feature_id=369368 feature_name=fc_user_recent_click_doc_topic_tag_1h_tob_profile_match; depend=user_recent_click_doc_topic_tag_1h,f_doc_topic_tag; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4307; feature_id=369367 feature_name=fc_user_recent_click_doc_topic_tag_30d; depend=user_recent_click_doc_topic_tag_30d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4318; shared=true; feature_id=369378 feature_name=fc_user_recent_click_doc_topic_tag_30d_has_match; depend=fc_user_recent_click_doc_topic_tag_30d,f_doc_topic_tag; method=HasMatch; feature_version=2; slot=4320; feature_id=369380 feature_name=fc_user_recent_click_doc_topic_tag_30d_tob_profile_match; depend=user_recent_click_doc_topic_tag_30d,f_doc_topic_tag; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4319; feature_id=369379 feature_name=fc_user_recent_click_doc_topic_tag_6h; depend=user_recent_click_doc_topic_tag_6h; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4309; shared=true; feature_id=369369 feature_name=fc_user_recent_click_doc_topic_tag_6h_has_match; depend=fc_user_recent_click_doc_topic_tag_6h,f_doc_topic_tag; method=HasMatch; feature_version=2; slot=4311; feature_id=369371 feature_name=fc_user_recent_click_doc_topic_tag_6h_tob_profile_match; depend=user_recent_click_doc_topic_tag_6h,f_doc_topic_tag; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4310; feature_id=369370 feature_name=fc_user_recent_click_doc_topic_tag_7d; depend=user_recent_click_doc_topic_tag_7d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4315; shared=true; feature_id=369375 feature_name=fc_user_recent_click_doc_topic_tag_7d_has_match; depend=fc_user_recent_click_doc_topic_tag_7d,f_doc_topic_tag; method=HasMatch; feature_version=2; slot=4317; feature_id=369377 feature_name=fc_user_recent_click_doc_topic_tag_7d_tob_profile_match; depend=user_recent_click_doc_topic_tag_7d,f_doc_topic_tag; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4316; feature_id=369376 feature_name=fc_user_recent_click_doc_type_180d; depend=user_recent_click_doc_type_180d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4303; shared=true; feature_id=369363 feature_name=fc_user_recent_click_doc_type_180d_has_match; depend=fc_user_recent_click_doc_type_180d,f_doc_type; method=HasMatch; feature_version=2; slot=4305; feature_id=369365 feature_name=fc_user_recent_click_doc_type_180d_tob_profile_match; depend=user_recent_click_doc_type_180d,f_doc_type; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4304; feature_id=369364 feature_name=fc_user_recent_click_doc_type_1d; depend=user_recent_click_doc_type_1d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4294; shared=true; feature_id=369354 feature_name=fc_user_recent_click_doc_type_1d_has_match; depend=fc_user_recent_click_doc_type_1d,f_doc_type; method=HasMatch; feature_version=2; slot=4296; feature_id=369356 feature_name=fc_user_recent_click_doc_type_1d_tob_profile_match; depend=user_recent_click_doc_type_1d,f_doc_type; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4295; feature_id=369355 feature_name=fc_user_recent_click_doc_type_1h; depend=user_recent_click_doc_type_1h; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4288; shared=true; feature_id=369348 feature_name=fc_user_recent_click_doc_type_1h_has_match; depend=fc_user_recent_click_doc_type_1h,f_doc_type; method=HasMatch; feature_version=2; slot=4290; feature_id=369350 feature_name=fc_user_recent_click_doc_type_1h_tob_profile_match; depend=user_recent_click_doc_type_1h,f_doc_type; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4289; feature_id=369349 feature_name=fc_user_recent_click_doc_type_30d; depend=user_recent_click_doc_type_30d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4300; shared=true; feature_id=369360 feature_name=fc_user_recent_click_doc_type_30d_has_match; depend=fc_user_recent_click_doc_type_30d,f_doc_type; method=HasMatch; feature_version=2; slot=4302; feature_id=369362 feature_name=fc_user_recent_click_doc_type_30d_tob_profile_match; depend=user_recent_click_doc_type_30d,f_doc_type; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4301; feature_id=369361 feature_name=fc_user_recent_click_doc_type_6h; depend=user_recent_click_doc_type_6h; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4291; shared=true; feature_id=369351 feature_name=fc_user_recent_click_doc_type_6h_has_match; depend=fc_user_recent_click_doc_type_6h,f_doc_type; method=HasMatch; feature_version=2; slot=4293; feature_id=369353 feature_name=fc_user_recent_click_doc_type_6h_tob_profile_match; depend=user_recent_click_doc_type_6h,f_doc_type; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4292; feature_id=369352 feature_name=fc_user_recent_click_doc_type_7d; depend=user_recent_click_doc_type_7d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4297; shared=true; feature_id=369357 feature_name=fc_user_recent_click_doc_type_7d_has_match; depend=fc_user_recent_click_doc_type_7d,f_doc_type; method=HasMatch; feature_version=2; slot=4299; feature_id=369359 feature_name=fc_user_recent_click_doc_type_7d_tob_profile_match; depend=user_recent_click_doc_type_7d,f_doc_type; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4298; feature_id=369358 feature_name=fc_user_recent_exposure_doc_cate1_180d; depend=user_recent_exposure_doc_cate1_180d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4645; shared=true; feature_id=369705 feature_name=fc_user_recent_exposure_doc_cate1_180d_has_match; depend=fc_user_recent_exposure_doc_cate1_180d,f_doc_cate1; method=HasMatch; feature_version=2; slot=4647; feature_id=369707 feature_name=fc_user_recent_exposure_doc_cate1_180d_tob_profile_match; depend=user_recent_exposure_doc_cate1_180d,f_doc_cate1; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4646; feature_id=369706 feature_name=fc_user_recent_exposure_doc_cate1_1d; depend=user_recent_exposure_doc_cate1_1d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4636; shared=true; feature_id=369696 feature_name=fc_user_recent_exposure_doc_cate1_1d_has_match; depend=fc_user_recent_exposure_doc_cate1_1d,f_doc_cate1; method=HasMatch; feature_version=2; slot=4638; feature_id=369698 feature_name=fc_user_recent_exposure_doc_cate1_1d_tob_profile_match; depend=user_recent_exposure_doc_cate1_1d,f_doc_cate1; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4637; feature_id=369697 feature_name=fc_user_recent_exposure_doc_cate1_1h; depend=user_recent_exposure_doc_cate1_1h; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4630; shared=true; feature_id=369690 feature_name=fc_user_recent_exposure_doc_cate1_1h_has_match; depend=fc_user_recent_exposure_doc_cate1_1h,f_doc_cate1; method=HasMatch; feature_version=2; slot=4632; feature_id=369692 feature_name=fc_user_recent_exposure_doc_cate1_1h_tob_profile_match; depend=user_recent_exposure_doc_cate1_1h,f_doc_cate1; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4631; feature_id=369691 feature_name=fc_user_recent_exposure_doc_cate1_30d; depend=user_recent_exposure_doc_cate1_30d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4642; shared=true; feature_id=369702 feature_name=fc_user_recent_exposure_doc_cate1_30d_has_match; depend=fc_user_recent_exposure_doc_cate1_30d,f_doc_cate1; method=HasMatch; feature_version=2; slot=4644; feature_id=369704 feature_name=fc_user_recent_exposure_doc_cate1_30d_tob_profile_match; depend=user_recent_exposure_doc_cate1_30d,f_doc_cate1; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4643; feature_id=369703 feature_name=fc_user_recent_exposure_doc_cate1_6h; depend=user_recent_exposure_doc_cate1_6h; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4633; shared=true; feature_id=369693 feature_name=fc_user_recent_exposure_doc_cate1_6h_has_match; depend=fc_user_recent_exposure_doc_cate1_6h,f_doc_cate1; method=HasMatch; feature_version=2; slot=4635; feature_id=369695 feature_name=fc_user_recent_exposure_doc_cate1_6h_tob_profile_match; depend=user_recent_exposure_doc_cate1_6h,f_doc_cate1; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4634; feature_id=369694 feature_name=fc_user_recent_exposure_doc_cate1_7d; depend=user_recent_exposure_doc_cate1_7d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4639; shared=true; feature_id=369699 feature_name=fc_user_recent_exposure_doc_cate1_7d_has_match; depend=fc_user_recent_exposure_doc_cate1_7d,f_doc_cate1; method=HasMatch; feature_version=2; slot=4641; feature_id=369701 feature_name=fc_user_recent_exposure_doc_cate1_7d_tob_profile_match; depend=user_recent_exposure_doc_cate1_7d,f_doc_cate1; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4640; feature_id=369700 feature_name=fc_user_recent_exposure_doc_cate2_180d; depend=user_recent_exposure_doc_cate2_180d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4681; shared=true; feature_id=369741 feature_name=fc_user_recent_exposure_doc_cate2_180d_has_match; depend=fc_user_recent_exposure_doc_cate2_180d,f_doc_cate2; method=HasMatch; feature_version=2; slot=4683; feature_id=369743 feature_name=fc_user_recent_exposure_doc_cate2_180d_tob_profile_match; depend=user_recent_exposure_doc_cate2_180d,f_doc_cate2; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4682; feature_id=369742 feature_name=fc_user_recent_exposure_doc_cate2_1d; depend=user_recent_exposure_doc_cate2_1d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4672; shared=true; feature_id=369732 feature_name=fc_user_recent_exposure_doc_cate2_1d_has_match; depend=fc_user_recent_exposure_doc_cate2_1d,f_doc_cate2; method=HasMatch; feature_version=2; slot=4674; feature_id=369734 feature_name=fc_user_recent_exposure_doc_cate2_1d_tob_profile_match; depend=user_recent_exposure_doc_cate2_1d,f_doc_cate2; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4673; feature_id=369733 feature_name=fc_user_recent_exposure_doc_cate2_1h; depend=user_recent_exposure_doc_cate2_1h; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4666; shared=true; feature_id=369726 feature_name=fc_user_recent_exposure_doc_cate2_1h_has_match; depend=fc_user_recent_exposure_doc_cate2_1h,f_doc_cate2; method=HasMatch; feature_version=2; slot=4668; feature_id=369728 feature_name=fc_user_recent_exposure_doc_cate2_1h_tob_profile_match; depend=user_recent_exposure_doc_cate2_1h,f_doc_cate2; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4667; feature_id=369727 feature_name=fc_user_recent_exposure_doc_cate2_30d; depend=user_recent_exposure_doc_cate2_30d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4678; shared=true; feature_id=369738 feature_name=fc_user_recent_exposure_doc_cate2_30d_has_match; depend=fc_user_recent_exposure_doc_cate2_30d,f_doc_cate2; method=HasMatch; feature_version=2; slot=4680; feature_id=369740 feature_name=fc_user_recent_exposure_doc_cate2_30d_tob_profile_match; depend=user_recent_exposure_doc_cate2_30d,f_doc_cate2; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4679; feature_id=369739 feature_name=fc_user_recent_exposure_doc_cate2_6h; depend=user_recent_exposure_doc_cate2_6h; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4669; shared=true; feature_id=369729 feature_name=fc_user_recent_exposure_doc_cate2_6h_has_match; depend=fc_user_recent_exposure_doc_cate2_6h,f_doc_cate2; method=HasMatch; feature_version=2; slot=4671; feature_id=369731 feature_name=fc_user_recent_exposure_doc_cate2_6h_tob_profile_match; depend=user_recent_exposure_doc_cate2_6h,f_doc_cate2; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4670; feature_id=369730 feature_name=fc_user_recent_exposure_doc_cate2_7d; depend=user_recent_exposure_doc_cate2_7d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4675; shared=true; feature_id=369735 feature_name=fc_user_recent_exposure_doc_cate2_7d_has_match; depend=fc_user_recent_exposure_doc_cate2_7d,f_doc_cate2; method=HasMatch; feature_version=2; slot=4677; feature_id=369737 feature_name=fc_user_recent_exposure_doc_cate2_7d_tob_profile_match; depend=user_recent_exposure_doc_cate2_7d,f_doc_cate2; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4676; feature_id=369736 feature_name=fc_user_recent_exposure_doc_cate3_180d; depend=user_recent_exposure_doc_cate3_180d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4663; shared=true; feature_id=369723 feature_name=fc_user_recent_exposure_doc_cate3_180d_has_match; depend=fc_user_recent_exposure_doc_cate3_180d,f_doc_cate3; method=HasMatch; feature_version=2; slot=4665; feature_id=369725 feature_name=fc_user_recent_exposure_doc_cate3_180d_tob_profile_match; depend=user_recent_exposure_doc_cate3_180d,f_doc_cate3; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4664; feature_id=369724 feature_name=fc_user_recent_exposure_doc_cate3_1d; depend=user_recent_exposure_doc_cate3_1d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4654; shared=true; feature_id=369714 feature_name=fc_user_recent_exposure_doc_cate3_1d_has_match; depend=fc_user_recent_exposure_doc_cate3_1d,f_doc_cate3; method=HasMatch; feature_version=2; slot=4656; feature_id=369716 feature_name=fc_user_recent_exposure_doc_cate3_1d_tob_profile_match; depend=user_recent_exposure_doc_cate3_1d,f_doc_cate3; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4655; feature_id=369715 feature_name=fc_user_recent_exposure_doc_cate3_1h; depend=user_recent_exposure_doc_cate3_1h; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4648; shared=true; feature_id=369708 feature_name=fc_user_recent_exposure_doc_cate3_1h_has_match; depend=fc_user_recent_exposure_doc_cate3_1h,f_doc_cate3; method=HasMatch; feature_version=2; slot=4650; feature_id=369710 feature_name=fc_user_recent_exposure_doc_cate3_1h_tob_profile_match; depend=user_recent_exposure_doc_cate3_1h,f_doc_cate3; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4649; feature_id=369709 feature_name=fc_user_recent_exposure_doc_cate3_30d; depend=user_recent_exposure_doc_cate3_30d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4660; shared=true; feature_id=369720 feature_name=fc_user_recent_exposure_doc_cate3_30d_has_match; depend=fc_user_recent_exposure_doc_cate3_30d,f_doc_cate3; method=HasMatch; feature_version=2; slot=4662; feature_id=369722 feature_name=fc_user_recent_exposure_doc_cate3_30d_tob_profile_match; depend=user_recent_exposure_doc_cate3_30d,f_doc_cate3; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4661; feature_id=369721 feature_name=fc_user_recent_exposure_doc_cate3_6h; depend=user_recent_exposure_doc_cate3_6h; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4651; shared=true; feature_id=369711 feature_name=fc_user_recent_exposure_doc_cate3_6h_has_match; depend=fc_user_recent_exposure_doc_cate3_6h,f_doc_cate3; method=HasMatch; feature_version=2; slot=4653; feature_id=369713 feature_name=fc_user_recent_exposure_doc_cate3_6h_tob_profile_match; depend=user_recent_exposure_doc_cate3_6h,f_doc_cate3; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4652; feature_id=369712 feature_name=fc_user_recent_exposure_doc_cate3_7d; depend=user_recent_exposure_doc_cate3_7d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4657; shared=true; feature_id=369717 feature_name=fc_user_recent_exposure_doc_cate3_7d_has_match; depend=fc_user_recent_exposure_doc_cate3_7d,f_doc_cate3; method=HasMatch; feature_version=2; slot=4659; feature_id=369719 feature_name=fc_user_recent_exposure_doc_cate3_7d_tob_profile_match; depend=user_recent_exposure_doc_cate3_7d,f_doc_cate3; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4658; feature_id=369718 feature_name=fc_user_recent_exposure_doc_id_180d; depend=user_recent_exposure_doc_id_180d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4699; shared=true; feature_id=369759 feature_name=fc_user_recent_exposure_doc_id_180d_has_match; depend=fc_user_recent_exposure_doc_id_180d,f_doc_id; method=HasMatch; feature_version=2; slot=4701; feature_id=369761 feature_name=fc_user_recent_exposure_doc_id_180d_tob_profile_match; depend=user_recent_exposure_doc_id_180d,f_doc_id; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4700; feature_id=369760 feature_name=fc_user_recent_exposure_doc_id_1d; depend=user_recent_exposure_doc_id_1d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4690; shared=true; feature_id=369750 feature_name=fc_user_recent_exposure_doc_id_1d_has_match; depend=fc_user_recent_exposure_doc_id_1d,f_doc_id; method=HasMatch; feature_version=2; slot=4692; feature_id=369752 feature_name=fc_user_recent_exposure_doc_id_1d_tob_profile_match; depend=user_recent_exposure_doc_id_1d,f_doc_id; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4691; feature_id=369751 feature_name=fc_user_recent_exposure_doc_id_1h; depend=user_recent_exposure_doc_id_1h; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4684; shared=true; feature_id=369744 feature_name=fc_user_recent_exposure_doc_id_1h_has_match; depend=fc_user_recent_exposure_doc_id_1h,f_doc_id; method=HasMatch; feature_version=2; slot=4686; feature_id=369746 feature_name=fc_user_recent_exposure_doc_id_1h_tob_profile_match; depend=user_recent_exposure_doc_id_1h,f_doc_id; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4685; feature_id=369745 feature_name=fc_user_recent_exposure_doc_id_30d; depend=user_recent_exposure_doc_id_30d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4696; shared=true; feature_id=369756 feature_name=fc_user_recent_exposure_doc_id_30d_has_match; depend=fc_user_recent_exposure_doc_id_30d,f_doc_id; method=HasMatch; feature_version=2; slot=4698; feature_id=369758 feature_name=fc_user_recent_exposure_doc_id_30d_tob_profile_match; depend=user_recent_exposure_doc_id_30d,f_doc_id; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4697; feature_id=369757 feature_name=fc_user_recent_exposure_doc_id_6h; depend=user_recent_exposure_doc_id_6h; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4687; shared=true; feature_id=369747 feature_name=fc_user_recent_exposure_doc_id_6h_has_match; depend=fc_user_recent_exposure_doc_id_6h,f_doc_id; method=HasMatch; feature_version=2; slot=4689; feature_id=369749 feature_name=fc_user_recent_exposure_doc_id_6h_tob_profile_match; depend=user_recent_exposure_doc_id_6h,f_doc_id; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4688; feature_id=369748 feature_name=fc_user_recent_exposure_doc_id_7d; depend=user_recent_exposure_doc_id_7d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4693; shared=true; feature_id=369753 feature_name=fc_user_recent_exposure_doc_id_7d_has_match; depend=fc_user_recent_exposure_doc_id_7d,f_doc_id; method=HasMatch; feature_version=2; slot=4695; feature_id=369755 feature_name=fc_user_recent_exposure_doc_id_7d_tob_profile_match; depend=user_recent_exposure_doc_id_7d,f_doc_id; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4694; feature_id=369754 feature_name=fc_user_recent_exposure_doc_keyword_180d; depend=user_recent_exposure_doc_keyword_180d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4627; shared=true; feature_id=369687 feature_name=fc_user_recent_exposure_doc_keyword_180d_has_match; depend=fc_user_recent_exposure_doc_keyword_180d,f_doc_keyword; method=HasMatch; feature_version=2; slot=4629; feature_id=369689 feature_name=fc_user_recent_exposure_doc_keyword_180d_tob_profile_match; depend=user_recent_exposure_doc_keyword_180d,f_doc_keyword; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4628; feature_id=369688 feature_name=fc_user_recent_exposure_doc_keyword_1d; depend=user_recent_exposure_doc_keyword_1d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4618; shared=true; feature_id=369678 feature_name=fc_user_recent_exposure_doc_keyword_1d_has_match; depend=fc_user_recent_exposure_doc_keyword_1d,f_doc_keyword; method=HasMatch; feature_version=2; slot=4620; feature_id=369680 feature_name=fc_user_recent_exposure_doc_keyword_1d_tob_profile_match; depend=user_recent_exposure_doc_keyword_1d,f_doc_keyword; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4619; feature_id=369679 feature_name=fc_user_recent_exposure_doc_keyword_1h; depend=user_recent_exposure_doc_keyword_1h; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4612; shared=true; feature_id=369672 feature_name=fc_user_recent_exposure_doc_keyword_1h_has_match; depend=fc_user_recent_exposure_doc_keyword_1h,f_doc_keyword; method=HasMatch; feature_version=2; slot=4614; feature_id=369674 feature_name=fc_user_recent_exposure_doc_keyword_1h_tob_profile_match; depend=user_recent_exposure_doc_keyword_1h,f_doc_keyword; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4613; feature_id=369673 feature_name=fc_user_recent_exposure_doc_keyword_30d; depend=user_recent_exposure_doc_keyword_30d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4624; shared=true; feature_id=369684 feature_name=fc_user_recent_exposure_doc_keyword_30d_has_match; depend=fc_user_recent_exposure_doc_keyword_30d,f_doc_keyword; method=HasMatch; feature_version=2; slot=4626; feature_id=369686 feature_name=fc_user_recent_exposure_doc_keyword_30d_tob_profile_match; depend=user_recent_exposure_doc_keyword_30d,f_doc_keyword; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4625; feature_id=369685 feature_name=fc_user_recent_exposure_doc_keyword_6h; depend=user_recent_exposure_doc_keyword_6h; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4615; shared=true; feature_id=369675 feature_name=fc_user_recent_exposure_doc_keyword_6h_has_match; depend=fc_user_recent_exposure_doc_keyword_6h,f_doc_keyword; method=HasMatch; feature_version=2; slot=4617; feature_id=369677 feature_name=fc_user_recent_exposure_doc_keyword_6h_tob_profile_match; depend=user_recent_exposure_doc_keyword_6h,f_doc_keyword; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4616; feature_id=369676 feature_name=fc_user_recent_exposure_doc_keyword_7d; depend=user_recent_exposure_doc_keyword_7d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4621; shared=true; feature_id=369681 feature_name=fc_user_recent_exposure_doc_keyword_7d_has_match; depend=fc_user_recent_exposure_doc_keyword_7d,f_doc_keyword; method=HasMatch; feature_version=2; slot=4623; feature_id=369683 feature_name=fc_user_recent_exposure_doc_keyword_7d_tob_profile_match; depend=user_recent_exposure_doc_keyword_7d,f_doc_keyword; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4622; feature_id=369682 feature_name=fc_user_recent_exposure_doc_tags_180d; depend=user_recent_exposure_doc_tags_180d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4717; shared=true; feature_id=369777 feature_name=fc_user_recent_exposure_doc_tags_180d_has_match; depend=fc_user_recent_exposure_doc_tags_180d,f_doc_tags; method=HasMatch; feature_version=2; slot=4719; feature_id=369779 feature_name=fc_user_recent_exposure_doc_tags_180d_tob_profile_match; depend=user_recent_exposure_doc_tags_180d,f_doc_tags; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4718; feature_id=369778 feature_name=fc_user_recent_exposure_doc_tags_1d; depend=user_recent_exposure_doc_tags_1d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4708; shared=true; feature_id=369768 feature_name=fc_user_recent_exposure_doc_tags_1d_has_match; depend=fc_user_recent_exposure_doc_tags_1d,f_doc_tags; method=HasMatch; feature_version=2; slot=4710; feature_id=369770 feature_name=fc_user_recent_exposure_doc_tags_1d_tob_profile_match; depend=user_recent_exposure_doc_tags_1d,f_doc_tags; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4709; feature_id=369769 feature_name=fc_user_recent_exposure_doc_tags_1h; depend=user_recent_exposure_doc_tags_1h; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4702; shared=true; feature_id=369762 feature_name=fc_user_recent_exposure_doc_tags_1h_has_match; depend=fc_user_recent_exposure_doc_tags_1h,f_doc_tags; method=HasMatch; feature_version=2; slot=4704; feature_id=369764 feature_name=fc_user_recent_exposure_doc_tags_1h_tob_profile_match; depend=user_recent_exposure_doc_tags_1h,f_doc_tags; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4703; feature_id=369763 feature_name=fc_user_recent_exposure_doc_tags_30d; depend=user_recent_exposure_doc_tags_30d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4714; shared=true; feature_id=369774 feature_name=fc_user_recent_exposure_doc_tags_30d_has_match; depend=fc_user_recent_exposure_doc_tags_30d,f_doc_tags; method=HasMatch; feature_version=2; slot=4716; feature_id=369776 feature_name=fc_user_recent_exposure_doc_tags_30d_tob_profile_match; depend=user_recent_exposure_doc_tags_30d,f_doc_tags; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4715; feature_id=369775 feature_name=fc_user_recent_exposure_doc_tags_6h; depend=user_recent_exposure_doc_tags_6h; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4705; shared=true; feature_id=369765 feature_name=fc_user_recent_exposure_doc_tags_6h_has_match; depend=fc_user_recent_exposure_doc_tags_6h,f_doc_tags; method=HasMatch; feature_version=2; slot=4707; feature_id=369767 feature_name=fc_user_recent_exposure_doc_tags_6h_tob_profile_match; depend=user_recent_exposure_doc_tags_6h,f_doc_tags; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4706; feature_id=369766 feature_name=fc_user_recent_exposure_doc_tags_7d; depend=user_recent_exposure_doc_tags_7d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4711; shared=true; feature_id=369771 feature_name=fc_user_recent_exposure_doc_tags_7d_has_match; depend=fc_user_recent_exposure_doc_tags_7d,f_doc_tags; method=HasMatch; feature_version=2; slot=4713; feature_id=369773 feature_name=fc_user_recent_exposure_doc_tags_7d_tob_profile_match; depend=user_recent_exposure_doc_tags_7d,f_doc_tags; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4712; feature_id=369772 feature_name=fc_user_recent_exposure_doc_topic_tag_180d; depend=user_recent_exposure_doc_topic_tag_180d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4609; shared=true; feature_id=369669 feature_name=fc_user_recent_exposure_doc_topic_tag_180d_has_match; depend=fc_user_recent_exposure_doc_topic_tag_180d,f_doc_topic_tag; method=HasMatch; feature_version=2; slot=4611; feature_id=369671 feature_name=fc_user_recent_exposure_doc_topic_tag_180d_tob_profile_match; depend=user_recent_exposure_doc_topic_tag_180d,f_doc_topic_tag; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4610; feature_id=369670 feature_name=fc_user_recent_exposure_doc_topic_tag_1d; depend=user_recent_exposure_doc_topic_tag_1d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4600; shared=true; feature_id=369660 feature_name=fc_user_recent_exposure_doc_topic_tag_1d_has_match; depend=fc_user_recent_exposure_doc_topic_tag_1d,f_doc_topic_tag; method=HasMatch; feature_version=2; slot=4602; feature_id=369662 feature_name=fc_user_recent_exposure_doc_topic_tag_1d_tob_profile_match; depend=user_recent_exposure_doc_topic_tag_1d,f_doc_topic_tag; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4601; feature_id=369661 feature_name=fc_user_recent_exposure_doc_topic_tag_1h; depend=user_recent_exposure_doc_topic_tag_1h; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4594; shared=true; feature_id=369654 feature_name=fc_user_recent_exposure_doc_topic_tag_1h_has_match; depend=fc_user_recent_exposure_doc_topic_tag_1h,f_doc_topic_tag; method=HasMatch; feature_version=2; slot=4596; feature_id=369656 feature_name=fc_user_recent_exposure_doc_topic_tag_1h_tob_profile_match; depend=user_recent_exposure_doc_topic_tag_1h,f_doc_topic_tag; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4595; feature_id=369655 feature_name=fc_user_recent_exposure_doc_topic_tag_30d; depend=user_recent_exposure_doc_topic_tag_30d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4606; shared=true; feature_id=369666 feature_name=fc_user_recent_exposure_doc_topic_tag_30d_has_match; depend=fc_user_recent_exposure_doc_topic_tag_30d,f_doc_topic_tag; method=HasMatch; feature_version=2; slot=4608; feature_id=369668 feature_name=fc_user_recent_exposure_doc_topic_tag_30d_tob_profile_match; depend=user_recent_exposure_doc_topic_tag_30d,f_doc_topic_tag; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4607; feature_id=369667 feature_name=fc_user_recent_exposure_doc_topic_tag_6h; depend=user_recent_exposure_doc_topic_tag_6h; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4597; shared=true; feature_id=369657 feature_name=fc_user_recent_exposure_doc_topic_tag_6h_has_match; depend=fc_user_recent_exposure_doc_topic_tag_6h,f_doc_topic_tag; method=HasMatch; feature_version=2; slot=4599; feature_id=369659 feature_name=fc_user_recent_exposure_doc_topic_tag_6h_tob_profile_match; depend=user_recent_exposure_doc_topic_tag_6h,f_doc_topic_tag; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4598; feature_id=369658 feature_name=fc_user_recent_exposure_doc_topic_tag_7d; depend=user_recent_exposure_doc_topic_tag_7d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4603; shared=true; feature_id=369663 feature_name=fc_user_recent_exposure_doc_topic_tag_7d_has_match; depend=fc_user_recent_exposure_doc_topic_tag_7d,f_doc_topic_tag; method=HasMatch; feature_version=2; slot=4605; feature_id=369665 feature_name=fc_user_recent_exposure_doc_topic_tag_7d_tob_profile_match; depend=user_recent_exposure_doc_topic_tag_7d,f_doc_topic_tag; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4604; feature_id=369664 feature_name=fc_user_recent_exposure_doc_type_180d; depend=user_recent_exposure_doc_type_180d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4591; shared=true; feature_id=369651 feature_name=fc_user_recent_exposure_doc_type_180d_has_match; depend=fc_user_recent_exposure_doc_type_180d,f_doc_type; method=HasMatch; feature_version=2; slot=4593; feature_id=369653 feature_name=fc_user_recent_exposure_doc_type_180d_tob_profile_match; depend=user_recent_exposure_doc_type_180d,f_doc_type; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4592; feature_id=369652 feature_name=fc_user_recent_exposure_doc_type_1d; depend=user_recent_exposure_doc_type_1d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4582; shared=true; feature_id=369642 feature_name=fc_user_recent_exposure_doc_type_1d_has_match; depend=fc_user_recent_exposure_doc_type_1d,f_doc_type; method=HasMatch; feature_version=2; slot=4584; feature_id=369644 feature_name=fc_user_recent_exposure_doc_type_1d_tob_profile_match; depend=user_recent_exposure_doc_type_1d,f_doc_type; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4583; feature_id=369643 feature_name=fc_user_recent_exposure_doc_type_1h; depend=user_recent_exposure_doc_type_1h; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4576; shared=true; feature_id=369636 feature_name=fc_user_recent_exposure_doc_type_1h_has_match; depend=fc_user_recent_exposure_doc_type_1h,f_doc_type; method=HasMatch; feature_version=2; slot=4578; feature_id=369638 feature_name=fc_user_recent_exposure_doc_type_1h_tob_profile_match; depend=user_recent_exposure_doc_type_1h,f_doc_type; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4577; feature_id=369637 feature_name=fc_user_recent_exposure_doc_type_30d; depend=user_recent_exposure_doc_type_30d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4588; shared=true; feature_id=369648 feature_name=fc_user_recent_exposure_doc_type_30d_has_match; depend=fc_user_recent_exposure_doc_type_30d,f_doc_type; method=HasMatch; feature_version=2; slot=4590; feature_id=369650 feature_name=fc_user_recent_exposure_doc_type_30d_tob_profile_match; depend=user_recent_exposure_doc_type_30d,f_doc_type; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4589; feature_id=369649 feature_name=fc_user_recent_exposure_doc_type_6h; depend=user_recent_exposure_doc_type_6h; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4579; shared=true; feature_id=369639 feature_name=fc_user_recent_exposure_doc_type_6h_has_match; depend=fc_user_recent_exposure_doc_type_6h,f_doc_type; method=HasMatch; feature_version=2; slot=4581; feature_id=369641 feature_name=fc_user_recent_exposure_doc_type_6h_tob_profile_match; depend=user_recent_exposure_doc_type_6h,f_doc_type; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4580; feature_id=369640 feature_name=fc_user_recent_exposure_doc_type_7d; depend=user_recent_exposure_doc_type_7d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4585; shared=true; feature_id=369645 feature_name=fc_user_recent_exposure_doc_type_7d_has_match; depend=fc_user_recent_exposure_doc_type_7d,f_doc_type; method=HasMatch; feature_version=2; slot=4587; feature_id=369647 feature_name=fc_user_recent_exposure_doc_type_7d_tob_profile_match; depend=user_recent_exposure_doc_type_7d,f_doc_type; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4586; feature_id=369646 feature_name=fc_user_recent_favorite_doc_cate1_180d; depend=user_recent_favorite_doc_cate1_180d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4213; shared=true; feature_id=369273 feature_name=fc_user_recent_favorite_doc_cate1_180d_has_match; depend=fc_user_recent_favorite_doc_cate1_180d,f_doc_cate1; method=HasMatch; feature_version=2; slot=4215; feature_id=369275 feature_name=fc_user_recent_favorite_doc_cate1_180d_tob_profile_match; depend=user_recent_favorite_doc_cate1_180d,f_doc_cate1; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4214; feature_id=369274 feature_name=fc_user_recent_favorite_doc_cate1_1d; depend=user_recent_favorite_doc_cate1_1d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4204; shared=true; feature_id=369264 feature_name=fc_user_recent_favorite_doc_cate1_1d_has_match; depend=fc_user_recent_favorite_doc_cate1_1d,f_doc_cate1; method=HasMatch; feature_version=2; slot=4206; feature_id=369266 feature_name=fc_user_recent_favorite_doc_cate1_1d_tob_profile_match; depend=user_recent_favorite_doc_cate1_1d,f_doc_cate1; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4205; feature_id=369265 feature_name=fc_user_recent_favorite_doc_cate1_1h; depend=user_recent_favorite_doc_cate1_1h; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4198; shared=true; feature_id=369258 feature_name=fc_user_recent_favorite_doc_cate1_1h_has_match; depend=fc_user_recent_favorite_doc_cate1_1h,f_doc_cate1; method=HasMatch; feature_version=2; slot=4200; feature_id=369260 feature_name=fc_user_recent_favorite_doc_cate1_1h_tob_profile_match; depend=user_recent_favorite_doc_cate1_1h,f_doc_cate1; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4199; feature_id=369259 feature_name=fc_user_recent_favorite_doc_cate1_30d; depend=user_recent_favorite_doc_cate1_30d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4210; shared=true; feature_id=369270 feature_name=fc_user_recent_favorite_doc_cate1_30d_has_match; depend=fc_user_recent_favorite_doc_cate1_30d,f_doc_cate1; method=HasMatch; feature_version=2; slot=4212; feature_id=369272 feature_name=fc_user_recent_favorite_doc_cate1_30d_tob_profile_match; depend=user_recent_favorite_doc_cate1_30d,f_doc_cate1; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4211; feature_id=369271 feature_name=fc_user_recent_favorite_doc_cate1_6h; depend=user_recent_favorite_doc_cate1_6h; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4201; shared=true; feature_id=369261 feature_name=fc_user_recent_favorite_doc_cate1_6h_has_match; depend=fc_user_recent_favorite_doc_cate1_6h,f_doc_cate1; method=HasMatch; feature_version=2; slot=4203; feature_id=369263 feature_name=fc_user_recent_favorite_doc_cate1_6h_tob_profile_match; depend=user_recent_favorite_doc_cate1_6h,f_doc_cate1; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4202; feature_id=369262 feature_name=fc_user_recent_favorite_doc_cate1_7d; depend=user_recent_favorite_doc_cate1_7d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4207; shared=true; feature_id=369267 feature_name=fc_user_recent_favorite_doc_cate1_7d_has_match; depend=fc_user_recent_favorite_doc_cate1_7d,f_doc_cate1; method=HasMatch; feature_version=2; slot=4209; feature_id=369269 feature_name=fc_user_recent_favorite_doc_cate1_7d_tob_profile_match; depend=user_recent_favorite_doc_cate1_7d,f_doc_cate1; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4208; feature_id=369268 feature_name=fc_user_recent_favorite_doc_cate2_180d; depend=user_recent_favorite_doc_cate2_180d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4249; shared=true; feature_id=369309 feature_name=fc_user_recent_favorite_doc_cate2_180d_has_match; depend=fc_user_recent_favorite_doc_cate2_180d,f_doc_cate2; method=HasMatch; feature_version=2; slot=4251; feature_id=369311 feature_name=fc_user_recent_favorite_doc_cate2_180d_tob_profile_match; depend=user_recent_favorite_doc_cate2_180d,f_doc_cate2; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4250; feature_id=369310 feature_name=fc_user_recent_favorite_doc_cate2_1d; depend=user_recent_favorite_doc_cate2_1d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4240; shared=true; feature_id=369300 feature_name=fc_user_recent_favorite_doc_cate2_1d_has_match; depend=fc_user_recent_favorite_doc_cate2_1d,f_doc_cate2; method=HasMatch; feature_version=2; slot=4242; feature_id=369302 feature_name=fc_user_recent_favorite_doc_cate2_1d_tob_profile_match; depend=user_recent_favorite_doc_cate2_1d,f_doc_cate2; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4241; feature_id=369301 feature_name=fc_user_recent_favorite_doc_cate2_1h; depend=user_recent_favorite_doc_cate2_1h; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4234; shared=true; feature_id=369294 feature_name=fc_user_recent_favorite_doc_cate2_1h_has_match; depend=fc_user_recent_favorite_doc_cate2_1h,f_doc_cate2; method=HasMatch; feature_version=2; slot=4236; feature_id=369296 feature_name=fc_user_recent_favorite_doc_cate2_1h_tob_profile_match; depend=user_recent_favorite_doc_cate2_1h,f_doc_cate2; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4235; feature_id=369295 feature_name=fc_user_recent_favorite_doc_cate2_30d; depend=user_recent_favorite_doc_cate2_30d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4246; shared=true; feature_id=369306 feature_name=fc_user_recent_favorite_doc_cate2_30d_has_match; depend=fc_user_recent_favorite_doc_cate2_30d,f_doc_cate2; method=HasMatch; feature_version=2; slot=4248; feature_id=369308 feature_name=fc_user_recent_favorite_doc_cate2_30d_tob_profile_match; depend=user_recent_favorite_doc_cate2_30d,f_doc_cate2; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4247; feature_id=369307 feature_name=fc_user_recent_favorite_doc_cate2_6h; depend=user_recent_favorite_doc_cate2_6h; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4237; shared=true; feature_id=369297 feature_name=fc_user_recent_favorite_doc_cate2_6h_has_match; depend=fc_user_recent_favorite_doc_cate2_6h,f_doc_cate2; method=HasMatch; feature_version=2; slot=4239; feature_id=369299 feature_name=fc_user_recent_favorite_doc_cate2_6h_tob_profile_match; depend=user_recent_favorite_doc_cate2_6h,f_doc_cate2; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4238; feature_id=369298 feature_name=fc_user_recent_favorite_doc_cate2_7d; depend=user_recent_favorite_doc_cate2_7d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4243; shared=true; feature_id=369303 feature_name=fc_user_recent_favorite_doc_cate2_7d_has_match; depend=fc_user_recent_favorite_doc_cate2_7d,f_doc_cate2; method=HasMatch; feature_version=2; slot=4245; feature_id=369305 feature_name=fc_user_recent_favorite_doc_cate2_7d_tob_profile_match; depend=user_recent_favorite_doc_cate2_7d,f_doc_cate2; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4244; feature_id=369304 feature_name=fc_user_recent_favorite_doc_cate3_180d; depend=user_recent_favorite_doc_cate3_180d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4231; shared=true; feature_id=369291 feature_name=fc_user_recent_favorite_doc_cate3_180d_has_match; depend=fc_user_recent_favorite_doc_cate3_180d,f_doc_cate3; method=HasMatch; feature_version=2; slot=4233; feature_id=369293 feature_name=fc_user_recent_favorite_doc_cate3_180d_tob_profile_match; depend=user_recent_favorite_doc_cate3_180d,f_doc_cate3; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4232; feature_id=369292 feature_name=fc_user_recent_favorite_doc_cate3_1d; depend=user_recent_favorite_doc_cate3_1d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4222; shared=true; feature_id=369282 feature_name=fc_user_recent_favorite_doc_cate3_1d_has_match; depend=fc_user_recent_favorite_doc_cate3_1d,f_doc_cate3; method=HasMatch; feature_version=2; slot=4224; feature_id=369284 feature_name=fc_user_recent_favorite_doc_cate3_1d_tob_profile_match; depend=user_recent_favorite_doc_cate3_1d,f_doc_cate3; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4223; feature_id=369283 feature_name=fc_user_recent_favorite_doc_cate3_1h; depend=user_recent_favorite_doc_cate3_1h; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4216; shared=true; feature_id=369276 feature_name=fc_user_recent_favorite_doc_cate3_1h_has_match; depend=fc_user_recent_favorite_doc_cate3_1h,f_doc_cate3; method=HasMatch; feature_version=2; slot=4218; feature_id=369278 feature_name=fc_user_recent_favorite_doc_cate3_1h_tob_profile_match; depend=user_recent_favorite_doc_cate3_1h,f_doc_cate3; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4217; feature_id=369277 feature_name=fc_user_recent_favorite_doc_cate3_30d; depend=user_recent_favorite_doc_cate3_30d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4228; shared=true; feature_id=369288 feature_name=fc_user_recent_favorite_doc_cate3_30d_has_match; depend=fc_user_recent_favorite_doc_cate3_30d,f_doc_cate3; method=HasMatch; feature_version=2; slot=4230; feature_id=369290 feature_name=fc_user_recent_favorite_doc_cate3_30d_tob_profile_match; depend=user_recent_favorite_doc_cate3_30d,f_doc_cate3; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4229; feature_id=369289 feature_name=fc_user_recent_favorite_doc_cate3_6h; depend=user_recent_favorite_doc_cate3_6h; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4219; shared=true; feature_id=369279 feature_name=fc_user_recent_favorite_doc_cate3_6h_has_match; depend=fc_user_recent_favorite_doc_cate3_6h,f_doc_cate3; method=HasMatch; feature_version=2; slot=4221; feature_id=369281 feature_name=fc_user_recent_favorite_doc_cate3_6h_tob_profile_match; depend=user_recent_favorite_doc_cate3_6h,f_doc_cate3; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4220; feature_id=369280 feature_name=fc_user_recent_favorite_doc_cate3_7d; depend=user_recent_favorite_doc_cate3_7d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4225; shared=true; feature_id=369285 feature_name=fc_user_recent_favorite_doc_cate3_7d_has_match; depend=fc_user_recent_favorite_doc_cate3_7d,f_doc_cate3; method=HasMatch; feature_version=2; slot=4227; feature_id=369287 feature_name=fc_user_recent_favorite_doc_cate3_7d_tob_profile_match; depend=user_recent_favorite_doc_cate3_7d,f_doc_cate3; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4226; feature_id=369286 feature_name=fc_user_recent_favorite_doc_id_180d; depend=user_recent_favorite_doc_id_180d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4267; shared=true; feature_id=369327 feature_name=fc_user_recent_favorite_doc_id_180d_has_match; depend=fc_user_recent_favorite_doc_id_180d,f_doc_id; method=HasMatch; feature_version=2; slot=4269; feature_id=369329 feature_name=fc_user_recent_favorite_doc_id_180d_tob_profile_match; depend=user_recent_favorite_doc_id_180d,f_doc_id; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4268; feature_id=369328 feature_name=fc_user_recent_favorite_doc_id_1d; depend=user_recent_favorite_doc_id_1d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4258; shared=true; feature_id=369318 feature_name=fc_user_recent_favorite_doc_id_1d_has_match; depend=fc_user_recent_favorite_doc_id_1d,f_doc_id; method=HasMatch; feature_version=2; slot=4260; feature_id=369320 feature_name=fc_user_recent_favorite_doc_id_1d_tob_profile_match; depend=user_recent_favorite_doc_id_1d,f_doc_id; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4259; feature_id=369319 feature_name=fc_user_recent_favorite_doc_id_1h; depend=user_recent_favorite_doc_id_1h; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4252; shared=true; feature_id=369312 feature_name=fc_user_recent_favorite_doc_id_1h_has_match; depend=fc_user_recent_favorite_doc_id_1h,f_doc_id; method=HasMatch; feature_version=2; slot=4254; feature_id=369314 feature_name=fc_user_recent_favorite_doc_id_1h_tob_profile_match; depend=user_recent_favorite_doc_id_1h,f_doc_id; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4253; feature_id=369313 feature_name=fc_user_recent_favorite_doc_id_30d; depend=user_recent_favorite_doc_id_30d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4264; shared=true; feature_id=369324 feature_name=fc_user_recent_favorite_doc_id_30d_has_match; depend=fc_user_recent_favorite_doc_id_30d,f_doc_id; method=HasMatch; feature_version=2; slot=4266; feature_id=369326 feature_name=fc_user_recent_favorite_doc_id_30d_tob_profile_match; depend=user_recent_favorite_doc_id_30d,f_doc_id; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4265; feature_id=369325 feature_name=fc_user_recent_favorite_doc_id_6h; depend=user_recent_favorite_doc_id_6h; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4255; shared=true; feature_id=369315 feature_name=fc_user_recent_favorite_doc_id_6h_has_match; depend=fc_user_recent_favorite_doc_id_6h,f_doc_id; method=HasMatch; feature_version=2; slot=4257; feature_id=369317 feature_name=fc_user_recent_favorite_doc_id_6h_tob_profile_match; depend=user_recent_favorite_doc_id_6h,f_doc_id; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4256; feature_id=369316 feature_name=fc_user_recent_favorite_doc_id_7d; depend=user_recent_favorite_doc_id_7d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4261; shared=true; feature_id=369321 feature_name=fc_user_recent_favorite_doc_id_7d_has_match; depend=fc_user_recent_favorite_doc_id_7d,f_doc_id; method=HasMatch; feature_version=2; slot=4263; feature_id=369323 feature_name=fc_user_recent_favorite_doc_id_7d_tob_profile_match; depend=user_recent_favorite_doc_id_7d,f_doc_id; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4262; feature_id=369322 feature_name=fc_user_recent_favorite_doc_keyword_180d; depend=user_recent_favorite_doc_keyword_180d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4195; shared=true; feature_id=369255 feature_name=fc_user_recent_favorite_doc_keyword_180d_has_match; depend=fc_user_recent_favorite_doc_keyword_180d,f_doc_keyword; method=HasMatch; feature_version=2; slot=4197; feature_id=369257 feature_name=fc_user_recent_favorite_doc_keyword_180d_tob_profile_match; depend=user_recent_favorite_doc_keyword_180d,f_doc_keyword; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4196; feature_id=369256 feature_name=fc_user_recent_favorite_doc_keyword_1d; depend=user_recent_favorite_doc_keyword_1d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4186; shared=true; feature_id=369246 feature_name=fc_user_recent_favorite_doc_keyword_1d_has_match; depend=fc_user_recent_favorite_doc_keyword_1d,f_doc_keyword; method=HasMatch; feature_version=2; slot=4188; feature_id=369248 feature_name=fc_user_recent_favorite_doc_keyword_1d_tob_profile_match; depend=user_recent_favorite_doc_keyword_1d,f_doc_keyword; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4187; feature_id=369247 feature_name=fc_user_recent_favorite_doc_keyword_1h; depend=user_recent_favorite_doc_keyword_1h; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4180; shared=true; feature_id=369240 feature_name=fc_user_recent_favorite_doc_keyword_1h_has_match; depend=fc_user_recent_favorite_doc_keyword_1h,f_doc_keyword; method=HasMatch; feature_version=2; slot=4182; feature_id=369242 feature_name=fc_user_recent_favorite_doc_keyword_1h_tob_profile_match; depend=user_recent_favorite_doc_keyword_1h,f_doc_keyword; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4181; feature_id=369241 feature_name=fc_user_recent_favorite_doc_keyword_30d; depend=user_recent_favorite_doc_keyword_30d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4192; shared=true; feature_id=369252 feature_name=fc_user_recent_favorite_doc_keyword_30d_has_match; depend=fc_user_recent_favorite_doc_keyword_30d,f_doc_keyword; method=HasMatch; feature_version=2; slot=4194; feature_id=369254 feature_name=fc_user_recent_favorite_doc_keyword_30d_tob_profile_match; depend=user_recent_favorite_doc_keyword_30d,f_doc_keyword; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4193; feature_id=369253 feature_name=fc_user_recent_favorite_doc_keyword_6h; depend=user_recent_favorite_doc_keyword_6h; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4183; shared=true; feature_id=369243 feature_name=fc_user_recent_favorite_doc_keyword_6h_has_match; depend=fc_user_recent_favorite_doc_keyword_6h,f_doc_keyword; method=HasMatch; feature_version=2; slot=4185; feature_id=369245 feature_name=fc_user_recent_favorite_doc_keyword_6h_tob_profile_match; depend=user_recent_favorite_doc_keyword_6h,f_doc_keyword; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4184; feature_id=369244 feature_name=fc_user_recent_favorite_doc_keyword_7d; depend=user_recent_favorite_doc_keyword_7d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4189; shared=true; feature_id=369249 feature_name=fc_user_recent_favorite_doc_keyword_7d_has_match; depend=fc_user_recent_favorite_doc_keyword_7d,f_doc_keyword; method=HasMatch; feature_version=2; slot=4191; feature_id=369251 feature_name=fc_user_recent_favorite_doc_keyword_7d_tob_profile_match; depend=user_recent_favorite_doc_keyword_7d,f_doc_keyword; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4190; feature_id=369250 feature_name=fc_user_recent_favorite_doc_tags_180d; depend=user_recent_favorite_doc_tags_180d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4285; shared=true; feature_id=369345 feature_name=fc_user_recent_favorite_doc_tags_180d_has_match; depend=fc_user_recent_favorite_doc_tags_180d,f_doc_tags; method=HasMatch; feature_version=2; slot=4287; feature_id=369347 feature_name=fc_user_recent_favorite_doc_tags_180d_tob_profile_match; depend=user_recent_favorite_doc_tags_180d,f_doc_tags; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4286; feature_id=369346 feature_name=fc_user_recent_favorite_doc_tags_1d; depend=user_recent_favorite_doc_tags_1d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4276; shared=true; feature_id=369336 feature_name=fc_user_recent_favorite_doc_tags_1d_has_match; depend=fc_user_recent_favorite_doc_tags_1d,f_doc_tags; method=HasMatch; feature_version=2; slot=4278; feature_id=369338 feature_name=fc_user_recent_favorite_doc_tags_1d_tob_profile_match; depend=user_recent_favorite_doc_tags_1d,f_doc_tags; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4277; feature_id=369337 feature_name=fc_user_recent_favorite_doc_tags_1h; depend=user_recent_favorite_doc_tags_1h; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4270; shared=true; feature_id=369330 feature_name=fc_user_recent_favorite_doc_tags_1h_has_match; depend=fc_user_recent_favorite_doc_tags_1h,f_doc_tags; method=HasMatch; feature_version=2; slot=4272; feature_id=369332 feature_name=fc_user_recent_favorite_doc_tags_1h_tob_profile_match; depend=user_recent_favorite_doc_tags_1h,f_doc_tags; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4271; feature_id=369331 feature_name=fc_user_recent_favorite_doc_tags_30d; depend=user_recent_favorite_doc_tags_30d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4282; shared=true; feature_id=369342 feature_name=fc_user_recent_favorite_doc_tags_30d_has_match; depend=fc_user_recent_favorite_doc_tags_30d,f_doc_tags; method=HasMatch; feature_version=2; slot=4284; feature_id=369344 feature_name=fc_user_recent_favorite_doc_tags_30d_tob_profile_match; depend=user_recent_favorite_doc_tags_30d,f_doc_tags; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4283; feature_id=369343 feature_name=fc_user_recent_favorite_doc_tags_6h; depend=user_recent_favorite_doc_tags_6h; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4273; shared=true; feature_id=369333 feature_name=fc_user_recent_favorite_doc_tags_6h_has_match; depend=fc_user_recent_favorite_doc_tags_6h,f_doc_tags; method=HasMatch; feature_version=2; slot=4275; feature_id=369335 feature_name=fc_user_recent_favorite_doc_tags_6h_tob_profile_match; depend=user_recent_favorite_doc_tags_6h,f_doc_tags; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4274; feature_id=369334 feature_name=fc_user_recent_favorite_doc_tags_7d; depend=user_recent_favorite_doc_tags_7d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4279; shared=true; feature_id=369339 feature_name=fc_user_recent_favorite_doc_tags_7d_has_match; depend=fc_user_recent_favorite_doc_tags_7d,f_doc_tags; method=HasMatch; feature_version=2; slot=4281; feature_id=369341 feature_name=fc_user_recent_favorite_doc_tags_7d_tob_profile_match; depend=user_recent_favorite_doc_tags_7d,f_doc_tags; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4280; feature_id=369340 feature_name=fc_user_recent_favorite_doc_topic_tag_180d; depend=user_recent_favorite_doc_topic_tag_180d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4177; shared=true; feature_id=369237 feature_name=fc_user_recent_favorite_doc_topic_tag_180d_has_match; depend=fc_user_recent_favorite_doc_topic_tag_180d,f_doc_topic_tag; method=HasMatch; feature_version=2; slot=4179; feature_id=369239 feature_name=fc_user_recent_favorite_doc_topic_tag_180d_tob_profile_match; depend=user_recent_favorite_doc_topic_tag_180d,f_doc_topic_tag; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4178; feature_id=369238 feature_name=fc_user_recent_favorite_doc_topic_tag_1d; depend=user_recent_favorite_doc_topic_tag_1d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4168; shared=true; feature_id=369228 feature_name=fc_user_recent_favorite_doc_topic_tag_1d_has_match; depend=fc_user_recent_favorite_doc_topic_tag_1d,f_doc_topic_tag; method=HasMatch; feature_version=2; slot=4170; feature_id=369230 feature_name=fc_user_recent_favorite_doc_topic_tag_1d_tob_profile_match; depend=user_recent_favorite_doc_topic_tag_1d,f_doc_topic_tag; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4169; feature_id=369229 feature_name=fc_user_recent_favorite_doc_topic_tag_1h; depend=user_recent_favorite_doc_topic_tag_1h; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4162; shared=true; feature_id=369222 feature_name=fc_user_recent_favorite_doc_topic_tag_1h_has_match; depend=fc_user_recent_favorite_doc_topic_tag_1h,f_doc_topic_tag; method=HasMatch; feature_version=2; slot=4164; feature_id=369224 feature_name=fc_user_recent_favorite_doc_topic_tag_1h_tob_profile_match; depend=user_recent_favorite_doc_topic_tag_1h,f_doc_topic_tag; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4163; feature_id=369223 feature_name=fc_user_recent_favorite_doc_topic_tag_30d; depend=user_recent_favorite_doc_topic_tag_30d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4174; shared=true; feature_id=369234 feature_name=fc_user_recent_favorite_doc_topic_tag_30d_has_match; depend=fc_user_recent_favorite_doc_topic_tag_30d,f_doc_topic_tag; method=HasMatch; feature_version=2; slot=4176; feature_id=369236 feature_name=fc_user_recent_favorite_doc_topic_tag_30d_tob_profile_match; depend=user_recent_favorite_doc_topic_tag_30d,f_doc_topic_tag; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4175; feature_id=369235 feature_name=fc_user_recent_favorite_doc_topic_tag_6h; depend=user_recent_favorite_doc_topic_tag_6h; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4165; shared=true; feature_id=369225 feature_name=fc_user_recent_favorite_doc_topic_tag_6h_has_match; depend=fc_user_recent_favorite_doc_topic_tag_6h,f_doc_topic_tag; method=HasMatch; feature_version=2; slot=4167; feature_id=369227 feature_name=fc_user_recent_favorite_doc_topic_tag_6h_tob_profile_match; depend=user_recent_favorite_doc_topic_tag_6h,f_doc_topic_tag; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4166; feature_id=369226 feature_name=fc_user_recent_favorite_doc_topic_tag_7d; depend=user_recent_favorite_doc_topic_tag_7d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4171; shared=true; feature_id=369231 feature_name=fc_user_recent_favorite_doc_topic_tag_7d_has_match; depend=fc_user_recent_favorite_doc_topic_tag_7d,f_doc_topic_tag; method=HasMatch; feature_version=2; slot=4173; feature_id=369233 feature_name=fc_user_recent_favorite_doc_topic_tag_7d_tob_profile_match; depend=user_recent_favorite_doc_topic_tag_7d,f_doc_topic_tag; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4172; feature_id=369232 feature_name=fc_user_recent_favorite_doc_type_180d; depend=user_recent_favorite_doc_type_180d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4159; shared=true; feature_id=369219 feature_name=fc_user_recent_favorite_doc_type_180d_has_match; depend=fc_user_recent_favorite_doc_type_180d,f_doc_type; method=HasMatch; feature_version=2; slot=4161; feature_id=369221 feature_name=fc_user_recent_favorite_doc_type_180d_tob_profile_match; depend=user_recent_favorite_doc_type_180d,f_doc_type; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4160; feature_id=369220 feature_name=fc_user_recent_favorite_doc_type_1d; depend=user_recent_favorite_doc_type_1d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4150; shared=true; feature_id=369210 feature_name=fc_user_recent_favorite_doc_type_1d_has_match; depend=fc_user_recent_favorite_doc_type_1d,f_doc_type; method=HasMatch; feature_version=2; slot=4152; feature_id=369212 feature_name=fc_user_recent_favorite_doc_type_1d_tob_profile_match; depend=user_recent_favorite_doc_type_1d,f_doc_type; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4151; feature_id=369211 feature_name=fc_user_recent_favorite_doc_type_1h; depend=user_recent_favorite_doc_type_1h; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4144; shared=true; feature_id=369204 feature_name=fc_user_recent_favorite_doc_type_1h_has_match; depend=fc_user_recent_favorite_doc_type_1h,f_doc_type; method=HasMatch; feature_version=2; slot=4146; feature_id=369206 feature_name=fc_user_recent_favorite_doc_type_1h_tob_profile_match; depend=user_recent_favorite_doc_type_1h,f_doc_type; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4145; feature_id=369205 feature_name=fc_user_recent_favorite_doc_type_30d; depend=user_recent_favorite_doc_type_30d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4156; shared=true; feature_id=369216 feature_name=fc_user_recent_favorite_doc_type_30d_has_match; depend=fc_user_recent_favorite_doc_type_30d,f_doc_type; method=HasMatch; feature_version=2; slot=4158; feature_id=369218 feature_name=fc_user_recent_favorite_doc_type_30d_tob_profile_match; depend=user_recent_favorite_doc_type_30d,f_doc_type; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4157; feature_id=369217 feature_name=fc_user_recent_favorite_doc_type_6h; depend=user_recent_favorite_doc_type_6h; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4147; shared=true; feature_id=369207 feature_name=fc_user_recent_favorite_doc_type_6h_has_match; depend=fc_user_recent_favorite_doc_type_6h,f_doc_type; method=HasMatch; feature_version=2; slot=4149; feature_id=369209 feature_name=fc_user_recent_favorite_doc_type_6h_tob_profile_match; depend=user_recent_favorite_doc_type_6h,f_doc_type; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4148; feature_id=369208 feature_name=fc_user_recent_favorite_doc_type_7d; depend=user_recent_favorite_doc_type_7d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4153; shared=true; feature_id=369213 feature_name=fc_user_recent_favorite_doc_type_7d_has_match; depend=fc_user_recent_favorite_doc_type_7d,f_doc_type; method=HasMatch; feature_version=2; slot=4155; feature_id=369215 feature_name=fc_user_recent_favorite_doc_type_7d_tob_profile_match; depend=user_recent_favorite_doc_type_7d,f_doc_type; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4154; feature_id=369214 feature_name=fc_user_recent_praise_doc_cate1_180d; depend=user_recent_praise_doc_cate1_180d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4501; shared=true; feature_id=369561 feature_name=fc_user_recent_praise_doc_cate1_180d_has_match; depend=fc_user_recent_praise_doc_cate1_180d,f_doc_cate1; method=HasMatch; feature_version=2; slot=4503; feature_id=369563 feature_name=fc_user_recent_praise_doc_cate1_180d_tob_profile_match; depend=user_recent_praise_doc_cate1_180d,f_doc_cate1; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4502; feature_id=369562 feature_name=fc_user_recent_praise_doc_cate1_1d; depend=user_recent_praise_doc_cate1_1d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4492; shared=true; feature_id=369552 feature_name=fc_user_recent_praise_doc_cate1_1d_has_match; depend=fc_user_recent_praise_doc_cate1_1d,f_doc_cate1; method=HasMatch; feature_version=2; slot=4494; feature_id=369554 feature_name=fc_user_recent_praise_doc_cate1_1d_tob_profile_match; depend=user_recent_praise_doc_cate1_1d,f_doc_cate1; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4493; feature_id=369553 feature_name=fc_user_recent_praise_doc_cate1_1h; depend=user_recent_praise_doc_cate1_1h; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4486; shared=true; feature_id=369546 feature_name=fc_user_recent_praise_doc_cate1_1h_has_match; depend=fc_user_recent_praise_doc_cate1_1h,f_doc_cate1; method=HasMatch; feature_version=2; slot=4488; feature_id=369548 feature_name=fc_user_recent_praise_doc_cate1_1h_tob_profile_match; depend=user_recent_praise_doc_cate1_1h,f_doc_cate1; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4487; feature_id=369547 feature_name=fc_user_recent_praise_doc_cate1_30d; depend=user_recent_praise_doc_cate1_30d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4498; shared=true; feature_id=369558 feature_name=fc_user_recent_praise_doc_cate1_30d_has_match; depend=fc_user_recent_praise_doc_cate1_30d,f_doc_cate1; method=HasMatch; feature_version=2; slot=4500; feature_id=369560 feature_name=fc_user_recent_praise_doc_cate1_30d_tob_profile_match; depend=user_recent_praise_doc_cate1_30d,f_doc_cate1; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4499; feature_id=369559 feature_name=fc_user_recent_praise_doc_cate1_6h; depend=user_recent_praise_doc_cate1_6h; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4489; shared=true; feature_id=369549 feature_name=fc_user_recent_praise_doc_cate1_6h_has_match; depend=fc_user_recent_praise_doc_cate1_6h,f_doc_cate1; method=HasMatch; feature_version=2; slot=4491; feature_id=369551 feature_name=fc_user_recent_praise_doc_cate1_6h_tob_profile_match; depend=user_recent_praise_doc_cate1_6h,f_doc_cate1; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4490; feature_id=369550 feature_name=fc_user_recent_praise_doc_cate1_7d; depend=user_recent_praise_doc_cate1_7d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4495; shared=true; feature_id=369555 feature_name=fc_user_recent_praise_doc_cate1_7d_has_match; depend=fc_user_recent_praise_doc_cate1_7d,f_doc_cate1; method=HasMatch; feature_version=2; slot=4497; feature_id=369557 feature_name=fc_user_recent_praise_doc_cate1_7d_tob_profile_match; depend=user_recent_praise_doc_cate1_7d,f_doc_cate1; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4496; feature_id=369556 feature_name=fc_user_recent_praise_doc_cate2_180d; depend=user_recent_praise_doc_cate2_180d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4537; shared=true; feature_id=369597 feature_name=fc_user_recent_praise_doc_cate2_180d_has_match; depend=fc_user_recent_praise_doc_cate2_180d,f_doc_cate2; method=HasMatch; feature_version=2; slot=4539; feature_id=369599 feature_name=fc_user_recent_praise_doc_cate2_180d_tob_profile_match; depend=user_recent_praise_doc_cate2_180d,f_doc_cate2; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4538; feature_id=369598 feature_name=fc_user_recent_praise_doc_cate2_1d; depend=user_recent_praise_doc_cate2_1d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4528; shared=true; feature_id=369588 feature_name=fc_user_recent_praise_doc_cate2_1d_has_match; depend=fc_user_recent_praise_doc_cate2_1d,f_doc_cate2; method=HasMatch; feature_version=2; slot=4530; feature_id=369590 feature_name=fc_user_recent_praise_doc_cate2_1d_tob_profile_match; depend=user_recent_praise_doc_cate2_1d,f_doc_cate2; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4529; feature_id=369589 feature_name=fc_user_recent_praise_doc_cate2_1h; depend=user_recent_praise_doc_cate2_1h; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4522; shared=true; feature_id=369582 feature_name=fc_user_recent_praise_doc_cate2_1h_has_match; depend=fc_user_recent_praise_doc_cate2_1h,f_doc_cate2; method=HasMatch; feature_version=2; slot=4524; feature_id=369584 feature_name=fc_user_recent_praise_doc_cate2_1h_tob_profile_match; depend=user_recent_praise_doc_cate2_1h,f_doc_cate2; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4523; feature_id=369583 feature_name=fc_user_recent_praise_doc_cate2_30d; depend=user_recent_praise_doc_cate2_30d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4534; shared=true; feature_id=369594 feature_name=fc_user_recent_praise_doc_cate2_30d_has_match; depend=fc_user_recent_praise_doc_cate2_30d,f_doc_cate2; method=HasMatch; feature_version=2; slot=4536; feature_id=369596 feature_name=fc_user_recent_praise_doc_cate2_30d_tob_profile_match; depend=user_recent_praise_doc_cate2_30d,f_doc_cate2; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4535; feature_id=369595 feature_name=fc_user_recent_praise_doc_cate2_6h; depend=user_recent_praise_doc_cate2_6h; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4525; shared=true; feature_id=369585 feature_name=fc_user_recent_praise_doc_cate2_6h_has_match; depend=fc_user_recent_praise_doc_cate2_6h,f_doc_cate2; method=HasMatch; feature_version=2; slot=4527; feature_id=369587 feature_name=fc_user_recent_praise_doc_cate2_6h_tob_profile_match; depend=user_recent_praise_doc_cate2_6h,f_doc_cate2; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4526; feature_id=369586 feature_name=fc_user_recent_praise_doc_cate2_7d; depend=user_recent_praise_doc_cate2_7d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4531; shared=true; feature_id=369591 feature_name=fc_user_recent_praise_doc_cate2_7d_has_match; depend=fc_user_recent_praise_doc_cate2_7d,f_doc_cate2; method=HasMatch; feature_version=2; slot=4533; feature_id=369593 feature_name=fc_user_recent_praise_doc_cate2_7d_tob_profile_match; depend=user_recent_praise_doc_cate2_7d,f_doc_cate2; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4532; feature_id=369592 feature_name=fc_user_recent_praise_doc_cate3_180d; depend=user_recent_praise_doc_cate3_180d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4519; shared=true; feature_id=369579 feature_name=fc_user_recent_praise_doc_cate3_180d_has_match; depend=fc_user_recent_praise_doc_cate3_180d,f_doc_cate3; method=HasMatch; feature_version=2; slot=4521; feature_id=369581 feature_name=fc_user_recent_praise_doc_cate3_180d_tob_profile_match; depend=user_recent_praise_doc_cate3_180d,f_doc_cate3; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4520; feature_id=369580 feature_name=fc_user_recent_praise_doc_cate3_1d; depend=user_recent_praise_doc_cate3_1d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4510; shared=true; feature_id=369570 feature_name=fc_user_recent_praise_doc_cate3_1d_has_match; depend=fc_user_recent_praise_doc_cate3_1d,f_doc_cate3; method=HasMatch; feature_version=2; slot=4512; feature_id=369572 feature_name=fc_user_recent_praise_doc_cate3_1d_tob_profile_match; depend=user_recent_praise_doc_cate3_1d,f_doc_cate3; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4511; feature_id=369571 feature_name=fc_user_recent_praise_doc_cate3_1h; depend=user_recent_praise_doc_cate3_1h; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4504; shared=true; feature_id=369564 feature_name=fc_user_recent_praise_doc_cate3_1h_has_match; depend=fc_user_recent_praise_doc_cate3_1h,f_doc_cate3; method=HasMatch; feature_version=2; slot=4506; feature_id=369566 feature_name=fc_user_recent_praise_doc_cate3_1h_tob_profile_match; depend=user_recent_praise_doc_cate3_1h,f_doc_cate3; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4505; feature_id=369565 feature_name=fc_user_recent_praise_doc_cate3_30d; depend=user_recent_praise_doc_cate3_30d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4516; shared=true; feature_id=369576 feature_name=fc_user_recent_praise_doc_cate3_30d_has_match; depend=fc_user_recent_praise_doc_cate3_30d,f_doc_cate3; method=HasMatch; feature_version=2; slot=4518; feature_id=369578 feature_name=fc_user_recent_praise_doc_cate3_30d_tob_profile_match; depend=user_recent_praise_doc_cate3_30d,f_doc_cate3; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4517; feature_id=369577 feature_name=fc_user_recent_praise_doc_cate3_6h; depend=user_recent_praise_doc_cate3_6h; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4507; shared=true; feature_id=369567 feature_name=fc_user_recent_praise_doc_cate3_6h_has_match; depend=fc_user_recent_praise_doc_cate3_6h,f_doc_cate3; method=HasMatch; feature_version=2; slot=4509; feature_id=369569 feature_name=fc_user_recent_praise_doc_cate3_6h_tob_profile_match; depend=user_recent_praise_doc_cate3_6h,f_doc_cate3; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4508; feature_id=369568 feature_name=fc_user_recent_praise_doc_cate3_7d; depend=user_recent_praise_doc_cate3_7d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4513; shared=true; feature_id=369573 feature_name=fc_user_recent_praise_doc_cate3_7d_has_match; depend=fc_user_recent_praise_doc_cate3_7d,f_doc_cate3; method=HasMatch; feature_version=2; slot=4515; feature_id=369575 feature_name=fc_user_recent_praise_doc_cate3_7d_tob_profile_match; depend=user_recent_praise_doc_cate3_7d,f_doc_cate3; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4514; feature_id=369574 feature_name=fc_user_recent_praise_doc_id_180d; depend=user_recent_praise_doc_id_180d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4555; shared=true; feature_id=369615 feature_name=fc_user_recent_praise_doc_id_180d_has_match; depend=fc_user_recent_praise_doc_id_180d,f_doc_id; method=HasMatch; feature_version=2; slot=4557; feature_id=369617 feature_name=fc_user_recent_praise_doc_id_180d_tob_profile_match; depend=user_recent_praise_doc_id_180d,f_doc_id; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4556; feature_id=369616 feature_name=fc_user_recent_praise_doc_id_1d; depend=user_recent_praise_doc_id_1d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4546; shared=true; feature_id=369606 feature_name=fc_user_recent_praise_doc_id_1d_has_match; depend=fc_user_recent_praise_doc_id_1d,f_doc_id; method=HasMatch; feature_version=2; slot=4548; feature_id=369608 feature_name=fc_user_recent_praise_doc_id_1d_tob_profile_match; depend=user_recent_praise_doc_id_1d,f_doc_id; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4547; feature_id=369607 feature_name=fc_user_recent_praise_doc_id_1h; depend=user_recent_praise_doc_id_1h; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4540; shared=true; feature_id=369600 feature_name=fc_user_recent_praise_doc_id_1h_has_match; depend=fc_user_recent_praise_doc_id_1h,f_doc_id; method=HasMatch; feature_version=2; slot=4542; feature_id=369602 feature_name=fc_user_recent_praise_doc_id_1h_tob_profile_match; depend=user_recent_praise_doc_id_1h,f_doc_id; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4541; feature_id=369601 feature_name=fc_user_recent_praise_doc_id_30d; depend=user_recent_praise_doc_id_30d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4552; shared=true; feature_id=369612 feature_name=fc_user_recent_praise_doc_id_30d_has_match; depend=fc_user_recent_praise_doc_id_30d,f_doc_id; method=HasMatch; feature_version=2; slot=4554; feature_id=369614 feature_name=fc_user_recent_praise_doc_id_30d_tob_profile_match; depend=user_recent_praise_doc_id_30d,f_doc_id; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4553; feature_id=369613 feature_name=fc_user_recent_praise_doc_id_6h; depend=user_recent_praise_doc_id_6h; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4543; shared=true; feature_id=369603 feature_name=fc_user_recent_praise_doc_id_6h_has_match; depend=fc_user_recent_praise_doc_id_6h,f_doc_id; method=HasMatch; feature_version=2; slot=4545; feature_id=369605 feature_name=fc_user_recent_praise_doc_id_6h_tob_profile_match; depend=user_recent_praise_doc_id_6h,f_doc_id; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4544; feature_id=369604 feature_name=fc_user_recent_praise_doc_id_7d; depend=user_recent_praise_doc_id_7d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4549; shared=true; feature_id=369609 feature_name=fc_user_recent_praise_doc_id_7d_has_match; depend=fc_user_recent_praise_doc_id_7d,f_doc_id; method=HasMatch; feature_version=2; slot=4551; feature_id=369611 feature_name=fc_user_recent_praise_doc_id_7d_tob_profile_match; depend=user_recent_praise_doc_id_7d,f_doc_id; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4550; feature_id=369610 feature_name=fc_user_recent_praise_doc_keyword_180d; depend=user_recent_praise_doc_keyword_180d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4483; shared=true; feature_id=369543 feature_name=fc_user_recent_praise_doc_keyword_180d_has_match; depend=fc_user_recent_praise_doc_keyword_180d,f_doc_keyword; method=HasMatch; feature_version=2; slot=4485; feature_id=369545 feature_name=fc_user_recent_praise_doc_keyword_180d_tob_profile_match; depend=user_recent_praise_doc_keyword_180d,f_doc_keyword; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4484; feature_id=369544 feature_name=fc_user_recent_praise_doc_keyword_1d; depend=user_recent_praise_doc_keyword_1d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4474; shared=true; feature_id=369534 feature_name=fc_user_recent_praise_doc_keyword_1d_has_match; depend=fc_user_recent_praise_doc_keyword_1d,f_doc_keyword; method=HasMatch; feature_version=2; slot=4476; feature_id=369536 feature_name=fc_user_recent_praise_doc_keyword_1d_tob_profile_match; depend=user_recent_praise_doc_keyword_1d,f_doc_keyword; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4475; feature_id=369535 feature_name=fc_user_recent_praise_doc_keyword_1h; depend=user_recent_praise_doc_keyword_1h; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4468; shared=true; feature_id=369528 feature_name=fc_user_recent_praise_doc_keyword_1h_has_match; depend=fc_user_recent_praise_doc_keyword_1h,f_doc_keyword; method=HasMatch; feature_version=2; slot=4470; feature_id=369530 feature_name=fc_user_recent_praise_doc_keyword_1h_tob_profile_match; depend=user_recent_praise_doc_keyword_1h,f_doc_keyword; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4469; feature_id=369529 feature_name=fc_user_recent_praise_doc_keyword_30d; depend=user_recent_praise_doc_keyword_30d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4480; shared=true; feature_id=369540 feature_name=fc_user_recent_praise_doc_keyword_30d_has_match; depend=fc_user_recent_praise_doc_keyword_30d,f_doc_keyword; method=HasMatch; feature_version=2; slot=4482; feature_id=369542 feature_name=fc_user_recent_praise_doc_keyword_30d_tob_profile_match; depend=user_recent_praise_doc_keyword_30d,f_doc_keyword; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4481; feature_id=369541 feature_name=fc_user_recent_praise_doc_keyword_6h; depend=user_recent_praise_doc_keyword_6h; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4471; shared=true; feature_id=369531 feature_name=fc_user_recent_praise_doc_keyword_6h_has_match; depend=fc_user_recent_praise_doc_keyword_6h,f_doc_keyword; method=HasMatch; feature_version=2; slot=4473; feature_id=369533 feature_name=fc_user_recent_praise_doc_keyword_6h_tob_profile_match; depend=user_recent_praise_doc_keyword_6h,f_doc_keyword; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4472; feature_id=369532 feature_name=fc_user_recent_praise_doc_keyword_7d; depend=user_recent_praise_doc_keyword_7d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4477; shared=true; feature_id=369537 feature_name=fc_user_recent_praise_doc_keyword_7d_has_match; depend=fc_user_recent_praise_doc_keyword_7d,f_doc_keyword; method=HasMatch; feature_version=2; slot=4479; feature_id=369539 feature_name=fc_user_recent_praise_doc_keyword_7d_tob_profile_match; depend=user_recent_praise_doc_keyword_7d,f_doc_keyword; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4478; feature_id=369538 feature_name=fc_user_recent_praise_doc_tags_180d; depend=user_recent_praise_doc_tags_180d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4573; shared=true; feature_id=369633 feature_name=fc_user_recent_praise_doc_tags_180d_has_match; depend=fc_user_recent_praise_doc_tags_180d,f_doc_tags; method=HasMatch; feature_version=2; slot=4575; feature_id=369635 feature_name=fc_user_recent_praise_doc_tags_180d_tob_profile_match; depend=user_recent_praise_doc_tags_180d,f_doc_tags; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4574; feature_id=369634 feature_name=fc_user_recent_praise_doc_tags_1d; depend=user_recent_praise_doc_tags_1d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4564; shared=true; feature_id=369624 feature_name=fc_user_recent_praise_doc_tags_1d_has_match; depend=fc_user_recent_praise_doc_tags_1d,f_doc_tags; method=HasMatch; feature_version=2; slot=4566; feature_id=369626 feature_name=fc_user_recent_praise_doc_tags_1d_tob_profile_match; depend=user_recent_praise_doc_tags_1d,f_doc_tags; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4565; feature_id=369625 feature_name=fc_user_recent_praise_doc_tags_1h; depend=user_recent_praise_doc_tags_1h; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4558; shared=true; feature_id=369618 feature_name=fc_user_recent_praise_doc_tags_1h_has_match; depend=fc_user_recent_praise_doc_tags_1h,f_doc_tags; method=HasMatch; feature_version=2; slot=4560; feature_id=369620 feature_name=fc_user_recent_praise_doc_tags_1h_tob_profile_match; depend=user_recent_praise_doc_tags_1h,f_doc_tags; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4559; feature_id=369619 feature_name=fc_user_recent_praise_doc_tags_30d; depend=user_recent_praise_doc_tags_30d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4570; shared=true; feature_id=369630 feature_name=fc_user_recent_praise_doc_tags_30d_has_match; depend=fc_user_recent_praise_doc_tags_30d,f_doc_tags; method=HasMatch; feature_version=2; slot=4572; feature_id=369632 feature_name=fc_user_recent_praise_doc_tags_30d_tob_profile_match; depend=user_recent_praise_doc_tags_30d,f_doc_tags; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4571; feature_id=369631 feature_name=fc_user_recent_praise_doc_tags_6h; depend=user_recent_praise_doc_tags_6h; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4561; shared=true; feature_id=369621 feature_name=fc_user_recent_praise_doc_tags_6h_has_match; depend=fc_user_recent_praise_doc_tags_6h,f_doc_tags; method=HasMatch; feature_version=2; slot=4563; feature_id=369623 feature_name=fc_user_recent_praise_doc_tags_6h_tob_profile_match; depend=user_recent_praise_doc_tags_6h,f_doc_tags; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4562; feature_id=369622 feature_name=fc_user_recent_praise_doc_tags_7d; depend=user_recent_praise_doc_tags_7d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4567; shared=true; feature_id=369627 feature_name=fc_user_recent_praise_doc_tags_7d_has_match; depend=fc_user_recent_praise_doc_tags_7d,f_doc_tags; method=HasMatch; feature_version=2; slot=4569; feature_id=369629 feature_name=fc_user_recent_praise_doc_tags_7d_tob_profile_match; depend=user_recent_praise_doc_tags_7d,f_doc_tags; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4568; feature_id=369628 feature_name=fc_user_recent_praise_doc_topic_tag_180d; depend=user_recent_praise_doc_topic_tag_180d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4465; shared=true; feature_id=369525 feature_name=fc_user_recent_praise_doc_topic_tag_180d_has_match; depend=fc_user_recent_praise_doc_topic_tag_180d,f_doc_topic_tag; method=HasMatch; feature_version=2; slot=4467; feature_id=369527 feature_name=fc_user_recent_praise_doc_topic_tag_180d_tob_profile_match; depend=user_recent_praise_doc_topic_tag_180d,f_doc_topic_tag; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4466; feature_id=369526 feature_name=fc_user_recent_praise_doc_topic_tag_1d; depend=user_recent_praise_doc_topic_tag_1d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4456; shared=true; feature_id=369516 feature_name=fc_user_recent_praise_doc_topic_tag_1d_has_match; depend=fc_user_recent_praise_doc_topic_tag_1d,f_doc_topic_tag; method=HasMatch; feature_version=2; slot=4458; feature_id=369518 feature_name=fc_user_recent_praise_doc_topic_tag_1d_tob_profile_match; depend=user_recent_praise_doc_topic_tag_1d,f_doc_topic_tag; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4457; feature_id=369517 feature_name=fc_user_recent_praise_doc_topic_tag_1h; depend=user_recent_praise_doc_topic_tag_1h; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4450; shared=true; feature_id=369510 feature_name=fc_user_recent_praise_doc_topic_tag_1h_has_match; depend=fc_user_recent_praise_doc_topic_tag_1h,f_doc_topic_tag; method=HasMatch; feature_version=2; slot=4452; feature_id=369512 feature_name=fc_user_recent_praise_doc_topic_tag_1h_tob_profile_match; depend=user_recent_praise_doc_topic_tag_1h,f_doc_topic_tag; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4451; feature_id=369511 feature_name=fc_user_recent_praise_doc_topic_tag_30d; depend=user_recent_praise_doc_topic_tag_30d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4462; shared=true; feature_id=369522 feature_name=fc_user_recent_praise_doc_topic_tag_30d_has_match; depend=fc_user_recent_praise_doc_topic_tag_30d,f_doc_topic_tag; method=HasMatch; feature_version=2; slot=4464; feature_id=369524 feature_name=fc_user_recent_praise_doc_topic_tag_30d_tob_profile_match; depend=user_recent_praise_doc_topic_tag_30d,f_doc_topic_tag; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4463; feature_id=369523 feature_name=fc_user_recent_praise_doc_topic_tag_6h; depend=user_recent_praise_doc_topic_tag_6h; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4453; shared=true; feature_id=369513 feature_name=fc_user_recent_praise_doc_topic_tag_6h_has_match; depend=fc_user_recent_praise_doc_topic_tag_6h,f_doc_topic_tag; method=HasMatch; feature_version=2; slot=4455; feature_id=369515 feature_name=fc_user_recent_praise_doc_topic_tag_6h_tob_profile_match; depend=user_recent_praise_doc_topic_tag_6h,f_doc_topic_tag; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4454; feature_id=369514 feature_name=fc_user_recent_praise_doc_topic_tag_7d; depend=user_recent_praise_doc_topic_tag_7d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4459; shared=true; feature_id=369519 feature_name=fc_user_recent_praise_doc_topic_tag_7d_has_match; depend=fc_user_recent_praise_doc_topic_tag_7d,f_doc_topic_tag; method=HasMatch; feature_version=2; slot=4461; feature_id=369521 feature_name=fc_user_recent_praise_doc_topic_tag_7d_tob_profile_match; depend=user_recent_praise_doc_topic_tag_7d,f_doc_topic_tag; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4460; feature_id=369520 feature_name=fc_user_recent_praise_doc_type_180d; depend=user_recent_praise_doc_type_180d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4447; shared=true; feature_id=369507 feature_name=fc_user_recent_praise_doc_type_180d_has_match; depend=fc_user_recent_praise_doc_type_180d,f_doc_type; method=HasMatch; feature_version=2; slot=4449; feature_id=369509 feature_name=fc_user_recent_praise_doc_type_180d_tob_profile_match; depend=user_recent_praise_doc_type_180d,f_doc_type; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4448; feature_id=369508 feature_name=fc_user_recent_praise_doc_type_1d; depend=user_recent_praise_doc_type_1d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4438; shared=true; feature_id=369498 feature_name=fc_user_recent_praise_doc_type_1d_has_match; depend=fc_user_recent_praise_doc_type_1d,f_doc_type; method=HasMatch; feature_version=2; slot=4440; feature_id=369500 feature_name=fc_user_recent_praise_doc_type_1d_tob_profile_match; depend=user_recent_praise_doc_type_1d,f_doc_type; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4439; feature_id=369499 feature_name=fc_user_recent_praise_doc_type_1h; depend=user_recent_praise_doc_type_1h; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4432; shared=true; feature_id=369492 feature_name=fc_user_recent_praise_doc_type_1h_has_match; depend=fc_user_recent_praise_doc_type_1h,f_doc_type; method=HasMatch; feature_version=2; slot=4434; feature_id=369494 feature_name=fc_user_recent_praise_doc_type_1h_tob_profile_match; depend=user_recent_praise_doc_type_1h,f_doc_type; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4433; feature_id=369493 feature_name=fc_user_recent_praise_doc_type_30d; depend=user_recent_praise_doc_type_30d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4444; shared=true; feature_id=369504 feature_name=fc_user_recent_praise_doc_type_30d_has_match; depend=fc_user_recent_praise_doc_type_30d,f_doc_type; method=HasMatch; feature_version=2; slot=4446; feature_id=369506 feature_name=fc_user_recent_praise_doc_type_30d_tob_profile_match; depend=user_recent_praise_doc_type_30d,f_doc_type; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4445; feature_id=369505 feature_name=fc_user_recent_praise_doc_type_6h; depend=user_recent_praise_doc_type_6h; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4435; shared=true; feature_id=369495 feature_name=fc_user_recent_praise_doc_type_6h_has_match; depend=fc_user_recent_praise_doc_type_6h,f_doc_type; method=HasMatch; feature_version=2; slot=4437; feature_id=369497 feature_name=fc_user_recent_praise_doc_type_6h_tob_profile_match; depend=user_recent_praise_doc_type_6h,f_doc_type; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4436; feature_id=369496 feature_name=fc_user_recent_praise_doc_type_7d; depend=user_recent_praise_doc_type_7d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4441; shared=true; feature_id=369501 feature_name=fc_user_recent_praise_doc_type_7d_has_match; depend=fc_user_recent_praise_doc_type_7d,f_doc_type; method=HasMatch; feature_version=2; slot=4443; feature_id=369503 feature_name=fc_user_recent_praise_doc_type_7d_tob_profile_match; depend=user_recent_praise_doc_type_7d,f_doc_type; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4442; feature_id=369502 feature_name=fc_user_recent_share_doc_cate1_180d; depend=user_recent_share_doc_cate1_180d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4069; shared=true; feature_id=369129 feature_name=fc_user_recent_share_doc_cate1_180d_has_match; depend=fc_user_recent_share_doc_cate1_180d,f_doc_cate1; method=HasMatch; feature_version=2; slot=4071; feature_id=369131 feature_name=fc_user_recent_share_doc_cate1_180d_tob_profile_match; depend=user_recent_share_doc_cate1_180d,f_doc_cate1; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4070; feature_id=369130 feature_name=fc_user_recent_share_doc_cate1_1d; depend=user_recent_share_doc_cate1_1d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4060; shared=true; feature_id=369120 feature_name=fc_user_recent_share_doc_cate1_1d_has_match; depend=fc_user_recent_share_doc_cate1_1d,f_doc_cate1; method=HasMatch; feature_version=2; slot=4062; feature_id=369122 feature_name=fc_user_recent_share_doc_cate1_1d_tob_profile_match; depend=user_recent_share_doc_cate1_1d,f_doc_cate1; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4061; feature_id=369121 feature_name=fc_user_recent_share_doc_cate1_1h; depend=user_recent_share_doc_cate1_1h; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4054; shared=true; feature_id=369114 feature_name=fc_user_recent_share_doc_cate1_1h_has_match; depend=fc_user_recent_share_doc_cate1_1h,f_doc_cate1; method=HasMatch; feature_version=2; slot=4056; feature_id=369116 feature_name=fc_user_recent_share_doc_cate1_1h_tob_profile_match; depend=user_recent_share_doc_cate1_1h,f_doc_cate1; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4055; feature_id=369115 feature_name=fc_user_recent_share_doc_cate1_30d; depend=user_recent_share_doc_cate1_30d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4066; shared=true; feature_id=369126 feature_name=fc_user_recent_share_doc_cate1_30d_has_match; depend=fc_user_recent_share_doc_cate1_30d,f_doc_cate1; method=HasMatch; feature_version=2; slot=4068; feature_id=369128 feature_name=fc_user_recent_share_doc_cate1_30d_tob_profile_match; depend=user_recent_share_doc_cate1_30d,f_doc_cate1; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4067; feature_id=369127 feature_name=fc_user_recent_share_doc_cate1_6h; depend=user_recent_share_doc_cate1_6h; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4057; shared=true; feature_id=369117 feature_name=fc_user_recent_share_doc_cate1_6h_has_match; depend=fc_user_recent_share_doc_cate1_6h,f_doc_cate1; method=HasMatch; feature_version=2; slot=4059; feature_id=369119 feature_name=fc_user_recent_share_doc_cate1_6h_tob_profile_match; depend=user_recent_share_doc_cate1_6h,f_doc_cate1; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4058; feature_id=369118 feature_name=fc_user_recent_share_doc_cate1_7d; depend=user_recent_share_doc_cate1_7d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4063; shared=true; feature_id=369123 feature_name=fc_user_recent_share_doc_cate1_7d_has_match; depend=fc_user_recent_share_doc_cate1_7d,f_doc_cate1; method=HasMatch; feature_version=2; slot=4065; feature_id=369125 feature_name=fc_user_recent_share_doc_cate1_7d_tob_profile_match; depend=user_recent_share_doc_cate1_7d,f_doc_cate1; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4064; feature_id=369124 feature_name=fc_user_recent_share_doc_cate2_180d; depend=user_recent_share_doc_cate2_180d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4105; shared=true; feature_id=369165 feature_name=fc_user_recent_share_doc_cate2_180d_has_match; depend=fc_user_recent_share_doc_cate2_180d,f_doc_cate2; method=HasMatch; feature_version=2; slot=4107; feature_id=369167 feature_name=fc_user_recent_share_doc_cate2_180d_tob_profile_match; depend=user_recent_share_doc_cate2_180d,f_doc_cate2; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4106; feature_id=369166 feature_name=fc_user_recent_share_doc_cate2_1d; depend=user_recent_share_doc_cate2_1d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4096; shared=true; feature_id=369156 feature_name=fc_user_recent_share_doc_cate2_1d_has_match; depend=fc_user_recent_share_doc_cate2_1d,f_doc_cate2; method=HasMatch; feature_version=2; slot=4098; feature_id=369158 feature_name=fc_user_recent_share_doc_cate2_1d_tob_profile_match; depend=user_recent_share_doc_cate2_1d,f_doc_cate2; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4097; feature_id=369157 feature_name=fc_user_recent_share_doc_cate2_1h; depend=user_recent_share_doc_cate2_1h; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4090; shared=true; feature_id=369150 feature_name=fc_user_recent_share_doc_cate2_1h_has_match; depend=fc_user_recent_share_doc_cate2_1h,f_doc_cate2; method=HasMatch; feature_version=2; slot=4092; feature_id=369152 feature_name=fc_user_recent_share_doc_cate2_1h_tob_profile_match; depend=user_recent_share_doc_cate2_1h,f_doc_cate2; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4091; feature_id=369151 feature_name=fc_user_recent_share_doc_cate2_30d; depend=user_recent_share_doc_cate2_30d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4102; shared=true; feature_id=369162 feature_name=fc_user_recent_share_doc_cate2_30d_has_match; depend=fc_user_recent_share_doc_cate2_30d,f_doc_cate2; method=HasMatch; feature_version=2; slot=4104; feature_id=369164 feature_name=fc_user_recent_share_doc_cate2_30d_tob_profile_match; depend=user_recent_share_doc_cate2_30d,f_doc_cate2; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4103; feature_id=369163 feature_name=fc_user_recent_share_doc_cate2_6h; depend=user_recent_share_doc_cate2_6h; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4093; shared=true; feature_id=369153 feature_name=fc_user_recent_share_doc_cate2_6h_has_match; depend=fc_user_recent_share_doc_cate2_6h,f_doc_cate2; method=HasMatch; feature_version=2; slot=4095; feature_id=369155 feature_name=fc_user_recent_share_doc_cate2_6h_tob_profile_match; depend=user_recent_share_doc_cate2_6h,f_doc_cate2; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4094; feature_id=369154 feature_name=fc_user_recent_share_doc_cate2_7d; depend=user_recent_share_doc_cate2_7d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4099; shared=true; feature_id=369159 feature_name=fc_user_recent_share_doc_cate2_7d_has_match; depend=fc_user_recent_share_doc_cate2_7d,f_doc_cate2; method=HasMatch; feature_version=2; slot=4101; feature_id=369161 feature_name=fc_user_recent_share_doc_cate2_7d_tob_profile_match; depend=user_recent_share_doc_cate2_7d,f_doc_cate2; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4100; feature_id=369160 feature_name=fc_user_recent_share_doc_cate3_180d; depend=user_recent_share_doc_cate3_180d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4087; shared=true; feature_id=369147 feature_name=fc_user_recent_share_doc_cate3_180d_has_match; depend=fc_user_recent_share_doc_cate3_180d,f_doc_cate3; method=HasMatch; feature_version=2; slot=4089; feature_id=369149 feature_name=fc_user_recent_share_doc_cate3_180d_tob_profile_match; depend=user_recent_share_doc_cate3_180d,f_doc_cate3; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4088; feature_id=369148 feature_name=fc_user_recent_share_doc_cate3_1d; depend=user_recent_share_doc_cate3_1d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4078; shared=true; feature_id=369138 feature_name=fc_user_recent_share_doc_cate3_1d_has_match; depend=fc_user_recent_share_doc_cate3_1d,f_doc_cate3; method=HasMatch; feature_version=2; slot=4080; feature_id=369140 feature_name=fc_user_recent_share_doc_cate3_1d_tob_profile_match; depend=user_recent_share_doc_cate3_1d,f_doc_cate3; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4079; feature_id=369139 feature_name=fc_user_recent_share_doc_cate3_1h; depend=user_recent_share_doc_cate3_1h; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4072; shared=true; feature_id=369132 feature_name=fc_user_recent_share_doc_cate3_1h_has_match; depend=fc_user_recent_share_doc_cate3_1h,f_doc_cate3; method=HasMatch; feature_version=2; slot=4074; feature_id=369134 feature_name=fc_user_recent_share_doc_cate3_1h_tob_profile_match; depend=user_recent_share_doc_cate3_1h,f_doc_cate3; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4073; feature_id=369133 feature_name=fc_user_recent_share_doc_cate3_30d; depend=user_recent_share_doc_cate3_30d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4084; shared=true; feature_id=369144 feature_name=fc_user_recent_share_doc_cate3_30d_has_match; depend=fc_user_recent_share_doc_cate3_30d,f_doc_cate3; method=HasMatch; feature_version=2; slot=4086; feature_id=369146 feature_name=fc_user_recent_share_doc_cate3_30d_tob_profile_match; depend=user_recent_share_doc_cate3_30d,f_doc_cate3; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4085; feature_id=369145 feature_name=fc_user_recent_share_doc_cate3_6h; depend=user_recent_share_doc_cate3_6h; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4075; shared=true; feature_id=369135 feature_name=fc_user_recent_share_doc_cate3_6h_has_match; depend=fc_user_recent_share_doc_cate3_6h,f_doc_cate3; method=HasMatch; feature_version=2; slot=4077; feature_id=369137 feature_name=fc_user_recent_share_doc_cate3_6h_tob_profile_match; depend=user_recent_share_doc_cate3_6h,f_doc_cate3; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4076; feature_id=369136 feature_name=fc_user_recent_share_doc_cate3_7d; depend=user_recent_share_doc_cate3_7d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4081; shared=true; feature_id=369141 feature_name=fc_user_recent_share_doc_cate3_7d_has_match; depend=fc_user_recent_share_doc_cate3_7d,f_doc_cate3; method=HasMatch; feature_version=2; slot=4083; feature_id=369143 feature_name=fc_user_recent_share_doc_cate3_7d_tob_profile_match; depend=user_recent_share_doc_cate3_7d,f_doc_cate3; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4082; feature_id=369142 feature_name=fc_user_recent_share_doc_id_180d; depend=user_recent_share_doc_id_180d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4123; shared=true; feature_id=369183 feature_name=fc_user_recent_share_doc_id_180d_has_match; depend=fc_user_recent_share_doc_id_180d,f_doc_id; method=HasMatch; feature_version=2; slot=4125; feature_id=369185 feature_name=fc_user_recent_share_doc_id_180d_tob_profile_match; depend=user_recent_share_doc_id_180d,f_doc_id; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4124; feature_id=369184 feature_name=fc_user_recent_share_doc_id_1d; depend=user_recent_share_doc_id_1d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4114; shared=true; feature_id=369174 feature_name=fc_user_recent_share_doc_id_1d_has_match; depend=fc_user_recent_share_doc_id_1d,f_doc_id; method=HasMatch; feature_version=2; slot=4116; feature_id=369176 feature_name=fc_user_recent_share_doc_id_1d_tob_profile_match; depend=user_recent_share_doc_id_1d,f_doc_id; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4115; feature_id=369175 feature_name=fc_user_recent_share_doc_id_1h; depend=user_recent_share_doc_id_1h; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4108; shared=true; feature_id=369168 feature_name=fc_user_recent_share_doc_id_1h_has_match; depend=fc_user_recent_share_doc_id_1h,f_doc_id; method=HasMatch; feature_version=2; slot=4110; feature_id=369170 feature_name=fc_user_recent_share_doc_id_1h_tob_profile_match; depend=user_recent_share_doc_id_1h,f_doc_id; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4109; feature_id=369169 feature_name=fc_user_recent_share_doc_id_30d; depend=user_recent_share_doc_id_30d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4120; shared=true; feature_id=369180 feature_name=fc_user_recent_share_doc_id_30d_has_match; depend=fc_user_recent_share_doc_id_30d,f_doc_id; method=HasMatch; feature_version=2; slot=4122; feature_id=369182 feature_name=fc_user_recent_share_doc_id_30d_tob_profile_match; depend=user_recent_share_doc_id_30d,f_doc_id; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4121; feature_id=369181 feature_name=fc_user_recent_share_doc_id_6h; depend=user_recent_share_doc_id_6h; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4111; shared=true; feature_id=369171 feature_name=fc_user_recent_share_doc_id_6h_has_match; depend=fc_user_recent_share_doc_id_6h,f_doc_id; method=HasMatch; feature_version=2; slot=4113; feature_id=369173 feature_name=fc_user_recent_share_doc_id_6h_tob_profile_match; depend=user_recent_share_doc_id_6h,f_doc_id; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4112; feature_id=369172 feature_name=fc_user_recent_share_doc_id_7d; depend=user_recent_share_doc_id_7d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4117; shared=true; feature_id=369177 feature_name=fc_user_recent_share_doc_id_7d_has_match; depend=fc_user_recent_share_doc_id_7d,f_doc_id; method=HasMatch; feature_version=2; slot=4119; feature_id=369179 feature_name=fc_user_recent_share_doc_id_7d_tob_profile_match; depend=user_recent_share_doc_id_7d,f_doc_id; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4118; feature_id=369178 feature_name=fc_user_recent_share_doc_keyword_180d; depend=user_recent_share_doc_keyword_180d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4051; shared=true; feature_id=369111 feature_name=fc_user_recent_share_doc_keyword_180d_has_match; depend=fc_user_recent_share_doc_keyword_180d,f_doc_keyword; method=HasMatch; feature_version=2; slot=4053; feature_id=369113 feature_name=fc_user_recent_share_doc_keyword_180d_tob_profile_match; depend=user_recent_share_doc_keyword_180d,f_doc_keyword; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4052; feature_id=369112 feature_name=fc_user_recent_share_doc_keyword_1d; depend=user_recent_share_doc_keyword_1d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4042; shared=true; feature_id=369102 feature_name=fc_user_recent_share_doc_keyword_1d_has_match; depend=fc_user_recent_share_doc_keyword_1d,f_doc_keyword; method=HasMatch; feature_version=2; slot=4044; feature_id=369104 feature_name=fc_user_recent_share_doc_keyword_1d_tob_profile_match; depend=user_recent_share_doc_keyword_1d,f_doc_keyword; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4043; feature_id=369103 feature_name=fc_user_recent_share_doc_keyword_1h; depend=user_recent_share_doc_keyword_1h; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4036; shared=true; feature_id=369096 feature_name=fc_user_recent_share_doc_keyword_1h_has_match; depend=fc_user_recent_share_doc_keyword_1h,f_doc_keyword; method=HasMatch; feature_version=2; slot=4038; feature_id=369098 feature_name=fc_user_recent_share_doc_keyword_1h_tob_profile_match; depend=user_recent_share_doc_keyword_1h,f_doc_keyword; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4037; feature_id=369097 feature_name=fc_user_recent_share_doc_keyword_30d; depend=user_recent_share_doc_keyword_30d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4048; shared=true; feature_id=369108 feature_name=fc_user_recent_share_doc_keyword_30d_has_match; depend=fc_user_recent_share_doc_keyword_30d,f_doc_keyword; method=HasMatch; feature_version=2; slot=4050; feature_id=369110 feature_name=fc_user_recent_share_doc_keyword_30d_tob_profile_match; depend=user_recent_share_doc_keyword_30d,f_doc_keyword; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4049; feature_id=369109 feature_name=fc_user_recent_share_doc_keyword_6h; depend=user_recent_share_doc_keyword_6h; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4039; shared=true; feature_id=369099 feature_name=fc_user_recent_share_doc_keyword_6h_has_match; depend=fc_user_recent_share_doc_keyword_6h,f_doc_keyword; method=HasMatch; feature_version=2; slot=4041; feature_id=369101 feature_name=fc_user_recent_share_doc_keyword_6h_tob_profile_match; depend=user_recent_share_doc_keyword_6h,f_doc_keyword; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4040; feature_id=369100 feature_name=fc_user_recent_share_doc_keyword_7d; depend=user_recent_share_doc_keyword_7d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4045; shared=true; feature_id=369105 feature_name=fc_user_recent_share_doc_keyword_7d_has_match; depend=fc_user_recent_share_doc_keyword_7d,f_doc_keyword; method=HasMatch; feature_version=2; slot=4047; feature_id=369107 feature_name=fc_user_recent_share_doc_keyword_7d_tob_profile_match; depend=user_recent_share_doc_keyword_7d,f_doc_keyword; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4046; feature_id=369106 feature_name=fc_user_recent_share_doc_tags_180d; depend=user_recent_share_doc_tags_180d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4141; shared=true; feature_id=369201 feature_name=fc_user_recent_share_doc_tags_180d_has_match; depend=fc_user_recent_share_doc_tags_180d,f_doc_tags; method=HasMatch; feature_version=2; slot=4143; feature_id=369203 feature_name=fc_user_recent_share_doc_tags_180d_tob_profile_match; depend=user_recent_share_doc_tags_180d,f_doc_tags; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4142; feature_id=369202 feature_name=fc_user_recent_share_doc_tags_1d; depend=user_recent_share_doc_tags_1d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4132; shared=true; feature_id=369192 feature_name=fc_user_recent_share_doc_tags_1d_has_match; depend=fc_user_recent_share_doc_tags_1d,f_doc_tags; method=HasMatch; feature_version=2; slot=4134; feature_id=369194 feature_name=fc_user_recent_share_doc_tags_1d_tob_profile_match; depend=user_recent_share_doc_tags_1d,f_doc_tags; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4133; feature_id=369193 feature_name=fc_user_recent_share_doc_tags_1h; depend=user_recent_share_doc_tags_1h; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4126; shared=true; feature_id=369186 feature_name=fc_user_recent_share_doc_tags_1h_has_match; depend=fc_user_recent_share_doc_tags_1h,f_doc_tags; method=HasMatch; feature_version=2; slot=4128; feature_id=369188 feature_name=fc_user_recent_share_doc_tags_1h_tob_profile_match; depend=user_recent_share_doc_tags_1h,f_doc_tags; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4127; feature_id=369187 feature_name=fc_user_recent_share_doc_tags_30d; depend=user_recent_share_doc_tags_30d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4138; shared=true; feature_id=369198 feature_name=fc_user_recent_share_doc_tags_30d_has_match; depend=fc_user_recent_share_doc_tags_30d,f_doc_tags; method=HasMatch; feature_version=2; slot=4140; feature_id=369200 feature_name=fc_user_recent_share_doc_tags_30d_tob_profile_match; depend=user_recent_share_doc_tags_30d,f_doc_tags; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4139; feature_id=369199 feature_name=fc_user_recent_share_doc_tags_6h; depend=user_recent_share_doc_tags_6h; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4129; shared=true; feature_id=369189 feature_name=fc_user_recent_share_doc_tags_6h_has_match; depend=fc_user_recent_share_doc_tags_6h,f_doc_tags; method=HasMatch; feature_version=2; slot=4131; feature_id=369191 feature_name=fc_user_recent_share_doc_tags_6h_tob_profile_match; depend=user_recent_share_doc_tags_6h,f_doc_tags; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4130; feature_id=369190 feature_name=fc_user_recent_share_doc_tags_7d; depend=user_recent_share_doc_tags_7d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4135; shared=true; feature_id=369195 feature_name=fc_user_recent_share_doc_tags_7d_has_match; depend=fc_user_recent_share_doc_tags_7d,f_doc_tags; method=HasMatch; feature_version=2; slot=4137; feature_id=369197 feature_name=fc_user_recent_share_doc_tags_7d_tob_profile_match; depend=user_recent_share_doc_tags_7d,f_doc_tags; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4136; feature_id=369196 feature_name=fc_user_recent_share_doc_topic_tag_180d; depend=user_recent_share_doc_topic_tag_180d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4033; shared=true; feature_id=369093 feature_name=fc_user_recent_share_doc_topic_tag_180d_has_match; depend=fc_user_recent_share_doc_topic_tag_180d,f_doc_topic_tag; method=HasMatch; feature_version=2; slot=4035; feature_id=369095 feature_name=fc_user_recent_share_doc_topic_tag_180d_tob_profile_match; depend=user_recent_share_doc_topic_tag_180d,f_doc_topic_tag; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4034; feature_id=369094 feature_name=fc_user_recent_share_doc_topic_tag_1d; depend=user_recent_share_doc_topic_tag_1d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4024; shared=true; feature_id=369084 feature_name=fc_user_recent_share_doc_topic_tag_1d_has_match; depend=fc_user_recent_share_doc_topic_tag_1d,f_doc_topic_tag; method=HasMatch; feature_version=2; slot=4026; feature_id=369086 feature_name=fc_user_recent_share_doc_topic_tag_1d_tob_profile_match; depend=user_recent_share_doc_topic_tag_1d,f_doc_topic_tag; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4025; feature_id=369085 feature_name=fc_user_recent_share_doc_topic_tag_1h; depend=user_recent_share_doc_topic_tag_1h; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4018; shared=true; feature_id=369078 feature_name=fc_user_recent_share_doc_topic_tag_1h_has_match; depend=fc_user_recent_share_doc_topic_tag_1h,f_doc_topic_tag; method=HasMatch; feature_version=2; slot=4020; feature_id=369080 feature_name=fc_user_recent_share_doc_topic_tag_1h_tob_profile_match; depend=user_recent_share_doc_topic_tag_1h,f_doc_topic_tag; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4019; feature_id=369079 feature_name=fc_user_recent_share_doc_topic_tag_30d; depend=user_recent_share_doc_topic_tag_30d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4030; shared=true; feature_id=369090 feature_name=fc_user_recent_share_doc_topic_tag_30d_has_match; depend=fc_user_recent_share_doc_topic_tag_30d,f_doc_topic_tag; method=HasMatch; feature_version=2; slot=4032; feature_id=369092 feature_name=fc_user_recent_share_doc_topic_tag_30d_tob_profile_match; depend=user_recent_share_doc_topic_tag_30d,f_doc_topic_tag; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4031; feature_id=369091 feature_name=fc_user_recent_share_doc_topic_tag_6h; depend=user_recent_share_doc_topic_tag_6h; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4021; shared=true; feature_id=369081 feature_name=fc_user_recent_share_doc_topic_tag_6h_has_match; depend=fc_user_recent_share_doc_topic_tag_6h,f_doc_topic_tag; method=HasMatch; feature_version=2; slot=4023; feature_id=369083 feature_name=fc_user_recent_share_doc_topic_tag_6h_tob_profile_match; depend=user_recent_share_doc_topic_tag_6h,f_doc_topic_tag; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4022; feature_id=369082 feature_name=fc_user_recent_share_doc_topic_tag_7d; depend=user_recent_share_doc_topic_tag_7d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4027; shared=true; feature_id=369087 feature_name=fc_user_recent_share_doc_topic_tag_7d_has_match; depend=fc_user_recent_share_doc_topic_tag_7d,f_doc_topic_tag; method=HasMatch; feature_version=2; slot=4029; feature_id=369089 feature_name=fc_user_recent_share_doc_topic_tag_7d_tob_profile_match; depend=user_recent_share_doc_topic_tag_7d,f_doc_topic_tag; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4028; feature_id=369088 feature_name=fc_user_recent_share_doc_type_180d; depend=user_recent_share_doc_type_180d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4015; shared=true; feature_id=369075 feature_name=fc_user_recent_share_doc_type_180d_has_match; depend=fc_user_recent_share_doc_type_180d,f_doc_type; method=HasMatch; feature_version=2; slot=4017; feature_id=369077 feature_name=fc_user_recent_share_doc_type_180d_tob_profile_match; depend=user_recent_share_doc_type_180d,f_doc_type; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4016; feature_id=369076 feature_name=fc_user_recent_share_doc_type_1d; depend=user_recent_share_doc_type_1d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4006; shared=true; feature_id=369066 feature_name=fc_user_recent_share_doc_type_1d_has_match; depend=fc_user_recent_share_doc_type_1d,f_doc_type; method=HasMatch; feature_version=2; slot=4008; feature_id=369068 feature_name=fc_user_recent_share_doc_type_1d_tob_profile_match; depend=user_recent_share_doc_type_1d,f_doc_type; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4007; feature_id=369067 feature_name=fc_user_recent_share_doc_type_1h; depend=user_recent_share_doc_type_1h; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4000; shared=true; feature_id=369060 feature_name=fc_user_recent_share_doc_type_1h_has_match; depend=fc_user_recent_share_doc_type_1h,f_doc_type; method=HasMatch; feature_version=2; slot=4002; feature_id=369062 feature_name=fc_user_recent_share_doc_type_1h_tob_profile_match; depend=user_recent_share_doc_type_1h,f_doc_type; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4001; feature_id=369061 feature_name=fc_user_recent_share_doc_type_30d; depend=user_recent_share_doc_type_30d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4012; shared=true; feature_id=369072 feature_name=fc_user_recent_share_doc_type_30d_has_match; depend=fc_user_recent_share_doc_type_30d,f_doc_type; method=HasMatch; feature_version=2; slot=4014; feature_id=369074 feature_name=fc_user_recent_share_doc_type_30d_tob_profile_match; depend=user_recent_share_doc_type_30d,f_doc_type; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4013; feature_id=369073 feature_name=fc_user_recent_share_doc_type_6h; depend=user_recent_share_doc_type_6h; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4003; shared=true; feature_id=369063 feature_name=fc_user_recent_share_doc_type_6h_has_match; depend=fc_user_recent_share_doc_type_6h,f_doc_type; method=HasMatch; feature_version=2; slot=4005; feature_id=369065 feature_name=fc_user_recent_share_doc_type_6h_tob_profile_match; depend=user_recent_share_doc_type_6h,f_doc_type; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4004; feature_id=369064 feature_name=fc_user_recent_share_doc_type_7d; depend=user_recent_share_doc_type_7d; method=TobInstanceProfilePairList; feature_version=2; args=10,2; slot=4009; shared=true; feature_id=369069 feature_name=fc_user_recent_share_doc_type_7d_has_match; depend=fc_user_recent_share_doc_type_7d,f_doc_type; method=HasMatch; feature_version=2; slot=4011; feature_id=369071 feature_name=fc_user_recent_share_doc_type_7d_tob_profile_match; depend=user_recent_share_doc_type_7d,f_doc_type; method=TobInstanceProfileMatch; feature_version=2; args=1,3; slot=4010; feature_id=369070 feature_name=fc_user_st_1d_doc_author_id_cart_cp; depend=user_st_1d_doc_author_id_cart_cp; method=VectorTopString; feature_version=2; args=10; slot=1359; shared=true; feature_id=368931 feature_name=fc_user_st_1d_doc_author_id_cart_recent; depend=user_st_1d_doc_author_id_cart_recent; method=VectorTopString; feature_version=2; args=30; slot=2098; shared=true; feature_id=369038 feature_name=fc_user_st_1d_doc_author_id_click_cp; depend=user_st_1d_doc_author_id_click_cp; method=VectorTopString; feature_version=2; args=10; slot=1220; shared=true; feature_id=368840 feature_name=fc_user_st_1d_doc_author_id_click_recent; depend=user_st_1d_doc_author_id_click_recent; method=VectorTopString; feature_version=2; args=30; slot=2020; shared=true; feature_id=368984 feature_name=fc_user_st_1d_doc_author_id_conversion_cp; depend=user_st_1d_doc_author_id_conversion_cp; method=VectorTopString; feature_version=2; args=10; slot=1383; shared=true; feature_id=368955 feature_name=fc_user_st_1d_doc_author_id_conversion_recent; depend=user_st_1d_doc_author_id_conversion_recent; method=VectorTopString; feature_version=2; args=30; slot=2114; shared=true; feature_id=369054 feature_name=fc_user_st_1d_doc_author_id_favorite_cp; depend=user_st_1d_doc_author_id_favorite_cp; method=VectorTopString; feature_version=2; args=10; slot=1221; shared=true; feature_id=368841 feature_name=fc_user_st_1d_doc_author_id_favorite_recent; depend=user_st_1d_doc_author_id_favorite_recent; method=VectorTopString; feature_version=2; args=30; slot=2021; shared=true; feature_id=368985 feature_name=fc_user_st_1d_doc_author_id_praise_cp; depend=user_st_1d_doc_author_id_praise_cp; method=VectorTopString; feature_version=2; args=10; slot=1222; shared=true; feature_id=368842 feature_name=fc_user_st_1d_doc_author_id_praise_recent; depend=user_st_1d_doc_author_id_praise_recent; method=VectorTopString; feature_version=2; args=30; slot=2022; shared=true; feature_id=368986 feature_name=fc_user_st_1d_doc_author_id_query_cp; depend=user_st_1d_doc_author_id_query_cp; method=VectorTopString; feature_version=2; args=10; slot=1223; shared=true; feature_id=368843 feature_name=fc_user_st_1d_doc_author_id_query_recent; depend=user_st_1d_doc_author_id_query_recent; method=VectorTopString; feature_version=2; args=30; slot=2023; shared=true; feature_id=368987 feature_name=fc_user_st_1d_doc_cate1_cart_cp; depend=user_st_1d_doc_cate1_cart_cp; method=VectorTopString; feature_version=2; args=10; slot=1350; shared=true; feature_id=368922 feature_name=fc_user_st_1d_doc_cate1_cart_recent; depend=user_st_1d_doc_cate1_cart_recent; method=VectorTopString; feature_version=2; args=30; slot=2092; shared=true; feature_id=369032 feature_name=fc_user_st_1d_doc_cate1_click_cp; depend=user_st_1d_doc_cate1_click_cp; method=VectorTopString; feature_version=2; args=10; slot=1208; shared=true; feature_id=368828 feature_name=fc_user_st_1d_doc_cate1_click_recent; depend=user_st_1d_doc_cate1_click_recent; method=VectorTopString; feature_version=2; args=30; slot=2008; shared=true; feature_id=368972 feature_name=fc_user_st_1d_doc_cate1_conversion_cp; depend=user_st_1d_doc_cate1_conversion_cp; method=VectorTopString; feature_version=2; args=10; slot=1374; shared=true; feature_id=368946 feature_name=fc_user_st_1d_doc_cate1_conversion_recent; depend=user_st_1d_doc_cate1_conversion_recent; method=VectorTopString; feature_version=2; args=30; slot=2108; shared=true; feature_id=369048 feature_name=fc_user_st_1d_doc_cate1_favorite_cp; depend=user_st_1d_doc_cate1_favorite_cp; method=VectorTopString; feature_version=2; args=10; slot=1209; shared=true; feature_id=368829 feature_name=fc_user_st_1d_doc_cate1_favorite_recent; depend=user_st_1d_doc_cate1_favorite_recent; method=VectorTopString; feature_version=2; args=30; slot=2009; shared=true; feature_id=368973 feature_name=fc_user_st_1d_doc_cate1_praise_cp; depend=user_st_1d_doc_cate1_praise_cp; method=VectorTopString; feature_version=2; args=10; slot=1210; shared=true; feature_id=368830 feature_name=fc_user_st_1d_doc_cate1_praise_recent; depend=user_st_1d_doc_cate1_praise_recent; method=VectorTopString; feature_version=2; args=30; slot=2010; shared=true; feature_id=368974 feature_name=fc_user_st_1d_doc_cate1_query_cp; depend=user_st_1d_doc_cate1_query_cp; method=VectorTopString; feature_version=2; args=10; slot=1211; shared=true; feature_id=368831 feature_name=fc_user_st_1d_doc_cate1_query_recent; depend=user_st_1d_doc_cate1_query_recent; method=VectorTopString; feature_version=2; args=30; slot=2011; shared=true; feature_id=368975 feature_name=fc_user_st_1d_doc_cate2_cart_cp; depend=user_st_1d_doc_cate2_cart_cp; method=VectorTopString; feature_version=2; args=10; slot=1353; shared=true; feature_id=368925 feature_name=fc_user_st_1d_doc_cate2_cart_recent; depend=user_st_1d_doc_cate2_cart_recent; method=VectorTopString; feature_version=2; args=30; slot=2094; shared=true; feature_id=369034 feature_name=fc_user_st_1d_doc_cate2_click_cp; depend=user_st_1d_doc_cate2_click_cp; method=VectorTopString; feature_version=2; args=10; slot=1212; shared=true; feature_id=368832 feature_name=fc_user_st_1d_doc_cate2_click_recent; depend=user_st_1d_doc_cate2_click_recent; method=VectorTopString; feature_version=2; args=30; slot=2012; shared=true; feature_id=368976 feature_name=fc_user_st_1d_doc_cate2_conversion_cp; depend=user_st_1d_doc_cate2_conversion_cp; method=VectorTopString; feature_version=2; args=10; slot=1377; shared=true; feature_id=368949 feature_name=fc_user_st_1d_doc_cate2_conversion_recent; depend=user_st_1d_doc_cate2_conversion_recent; method=VectorTopString; feature_version=2; args=30; slot=2110; shared=true; feature_id=369050 feature_name=fc_user_st_1d_doc_cate2_favorite_cp; depend=user_st_1d_doc_cate2_favorite_cp; method=VectorTopString; feature_version=2; args=10; slot=1213; shared=true; feature_id=368833 feature_name=fc_user_st_1d_doc_cate2_favorite_recent; depend=user_st_1d_doc_cate2_favorite_recent; method=VectorTopString; feature_version=2; args=30; slot=2013; shared=true; feature_id=368977 feature_name=fc_user_st_1d_doc_cate2_praise_cp; depend=user_st_1d_doc_cate2_praise_cp; method=VectorTopString; feature_version=2; args=10; slot=1214; shared=true; feature_id=368834 feature_name=fc_user_st_1d_doc_cate2_praise_recent; depend=user_st_1d_doc_cate2_praise_recent; method=VectorTopString; feature_version=2; args=30; slot=2014; shared=true; feature_id=368978 feature_name=fc_user_st_1d_doc_cate2_query_cp; depend=user_st_1d_doc_cate2_query_cp; method=VectorTopString; feature_version=2; args=10; slot=1215; shared=true; feature_id=368835 feature_name=fc_user_st_1d_doc_cate2_query_recent; depend=user_st_1d_doc_cate2_query_recent; method=VectorTopString; feature_version=2; args=30; slot=2015; shared=true; feature_id=368979 feature_name=fc_user_st_1d_doc_cate3_cart_cp; depend=user_st_1d_doc_cate3_cart_cp; method=VectorTopString; feature_version=2; args=10; slot=1356; shared=true; feature_id=368928 feature_name=fc_user_st_1d_doc_cate3_cart_recent; depend=user_st_1d_doc_cate3_cart_recent; method=VectorTopString; feature_version=2; args=30; slot=2096; shared=true; feature_id=369036 feature_name=fc_user_st_1d_doc_cate3_click_cp; depend=user_st_1d_doc_cate3_click_cp; method=VectorTopString; feature_version=2; args=10; slot=1216; shared=true; feature_id=368836 feature_name=fc_user_st_1d_doc_cate3_click_recent; depend=user_st_1d_doc_cate3_click_recent; method=VectorTopString; feature_version=2; args=30; slot=2016; shared=true; feature_id=368980 feature_name=fc_user_st_1d_doc_cate3_conversion_cp; depend=user_st_1d_doc_cate3_conversion_cp; method=VectorTopString; feature_version=2; args=10; slot=1380; shared=true; feature_id=368952 feature_name=fc_user_st_1d_doc_cate3_conversion_recent; depend=user_st_1d_doc_cate3_conversion_recent; method=VectorTopString; feature_version=2; args=30; slot=2112; shared=true; feature_id=369052 feature_name=fc_user_st_1d_doc_cate3_favorite_cp; depend=user_st_1d_doc_cate3_favorite_cp; method=VectorTopString; feature_version=2; args=10; slot=1217; shared=true; feature_id=368837 feature_name=fc_user_st_1d_doc_cate3_favorite_recent; depend=user_st_1d_doc_cate3_favorite_recent; method=VectorTopString; feature_version=2; args=30; slot=2017; shared=true; feature_id=368981 feature_name=fc_user_st_1d_doc_cate3_praise_cp; depend=user_st_1d_doc_cate3_praise_cp; method=VectorTopString; feature_version=2; args=10; slot=1218; shared=true; feature_id=368838 feature_name=fc_user_st_1d_doc_cate3_praise_recent; depend=user_st_1d_doc_cate3_praise_recent; method=VectorTopString; feature_version=2; args=30; slot=2018; shared=true; feature_id=368982 feature_name=fc_user_st_1d_doc_cate3_query_cp; depend=user_st_1d_doc_cate3_query_cp; method=VectorTopString; feature_version=2; args=10; slot=1219; shared=true; feature_id=368839 feature_name=fc_user_st_1d_doc_cate3_query_recent; depend=user_st_1d_doc_cate3_query_recent; method=VectorTopString; feature_version=2; args=30; slot=2019; shared=true; feature_id=368983 feature_name=fc_user_st_1d_doc_id_cart_cp; depend=user_st_1d_doc_id_cart_cp; method=VectorTopString; feature_version=2; args=10; slot=1344; shared=true; feature_id=368916 feature_name=fc_user_st_1d_doc_id_cart_recent; depend=user_st_1d_doc_id_cart_recent; method=VectorTopString; feature_version=2; args=30; slot=2088; shared=true; feature_id=369028 feature_name=fc_user_st_1d_doc_id_click_cp; depend=user_st_1d_doc_id_click_cp; method=VectorTopString; feature_version=2; args=10; slot=1200; shared=true; feature_id=368820 feature_name=fc_user_st_1d_doc_id_click_recent; depend=user_st_1d_doc_id_click_recent; method=VectorTopString; feature_version=2; args=30; slot=2000; shared=true; feature_id=368964 feature_name=fc_user_st_1d_doc_id_conversion_cp; depend=user_st_1d_doc_id_conversion_cp; method=VectorTopString; feature_version=2; args=10; slot=1368; shared=true; feature_id=368940 feature_name=fc_user_st_1d_doc_id_conversion_recent; depend=user_st_1d_doc_id_conversion_recent; method=VectorTopString; feature_version=2; args=30; slot=2104; shared=true; feature_id=369044 feature_name=fc_user_st_1d_doc_id_favorite_cp; depend=user_st_1d_doc_id_favorite_cp; method=VectorTopString; feature_version=2; args=10; slot=1201; shared=true; feature_id=368821 feature_name=fc_user_st_1d_doc_id_favorite_recent; depend=user_st_1d_doc_id_favorite_recent; method=VectorTopString; feature_version=2; args=30; slot=2001; shared=true; feature_id=368965 feature_name=fc_user_st_1d_doc_id_praise_cp; depend=user_st_1d_doc_id_praise_cp; method=VectorTopString; feature_version=2; args=10; slot=1202; shared=true; feature_id=368822 feature_name=fc_user_st_1d_doc_id_praise_recent; depend=user_st_1d_doc_id_praise_recent; method=VectorTopString; feature_version=2; args=30; slot=2002; shared=true; feature_id=368966 feature_name=fc_user_st_1d_doc_id_query_cp; depend=user_st_1d_doc_id_query_cp; method=VectorTopString; feature_version=2; args=10; slot=1203; shared=true; feature_id=368823 feature_name=fc_user_st_1d_doc_id_query_recent; depend=user_st_1d_doc_id_query_recent; method=VectorTopString; feature_version=2; args=30; slot=2003; shared=true; feature_id=368967 feature_name=fc_user_st_1d_doc_keyword_cart_cp; depend=user_st_1d_doc_keyword_cart_cp; method=VectorTopString; feature_version=2; args=10; slot=1365; shared=true; feature_id=368937 feature_name=fc_user_st_1d_doc_keyword_cart_recent; depend=user_st_1d_doc_keyword_cart_recent; method=VectorTopString; feature_version=2; args=30; slot=2102; shared=true; feature_id=369042 feature_name=fc_user_st_1d_doc_keyword_click_cp; depend=user_st_1d_doc_keyword_click_cp; method=VectorTopString; feature_version=2; args=10; slot=1228; shared=true; feature_id=368848 feature_name=fc_user_st_1d_doc_keyword_click_recent; depend=user_st_1d_doc_keyword_click_recent; method=VectorTopString; feature_version=2; args=30; slot=2028; shared=true; feature_id=368992 feature_name=fc_user_st_1d_doc_keyword_conversion_cp; depend=user_st_1d_doc_keyword_conversion_cp; method=VectorTopString; feature_version=2; args=10; slot=1389; shared=true; feature_id=368961 feature_name=fc_user_st_1d_doc_keyword_conversion_recent; depend=user_st_1d_doc_keyword_conversion_recent; method=VectorTopString; feature_version=2; args=30; slot=2118; shared=true; feature_id=369058 feature_name=fc_user_st_1d_doc_keyword_favorite_cp; depend=user_st_1d_doc_keyword_favorite_cp; method=VectorTopString; feature_version=2; args=10; slot=1229; shared=true; feature_id=368849 feature_name=fc_user_st_1d_doc_keyword_favorite_recent; depend=user_st_1d_doc_keyword_favorite_recent; method=VectorTopString; feature_version=2; args=30; slot=2029; shared=true; feature_id=368993 feature_name=fc_user_st_1d_doc_keyword_praise_cp; depend=user_st_1d_doc_keyword_praise_cp; method=VectorTopString; feature_version=2; args=10; slot=1230; shared=true; feature_id=368850 feature_name=fc_user_st_1d_doc_keyword_praise_recent; depend=user_st_1d_doc_keyword_praise_recent; method=VectorTopString; feature_version=2; args=30; slot=2030; shared=true; feature_id=368994 feature_name=fc_user_st_1d_doc_keyword_query_cp; depend=user_st_1d_doc_keyword_query_cp; method=VectorTopString; feature_version=2; args=10; slot=1231; shared=true; feature_id=368851 feature_name=fc_user_st_1d_doc_keyword_query_recent; depend=user_st_1d_doc_keyword_query_recent; method=VectorTopString; feature_version=2; args=30; slot=2031; shared=true; feature_id=368995 feature_name=fc_user_st_1d_doc_tags_cart_cp; depend=user_st_1d_doc_tags_cart_cp; method=VectorTopString; feature_version=2; args=10; slot=1362; shared=true; feature_id=368934 feature_name=fc_user_st_1d_doc_tags_cart_recent; depend=user_st_1d_doc_tags_cart_recent; method=VectorTopString; feature_version=2; args=30; slot=2100; shared=true; feature_id=369040 feature_name=fc_user_st_1d_doc_tags_click_cp; depend=user_st_1d_doc_tags_click_cp; method=VectorTopString; feature_version=2; args=10; slot=1224; shared=true; feature_id=368844 feature_name=fc_user_st_1d_doc_tags_click_recent; depend=user_st_1d_doc_tags_click_recent; method=VectorTopString; feature_version=2; args=30; slot=2024; shared=true; feature_id=368988 feature_name=fc_user_st_1d_doc_tags_conversion_cp; depend=user_st_1d_doc_tags_conversion_cp; method=VectorTopString; feature_version=2; args=10; slot=1386; shared=true; feature_id=368958 feature_name=fc_user_st_1d_doc_tags_conversion_recent; depend=user_st_1d_doc_tags_conversion_recent; method=VectorTopString; feature_version=2; args=30; slot=2116; shared=true; feature_id=369056 feature_name=fc_user_st_1d_doc_tags_favorite_cp; depend=user_st_1d_doc_tags_favorite_cp; method=VectorTopString; feature_version=2; args=10; slot=1225; shared=true; feature_id=368845 feature_name=fc_user_st_1d_doc_tags_favorite_recent; depend=user_st_1d_doc_tags_favorite_recent; method=VectorTopString; feature_version=2; args=30; slot=2025; shared=true; feature_id=368989 feature_name=fc_user_st_1d_doc_tags_praise_cp; depend=user_st_1d_doc_tags_praise_cp; method=VectorTopString; feature_version=2; args=10; slot=1226; shared=true; feature_id=368846 feature_name=fc_user_st_1d_doc_tags_praise_recent; depend=user_st_1d_doc_tags_praise_recent; method=VectorTopString; feature_version=2; args=30; slot=2026; shared=true; feature_id=368990 feature_name=fc_user_st_1d_doc_tags_query_cp; depend=user_st_1d_doc_tags_query_cp; method=VectorTopString; feature_version=2; args=10; slot=1227; shared=true; feature_id=368847 feature_name=fc_user_st_1d_doc_tags_query_recent; depend=user_st_1d_doc_tags_query_recent; method=VectorTopString; feature_version=2; args=30; slot=2027; shared=true; feature_id=368991 feature_name=fc_user_st_1d_doc_title_terms_cart_cp; depend=user_st_1d_doc_title_terms_cart_cp; method=VectorTopString; feature_version=2; args=10; slot=1347; shared=true; feature_id=368919 feature_name=fc_user_st_1d_doc_title_terms_cart_recent; depend=user_st_1d_doc_title_terms_cart_recent; method=VectorTopString; feature_version=2; args=30; slot=2090; shared=true; feature_id=369030 feature_name=fc_user_st_1d_doc_title_terms_click_cp; depend=user_st_1d_doc_title_terms_click_cp; method=VectorTopString; feature_version=2; args=10; slot=1204; shared=true; feature_id=368824 feature_name=fc_user_st_1d_doc_title_terms_click_recent; depend=user_st_1d_doc_title_terms_click_recent; method=VectorTopString; feature_version=2; args=30; slot=2004; shared=true; feature_id=368968 feature_name=fc_user_st_1d_doc_title_terms_conversion_cp; depend=user_st_1d_doc_title_terms_conversion_cp; method=VectorTopString; feature_version=2; args=10; slot=1371; shared=true; feature_id=368943 feature_name=fc_user_st_1d_doc_title_terms_conversion_recent; depend=user_st_1d_doc_title_terms_conversion_recent; method=VectorTopString; feature_version=2; args=30; slot=2106; shared=true; feature_id=369046 feature_name=fc_user_st_1d_doc_title_terms_favorite_cp; depend=user_st_1d_doc_title_terms_favorite_cp; method=VectorTopString; feature_version=2; args=10; slot=1205; shared=true; feature_id=368825 feature_name=fc_user_st_1d_doc_title_terms_favorite_recent; depend=user_st_1d_doc_title_terms_favorite_recent; method=VectorTopString; feature_version=2; args=30; slot=2005; shared=true; feature_id=368969 feature_name=fc_user_st_1d_doc_title_terms_praise_cp; depend=user_st_1d_doc_title_terms_praise_cp; method=VectorTopString; feature_version=2; args=10; slot=1206; shared=true; feature_id=368826 feature_name=fc_user_st_1d_doc_title_terms_praise_recent; depend=user_st_1d_doc_title_terms_praise_recent; method=VectorTopString; feature_version=2; args=30; slot=2006; shared=true; feature_id=368970 feature_name=fc_user_st_1d_doc_title_terms_query_cp; depend=user_st_1d_doc_title_terms_query_cp; method=VectorTopString; feature_version=2; args=10; slot=1207; shared=true; feature_id=368827 feature_name=fc_user_st_1d_doc_title_terms_query_recent; depend=user_st_1d_doc_title_terms_query_recent; method=VectorTopString; feature_version=2; args=30; slot=2007; shared=true; feature_id=368971 feature_name=fc_user_st_7d_doc_author_id_cart_cp; depend=user_st_7d_doc_author_id_cart_cp; method=VectorTopString; feature_version=2; args=30; slot=1360; shared=true; feature_id=368932 feature_name=fc_user_st_7d_doc_author_id_cart_recent; depend=user_st_7d_doc_author_id_cart_recent; method=VectorTopString; feature_version=2; args=50; slot=2099; shared=true; feature_id=369039 feature_name=fc_user_st_7d_doc_author_id_click_cp; depend=user_st_7d_doc_author_id_click_cp; method=VectorTopString; feature_version=2; args=30; slot=1276; shared=true; feature_id=368872 feature_name=fc_user_st_7d_doc_author_id_click_recent; depend=user_st_7d_doc_author_id_click_recent; method=VectorTopString; feature_version=2; args=50; slot=2076; shared=true; feature_id=369016 feature_name=fc_user_st_7d_doc_author_id_conversion_cp; depend=user_st_7d_doc_author_id_conversion_cp; method=VectorTopString; feature_version=2; args=30; slot=1384; shared=true; feature_id=368956 feature_name=fc_user_st_7d_doc_author_id_conversion_recent; depend=user_st_7d_doc_author_id_conversion_recent; method=VectorTopString; feature_version=2; args=50; slot=2115; shared=true; feature_id=369055 feature_name=fc_user_st_7d_doc_author_id_favorite_cp; depend=user_st_7d_doc_author_id_favorite_cp; method=VectorTopString; feature_version=2; args=30; slot=1277; shared=true; feature_id=368873 feature_name=fc_user_st_7d_doc_author_id_favorite_recent; depend=user_st_7d_doc_author_id_favorite_recent; method=VectorTopString; feature_version=2; args=50; slot=2077; shared=true; feature_id=369017 feature_name=fc_user_st_7d_doc_author_id_praise_cp; depend=user_st_7d_doc_author_id_praise_cp; method=VectorTopString; feature_version=2; args=30; slot=1278; shared=true; feature_id=368874 feature_name=fc_user_st_7d_doc_author_id_praise_recent; depend=user_st_7d_doc_author_id_praise_recent; method=VectorTopString; feature_version=2; args=50; slot=2078; shared=true; feature_id=369018 feature_name=fc_user_st_7d_doc_author_id_query_cp; depend=user_st_7d_doc_author_id_query_cp; method=VectorTopString; feature_version=2; args=30; slot=1279; shared=true; feature_id=368875 feature_name=fc_user_st_7d_doc_author_id_query_recent; depend=user_st_7d_doc_author_id_query_recent; method=VectorTopString; feature_version=2; args=50; slot=2079; shared=true; feature_id=369019 feature_name=fc_user_st_7d_doc_cate1_cart_cp; depend=user_st_7d_doc_cate1_cart_cp; method=VectorTopString; feature_version=2; args=30; slot=1351; shared=true; feature_id=368923 feature_name=fc_user_st_7d_doc_cate1_cart_recent; depend=user_st_7d_doc_cate1_cart_recent; method=VectorTopString; feature_version=2; args=50; slot=2093; shared=true; feature_id=369033 feature_name=fc_user_st_7d_doc_cate1_click_cp; depend=user_st_7d_doc_cate1_click_cp; method=VectorTopString; feature_version=2; args=30; slot=1264; shared=true; feature_id=368860 feature_name=fc_user_st_7d_doc_cate1_click_recent; depend=user_st_7d_doc_cate1_click_recent; method=VectorTopString; feature_version=2; args=50; slot=2064; shared=true; feature_id=369004 feature_name=fc_user_st_7d_doc_cate1_conversion_cp; depend=user_st_7d_doc_cate1_conversion_cp; method=VectorTopString; feature_version=2; args=30; slot=1375; shared=true; feature_id=368947 feature_name=fc_user_st_7d_doc_cate1_conversion_recent; depend=user_st_7d_doc_cate1_conversion_recent; method=VectorTopString; feature_version=2; args=50; slot=2109; shared=true; feature_id=369049 feature_name=fc_user_st_7d_doc_cate1_favorite_cp; depend=user_st_7d_doc_cate1_favorite_cp; method=VectorTopString; feature_version=2; args=30; slot=1265; shared=true; feature_id=368861 feature_name=fc_user_st_7d_doc_cate1_favorite_recent; depend=user_st_7d_doc_cate1_favorite_recent; method=VectorTopString; feature_version=2; args=50; slot=2065; shared=true; feature_id=369005 feature_name=fc_user_st_7d_doc_cate1_praise_cp; depend=user_st_7d_doc_cate1_praise_cp; method=VectorTopString; feature_version=2; args=30; slot=1266; shared=true; feature_id=368862 feature_name=fc_user_st_7d_doc_cate1_praise_recent; depend=user_st_7d_doc_cate1_praise_recent; method=VectorTopString; feature_version=2; args=50; slot=2066; shared=true; feature_id=369006 feature_name=fc_user_st_7d_doc_cate1_query_cp; depend=user_st_7d_doc_cate1_query_cp; method=VectorTopString; feature_version=2; args=30; slot=1267; shared=true; feature_id=368863 feature_name=fc_user_st_7d_doc_cate1_query_recent; depend=user_st_7d_doc_cate1_query_recent; method=VectorTopString; feature_version=2; args=50; slot=2067; shared=true; feature_id=369007 feature_name=fc_user_st_7d_doc_cate2_cart_cp; depend=user_st_7d_doc_cate2_cart_cp; method=VectorTopString; feature_version=2; args=30; slot=1354; shared=true; feature_id=368926 feature_name=fc_user_st_7d_doc_cate2_cart_recent; depend=user_st_7d_doc_cate2_cart_recent; method=VectorTopString; feature_version=2; args=50; slot=2095; shared=true; feature_id=369035 feature_name=fc_user_st_7d_doc_cate2_click_cp; depend=user_st_7d_doc_cate2_click_cp; method=VectorTopString; feature_version=2; args=30; slot=1268; shared=true; feature_id=368864 feature_name=fc_user_st_7d_doc_cate2_click_recent; depend=user_st_7d_doc_cate2_click_recent; method=VectorTopString; feature_version=2; args=50; slot=2068; shared=true; feature_id=369008 feature_name=fc_user_st_7d_doc_cate2_conversion_cp; depend=user_st_7d_doc_cate2_conversion_cp; method=VectorTopString; feature_version=2; args=30; slot=1378; shared=true; feature_id=368950 feature_name=fc_user_st_7d_doc_cate2_conversion_recent; depend=user_st_7d_doc_cate2_conversion_recent; method=VectorTopString; feature_version=2; args=50; slot=2111; shared=true; feature_id=369051 feature_name=fc_user_st_7d_doc_cate2_favorite_cp; depend=user_st_7d_doc_cate2_favorite_cp; method=VectorTopString; feature_version=2; args=30; slot=1269; shared=true; feature_id=368865 feature_name=fc_user_st_7d_doc_cate2_favorite_recent; depend=user_st_7d_doc_cate2_favorite_recent; method=VectorTopString; feature_version=2; args=50; slot=2069; shared=true; feature_id=369009 feature_name=fc_user_st_7d_doc_cate2_praise_cp; depend=user_st_7d_doc_cate2_praise_cp; method=VectorTopString; feature_version=2; args=30; slot=1270; shared=true; feature_id=368866 feature_name=fc_user_st_7d_doc_cate2_praise_recent; depend=user_st_7d_doc_cate2_praise_recent; method=VectorTopString; feature_version=2; args=50; slot=2070; shared=true; feature_id=369010 feature_name=fc_user_st_7d_doc_cate2_query_cp; depend=user_st_7d_doc_cate2_query_cp; method=VectorTopString; feature_version=2; args=30; slot=1271; shared=true; feature_id=368867 feature_name=fc_user_st_7d_doc_cate2_query_recent; depend=user_st_7d_doc_cate2_query_recent; method=VectorTopString; feature_version=2; args=50; slot=2071; shared=true; feature_id=369011 feature_name=fc_user_st_7d_doc_cate3_cart_cp; depend=user_st_7d_doc_cate3_cart_cp; method=VectorTopString; feature_version=2; args=30; slot=1357; shared=true; feature_id=368929 feature_name=fc_user_st_7d_doc_cate3_cart_recent; depend=user_st_7d_doc_cate3_cart_recent; method=VectorTopString; feature_version=2; args=50; slot=2097; shared=true; feature_id=369037 feature_name=fc_user_st_7d_doc_cate3_click_cp; depend=user_st_7d_doc_cate3_click_cp; method=VectorTopString; feature_version=2; args=30; slot=1272; shared=true; feature_id=368868 feature_name=fc_user_st_7d_doc_cate3_click_recent; depend=user_st_7d_doc_cate3_click_recent; method=VectorTopString; feature_version=2; args=50; slot=2072; shared=true; feature_id=369012 feature_name=fc_user_st_7d_doc_cate3_conversion_cp; depend=user_st_7d_doc_cate3_conversion_cp; method=VectorTopString; feature_version=2; args=30; slot=1381; shared=true; feature_id=368953 feature_name=fc_user_st_7d_doc_cate3_conversion_recent; depend=user_st_7d_doc_cate3_conversion_recent; method=VectorTopString; feature_version=2; args=50; slot=2113; shared=true; feature_id=369053 feature_name=fc_user_st_7d_doc_cate3_favorite_cp; depend=user_st_7d_doc_cate3_favorite_cp; method=VectorTopString; feature_version=2; args=30; slot=1273; shared=true; feature_id=368869 feature_name=fc_user_st_7d_doc_cate3_favorite_recent; depend=user_st_7d_doc_cate3_favorite_recent; method=VectorTopString; feature_version=2; args=50; slot=2073; shared=true; feature_id=369013 feature_name=fc_user_st_7d_doc_cate3_praise_cp; depend=user_st_7d_doc_cate3_praise_cp; method=VectorTopString; feature_version=2; args=30; slot=1274; shared=true; feature_id=368870 feature_name=fc_user_st_7d_doc_cate3_praise_recent; depend=user_st_7d_doc_cate3_praise_recent; method=VectorTopString; feature_version=2; args=50; slot=2074; shared=true; feature_id=369014 feature_name=fc_user_st_7d_doc_cate3_query_cp; depend=user_st_7d_doc_cate3_query_cp; method=VectorTopString; feature_version=2; args=30; slot=1275; shared=true; feature_id=368871 feature_name=fc_user_st_7d_doc_cate3_query_recent; depend=user_st_7d_doc_cate3_query_recent; method=VectorTopString; feature_version=2; args=50; slot=2075; shared=true; feature_id=369015 feature_name=fc_user_st_7d_doc_id_cart_cp; depend=user_st_7d_doc_id_cart_cp; method=VectorTopString; feature_version=2; args=30; slot=1345; shared=true; feature_id=368917 feature_name=fc_user_st_7d_doc_id_cart_recent; depend=user_st_7d_doc_id_cart_recent; method=VectorTopString; feature_version=2; args=50; slot=2089; shared=true; feature_id=369029 feature_name=fc_user_st_7d_doc_id_click_cp; depend=user_st_7d_doc_id_click_cp; method=VectorTopString; feature_version=2; args=30; slot=1256; shared=true; feature_id=368852 feature_name=fc_user_st_7d_doc_id_click_recent; depend=user_st_7d_doc_id_click_recent; method=VectorTopString; feature_version=2; args=50; slot=2056; shared=true; feature_id=368996 feature_name=fc_user_st_7d_doc_id_conversion_cp; depend=user_st_7d_doc_id_conversion_cp; method=VectorTopString; feature_version=2; args=30; slot=1369; shared=true; feature_id=368941 feature_name=fc_user_st_7d_doc_id_conversion_recent; depend=user_st_7d_doc_id_conversion_recent; method=VectorTopString; feature_version=2; args=50; slot=2105; shared=true; feature_id=369045 feature_name=fc_user_st_7d_doc_id_favorite_cp; depend=user_st_7d_doc_id_favorite_cp; method=VectorTopString; feature_version=2; args=30; slot=1257; shared=true; feature_id=368853 feature_name=fc_user_st_7d_doc_id_favorite_recent; depend=user_st_7d_doc_id_favorite_recent; method=VectorTopString; feature_version=2; args=50; slot=2057; shared=true; feature_id=368997 feature_name=fc_user_st_7d_doc_id_praise_cp; depend=user_st_7d_doc_id_praise_cp; method=VectorTopString; feature_version=2; args=30; slot=1258; shared=true; feature_id=368854 feature_name=fc_user_st_7d_doc_id_praise_recent; depend=user_st_7d_doc_id_praise_recent; method=VectorTopString; feature_version=2; args=50; slot=2058; shared=true; feature_id=368998 feature_name=fc_user_st_7d_doc_id_query_cp; depend=user_st_7d_doc_id_query_cp; method=VectorTopString; feature_version=2; args=30; slot=1259; shared=true; feature_id=368855 feature_name=fc_user_st_7d_doc_id_query_recent; depend=user_st_7d_doc_id_query_recent; method=VectorTopString; feature_version=2; args=50; slot=2059; shared=true; feature_id=368999 feature_name=fc_user_st_7d_doc_keyword_cart_cp; depend=user_st_7d_doc_keyword_cart_cp; method=VectorTopString; feature_version=2; args=30; slot=1366; shared=true; feature_id=368938 feature_name=fc_user_st_7d_doc_keyword_cart_recent; depend=user_st_7d_doc_keyword_cart_recent; method=VectorTopString; feature_version=2; args=50; slot=2103; shared=true; feature_id=369043 feature_name=fc_user_st_7d_doc_keyword_click_cp; depend=user_st_7d_doc_keyword_click_cp; method=VectorTopString; feature_version=2; args=30; slot=1284; shared=true; feature_id=368880 feature_name=fc_user_st_7d_doc_keyword_click_recent; depend=user_st_7d_doc_keyword_click_recent; method=VectorTopString; feature_version=2; args=50; slot=2084; shared=true; feature_id=369024 feature_name=fc_user_st_7d_doc_keyword_conversion_cp; depend=user_st_7d_doc_keyword_conversion_cp; method=VectorTopString; feature_version=2; args=30; slot=1390; shared=true; feature_id=368962 feature_name=fc_user_st_7d_doc_keyword_conversion_recent; depend=user_st_7d_doc_keyword_conversion_recent; method=VectorTopString; feature_version=2; args=50; slot=2119; shared=true; feature_id=369059 feature_name=fc_user_st_7d_doc_keyword_favorite_cp; depend=user_st_7d_doc_keyword_favorite_cp; method=VectorTopString; feature_version=2; args=30; slot=1285; shared=true; feature_id=368881 feature_name=fc_user_st_7d_doc_keyword_favorite_recent; depend=user_st_7d_doc_keyword_favorite_recent; method=VectorTopString; feature_version=2; args=50; slot=2085; shared=true; feature_id=369025 feature_name=fc_user_st_7d_doc_keyword_praise_cp; depend=user_st_7d_doc_keyword_praise_cp; method=VectorTopString; feature_version=2; args=30; slot=1286; shared=true; feature_id=368882 feature_name=fc_user_st_7d_doc_keyword_praise_recent; depend=user_st_7d_doc_keyword_praise_recent; method=VectorTopString; feature_version=2; args=50; slot=2086; shared=true; feature_id=369026 feature_name=fc_user_st_7d_doc_keyword_query_cp; depend=user_st_7d_doc_keyword_query_cp; method=VectorTopString; feature_version=2; args=30; slot=1287; shared=true; feature_id=368883 feature_name=fc_user_st_7d_doc_keyword_query_recent; depend=user_st_7d_doc_keyword_query_recent; method=VectorTopString; feature_version=2; args=50; slot=2087; shared=true; feature_id=369027 feature_name=fc_user_st_7d_doc_tags_cart_cp; depend=user_st_7d_doc_tags_cart_cp; method=VectorTopString; feature_version=2; args=30; slot=1363; shared=true; feature_id=368935 feature_name=fc_user_st_7d_doc_tags_cart_recent; depend=user_st_7d_doc_tags_cart_recent; method=VectorTopString; feature_version=2; args=50; slot=2101; shared=true; feature_id=369041 feature_name=fc_user_st_7d_doc_tags_click_cp; depend=user_st_7d_doc_tags_click_cp; method=VectorTopString; feature_version=2; args=30; slot=1280; shared=true; feature_id=368876 feature_name=fc_user_st_7d_doc_tags_click_recent; depend=user_st_7d_doc_tags_click_recent; method=VectorTopString; feature_version=2; args=50; slot=2080; shared=true; feature_id=369020 feature_name=fc_user_st_7d_doc_tags_conversion_cp; depend=user_st_7d_doc_tags_conversion_cp; method=VectorTopString; feature_version=2; args=30; slot=1387; shared=true; feature_id=368959 feature_name=fc_user_st_7d_doc_tags_conversion_recent; depend=user_st_7d_doc_tags_conversion_recent; method=VectorTopString; feature_version=2; args=50; slot=2117; shared=true; feature_id=369057 feature_name=fc_user_st_7d_doc_tags_favorite_cp; depend=user_st_7d_doc_tags_favorite_cp; method=VectorTopString; feature_version=2; args=30; slot=1281; shared=true; feature_id=368877 feature_name=fc_user_st_7d_doc_tags_favorite_recent; depend=user_st_7d_doc_tags_favorite_recent; method=VectorTopString; feature_version=2; args=50; slot=2081; shared=true; feature_id=369021 feature_name=fc_user_st_7d_doc_tags_praise_cp; depend=user_st_7d_doc_tags_praise_cp; method=VectorTopString; feature_version=2; args=30; slot=1282; shared=true; feature_id=368878 feature_name=fc_user_st_7d_doc_tags_praise_recent; depend=user_st_7d_doc_tags_praise_recent; method=VectorTopString; feature_version=2; args=50; slot=2082; shared=true; feature_id=369022 feature_name=fc_user_st_7d_doc_tags_query_cp; depend=user_st_7d_doc_tags_query_cp; method=VectorTopString; feature_version=2; args=30; slot=1283; shared=true; feature_id=368879 feature_name=fc_user_st_7d_doc_tags_query_recent; depend=user_st_7d_doc_tags_query_recent; method=VectorTopString; feature_version=2; args=50; slot=2083; shared=true; feature_id=369023 feature_name=fc_user_st_7d_doc_title_terms_cart_cp; depend=user_st_7d_doc_title_terms_cart_cp; method=VectorTopString; feature_version=2; args=30; slot=1348; shared=true; feature_id=368920 feature_name=fc_user_st_7d_doc_title_terms_cart_recent; depend=user_st_7d_doc_title_terms_cart_recent; method=VectorTopString; feature_version=2; args=50; slot=2091; shared=true; feature_id=369031 feature_name=fc_user_st_7d_doc_title_terms_click_cp; depend=user_st_7d_doc_title_terms_click_cp; method=VectorTopString; feature_version=2; args=30; slot=1260; shared=true; feature_id=368856 feature_name=fc_user_st_7d_doc_title_terms_click_recent; depend=user_st_7d_doc_title_terms_click_recent; method=VectorTopString; feature_version=2; args=50; slot=2060; shared=true; feature_id=369000 feature_name=fc_user_st_7d_doc_title_terms_conversion_cp; depend=user_st_7d_doc_title_terms_conversion_cp; method=VectorTopString; feature_version=2; args=30; slot=1372; shared=true; feature_id=368944 feature_name=fc_user_st_7d_doc_title_terms_conversion_recent; depend=user_st_7d_doc_title_terms_conversion_recent; method=VectorTopString; feature_version=2; args=50; slot=2107; shared=true; feature_id=369047 feature_name=fc_user_st_7d_doc_title_terms_favorite_cp; depend=user_st_7d_doc_title_terms_favorite_cp; method=VectorTopString; feature_version=2; args=30; slot=1261; shared=true; feature_id=368857 feature_name=fc_user_st_7d_doc_title_terms_favorite_recent; depend=user_st_7d_doc_title_terms_favorite_recent; method=VectorTopString; feature_version=2; args=50; slot=2061; shared=true; feature_id=369001 feature_name=fc_user_st_7d_doc_title_terms_praise_cp; depend=user_st_7d_doc_title_terms_praise_cp; method=VectorTopString; feature_version=2; args=30; slot=1262; shared=true; feature_id=368858 feature_name=fc_user_st_7d_doc_title_terms_praise_recent; depend=user_st_7d_doc_title_terms_praise_recent; method=VectorTopString; feature_version=2; args=50; slot=2062; shared=true; feature_id=369002 feature_name=fc_user_st_7d_doc_title_terms_query_cp; depend=user_st_7d_doc_title_terms_query_cp; method=VectorTopString; feature_version=2; args=30; slot=1263; shared=true; feature_id=368859 feature_name=fc_user_st_7d_doc_title_terms_query_recent; depend=user_st_7d_doc_title_terms_query_recent; method=VectorTopString; feature_version=2; args=50; slot=2063; shared=true; feature_id=369003 ================================================ FILE: monolith/native_training/data/tf_example_to_example_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math from absl import app, logging import tensorflow as tf import numpy as np from monolith.native_training.data.feature_utils import tf_example_to_example from monolith.native_training.data.parsers import parse_examples # The following functions can be used to convert a value to a type compatible # with tf.train.Example. def _bytes_feature(value): """Returns a bytes_list from a string / byte.""" if isinstance(value, type(tf.constant(0))): value = value.numpy( ) # BytesList won't unpack a string from an EagerTensor. return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) def _float_feature(value): """Returns a float_list from a float / double.""" return tf.train.Feature(float_list=tf.train.FloatList(value=[value])) def _int64_feature(value): """Returns an int64_list from a bool / enum / int / uint.""" return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) # @tf.py_function(Tout=tf.string) def serialize_example(feature0, feature1, feature2, feature3, feature4): """ Creates a tf.train.Example message ready to be written to a file. """ # Create a dictionary mapping the feature name to the tf.train.Example-compatible # data type. feature = { 'feature0': _int64_feature(feature0), 'feature1': _int64_feature(feature1), 'feature2': _bytes_feature(feature2), 'feature3': _float_feature(feature3), 'feature4': _float_feature(feature4), } # Create a Features message using tf.train.Example. example_proto = tf.train.Example(features=tf.train.Features(feature=feature)) return example_proto.SerializeToString() def get_fid_v2(slot: int, signature: int): fid_v2_mask = (1 << 48) - 1 return (slot << 48) | (signature & fid_v2_mask) def calc_hash_value(val: float): return int(math.log2(abs(val) + 1)) class TFExampleToExampleTest(tf.test.TestCase): def test_tf_example_to_example(self): tf.compat.v1.disable_v2_behavior() logging.set_verbosity(logging.INFO) # The number of observations in the dataset. n_observations = int(1e4) # Boolean feature, encoded as False or True. feature0 = np.random.choice([False, True], n_observations) # Integer feature, random from 0 to 4. feature1 = np.random.randint(0, 5, n_observations) # String feature. strings = np.array([b'cat', b'dog', b'chicken', b'horse', b'goat']) feature2 = strings[feature1] # Float feature, from a standard normal distribution. feature3 = np.random.randn(n_observations) feature4 = np.random.randn(n_observations) filename = '/tmp/test.tfrecord' # Write the `tf.train.Example` observations to the file. with tf.io.TFRecordWriter(filename) as writer: for i in range(n_observations): example = serialize_example(feature0[i], feature1[i], feature2[i], feature3[i], feature4[i]) writer.write(example) filenames = [filename] dataset = tf.data.TFRecordDataset(filenames) # logging.info(dataset) # for raw_record in dataset.take(1): # example = tf.train.Example() # example.ParseFromString(raw_record.numpy()) # print(example) def map_fn(tensor: tf.Tensor): return tf_example_to_example(tensor, sparse_features={ "feature0": 1, "feature1": 2, "feature4": 3 }, dense_features=["feature2"], label="feature3", instance_weight=None) def parse_fn(variant: tf.Tensor): return parse_examples( variant, sparse_features=["not_existed1", "feature0", "feature1", "feature4"], dense_features=[ "label", "feature2", "feature3", "not_existed2", "instance_weight" ], dense_feature_types=[ tf.float32, tf.string, tf.float32, tf.float32, tf.float32 ], dense_feature_shapes=[1, 1, 1, 1, 1], ) batch_size = 2 dataset = dataset.map(map_fn) dataset = dataset.batch(batch_size) dataset = dataset.map(parse_fn) session_config = tf.compat.v1.ConfigProto(allow_soft_placement=True) session_config.graph_options.rewrite_options.disable_meta_optimizer = True with tf.compat.v1.Session(config=session_config) as sess: it = tf.compat.v1.data.make_one_shot_iterator(dataset) element = it.get_next() element["label"] = tf.reshape(element['label'], shape=(-1,)) element["feature2"] = tf.reshape(element['feature2'], shape=(-1,)) element["feature3"] = tf.reshape(element['feature3'], shape=(-1,)) element["not_existed2"] = tf.reshape(element['not_existed2'], shape=(-1,)) element["instance_weight"] = tf.reshape(element['instance_weight'], shape=(-1,)) for i in range(n_observations // batch_size): features = sess.run(fetches=element) self.assertAllEqual(features['not_existed1'].values.shape, (0,)) feature0_fids = [ get_fid_v2(1, val) for val in feature0[i * batch_size:(i + 1) * batch_size] ] feature1_fids = [ get_fid_v2(2, val) for val in feature1[i * batch_size:(i + 1) * batch_size] ] feature4_fids = [ get_fid_v2(3, calc_hash_value(val)) for val in feature4[i * batch_size:(i + 1) * batch_size] ] self.assertAllEqual(features['feature0'].values, feature0_fids) self.assertAllEqual(features['feature1'].values, feature1_fids) self.assertAllEqual(features['feature4'].values, feature4_fids) self.assertAllClose(features['label'], feature3[i * batch_size:(i + 1) * batch_size]) self.assertAllClose(features['feature3'], [0, 0]) self.assertAllClose(features['not_existed2'], [0, 0]) self.assertAllClose(features['instance_weight'], [1.0, 1.0]) if __name__ == '__main__': tf.compat.v1.disable_eager_execution() tf.test.main() ================================================ FILE: monolith/native_training/data/training_instance/BUILD ================================================ load("@rules_python//python:defs.bzl", "py_binary", "py_library", "py_test") load("@rules_cc//cc:defs.bzl", "cc_library", "cc_test") load("@org_tensorflow//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test", "tf_custom_op_library") package(default_visibility = ["//visibility:public"]) cc_library( name = "fid", hdrs = ["cc/fid.h"], deps = [], ) cc_test( name = "fid_test", srcs = ["cc/fid_test.cc"], deps = [ ":fid", ":reader_util", "@com_google_googletest//:gtest_main", ], ) cc_library( name = "reader_util", srcs = ["cc/reader_util.cc"], hdrs = ["cc/reader_util.h"], deps = [ "//third_party/nlohmann:json", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", "@org_tensorflow//tensorflow/core:framework_headers_lib", "@org_tensorflow//tensorflow/core/kernels:ops_util_hdrs", "@org_tensorflow//tensorflow/core/platform:logging", ], ) cc_test( name = "reader_util_test", srcs = ["cc/reader_util_test.cc"], deps = [ ":reader_util", "@com_google_googletest//:gtest_main", ], ) cc_library( name = "cached_mem_pool", srcs = ["cc/cached_mem_pool.cc"], hdrs = ["cc/cached_mem_pool.h"], deps = [ "@com_google_glog//:glog", ], ) cc_test( name = "cached_mem_pool_test", srcs = ["cc/cached_mem_pool_test.cc"], deps = [ ":cached_mem_pool", "@com_google_googletest//:gtest_main", ], ) cc_library( name = "snappy_inputbuffer", srcs = [ "cc/snappy_inputbuffer.cc", ], hdrs = [ "cc/snappy_inputbuffer.h", ], deps = [ ":cached_mem_pool", "@org_tensorflow//tensorflow/core:framework_headers_lib", ], ) cc_library( name="zstd_inputbuffer", srcs=[ "cc/zstd_inputbuffer.cc", ], hdrs=[ "cc/zstd_inputbuffer.h", ], deps=[ ":cached_mem_pool", "@org_tensorflow//tensorflow/core:framework_headers_lib", "@zstd", ], ) cc_library( name = "ue_compress", srcs = ["cc/ue_compress.cc"], hdrs = ["cc/ue_compress.h"], deps = [ "//idl:compression_qtz8mm", "//idl:proto_parser_cc_proto", "@com_google_glog//:glog", ], ) cc_test( name = "ue_compress_test", srcs = ["cc/ue_compress_test.cc"], deps = [ ":ue_compress", "@com_google_googletest//:gtest_main", ], ) cc_library( name = "data_format_options", hdrs = ["cc/data_format_options.h"], ) cc_library( name="data_reader", srcs=[ "cc/data_reader.cc", "cc/data_reader.h", "cc/pb_variant.cc", "cc/pb_variant.h", ], deps=[ ":data_format_options", ":reader_util", ":snappy_inputbuffer", ":zstd_inputbuffer", ":ue_compress", "//idl:example_cc_proto", "//idl:proto_parser_cc_proto", "//monolith/native_training/runtime/ops:traceme", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@org_tensorflow//tensorflow/core:framework_headers_lib", ], ) cc_library( name = "data_writer", srcs = [ "cc/data_writer.cc", ], hdrs = [ "cc/data_writer.h", ], deps = [ ":data_format_options", "@com_google_absl//absl/strings", "@org_tensorflow//tensorflow/core:framework_headers_lib", ], ) tf_cc_test( name = "data_read_write_test", srcs = [ "cc/data_read_write_test.cc", ], deps = [ ":data_reader", ":data_writer", "@com_google_googletest//:gtest_main", ], ) cc_library( name = "parse_instance_lib", srcs = ["cc/parse_instance_lib.cc"], hdrs = ["cc/parse_instance_lib.h"], deps = [ ":data_reader", ":reader_util", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/types:span", "@com_google_glog//:glog", "@org_tensorflow//tensorflow/core:framework_headers_lib", "@org_tensorflow//tensorflow/core/kernels:ops_util_hdrs", ], ) cc_library( name = "instance_utils", srcs = ["cc/instance_utils.cc"], hdrs = ["cc/instance_utils.h"], deps = [ ":reader_util", "//idl:example_cc_proto", "//idl:proto_parser_cc_proto", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", ], ) cc_test( name = "instance_utils_test", srcs = ["cc/instance_utils_test.cc"], deps = [ ":instance_utils", "@com_google_absl//absl/time", "@com_google_glog//:glog", "@com_google_googletest//:gtest_main", ], ) tf_cc_binary( name = "instance_processor", srcs = [ "cc/instance_processor.cc", ], copts = ["-fexceptions"], deps = [ ":data_reader", ":instance_utils", "//third_party/nlohmann:json", "@org_tensorflow//tensorflow/core:tensorflow", ], ) tf_cc_binary( name = "instance_reader", srcs = [ "cc/instance_reader.cc", ], copts = ["-fexceptions"], deps = [ ":data_reader", ":fid", ":instance_utils", "//monolith/native_training/data/transform:transforms", "//third_party/cli11:cli11", "//third_party/nlohmann:json", "@org_tensorflow//tensorflow/core:tensorflow", ], ) cc_library( name = "pb_datasource_lib", srcs = [ "cc/instance_dataset_kernel.cc", "cc/parse_instance_kernel.cc", ], deps = [ ":data_reader", ":instance_utils", ":parse_instance_lib", "@com_google_absl//absl/container:flat_hash_map", ], ) cc_library( name = "pb_datasource_ops", srcs = [ "cc/instance_dataset_ops.cc", "cc/parse_instance_ops.cc", ], copts = ["-DNDEBUG"], deps = [ ":pb_datasource_lib", "@org_tensorflow//tensorflow/core:framework_headers_lib", ], alwayslink = 1, ) py_library( name = "instance_dataset_ops_py", srcs = [ "python/instance_dataset_op.py", ], deps = [ "//monolith:utils", "//monolith/native_training:runner_utils", "//monolith/native_training/distribute:distributed_dataset", "//monolith/native_training/data:datasets_py", "//monolith/native_training/hooks:ckpt_hooks", "//monolith/native_training/runtime/ops:gen_monolith_ops", "@org_tensorflow//tensorflow:tensorflow_py", ], ) py_test( name = "parse_instance_ops_py_test", srcs = [ "python/parse_instance_ops_test.py", ], main = "python/parse_instance_ops_test.py", deps = [ ":instance_dataset_ops_py", ":parse_instance_ops_py", "//idl:proto_parser_py_proto", ], ) py_binary( name = "instance_dataset_op_py_test_stdin", srcs = [ "python/instance_dataset_op_test_stdin.py", ], main = "python/instance_dataset_op_test_stdin.py", deps = [ ":instance_dataset_ops_py", ":parse_instance_ops_py", "//idl:proto_parser_py_proto", ], ) py_library( name = "parser_utils", srcs = [ "python/parser_utils.py", ], deps = [ "//monolith/native_training:ragged_utils", ], ) py_library( name = "parse_instance_ops_py", srcs = [ "python/parse_instance_ops.py", ], deps = [ ":parser_utils", "//idl:proto_parser_py_proto", "//monolith:utils", "//monolith/native_training/runtime/ops:gen_monolith_ops", ], ) exports_files([ "cc/parse_instance_kernel.cc", "cc/parse_instance_ops.cc", ]) ================================================ FILE: monolith/native_training/data/training_instance/cc/cached_mem_pool.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "cached_mem_pool.h" #include "glog/logging.h" namespace tensorflow { namespace monolith_tf { std::mutex session_mutex; CachedMemPool* CachedMemPool::cached_mem_pool = nullptr; CachedMemPool* CachedMemPool::init(size_t buffer_size) { std::unique_lock lock(session_mutex); if (cached_mem_pool == nullptr) { cached_mem_pool = new CachedMemPool(buffer_size); } return cached_mem_pool; } std::unique_ptr CachedMemPool::allocate() { std::unique_lock lock(alloc_mtx_); if (cached_buffers_.empty()) { total_requested_++; return std::make_unique(buffer_size_); } else { auto buffer = std::move(cached_buffers_.back()); cached_buffers_.pop_back(); return buffer; } } void CachedMemPool::deallocate(std::unique_ptr& buffer) { std::unique_lock lock(alloc_mtx_); cached_buffers_.emplace_back(std::move(buffer)); } } // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/data/training_instance/cc/cached_mem_pool.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_NATIVE_TRAINING_DATA_TRAINING_INSTANCE_CC_CACHED_MEM_POOL_H_ #define MONOLITH_NATIVE_TRAINING_DATA_TRAINING_INSTANCE_CC_CACHED_MEM_POOL_H_ #include #include #include namespace tensorflow { namespace monolith_tf { class CachedMemPool { public: static CachedMemPool* init(size_t buffer_size); std::unique_ptr allocate(); void deallocate(std::unique_ptr& buffer); // Test Only method. size_t get_buffer_size() { std::unique_lock lock(alloc_mtx_); return cached_buffers_.size(); } private: explicit CachedMemPool(size_t buffer_size) : buffer_size_(buffer_size) {} ~CachedMemPool() { cached_buffers_.clear(); } size_t buffer_size_; std::mutex alloc_mtx_; size_t total_requested_ = 0; std::vector> cached_buffers_; static CachedMemPool* cached_mem_pool; }; } // namespace monolith_tf } // namespace tensorflow #endif // MONOLITH_NATIVE_TRAINING_DATA_TRAINING_INSTANCE_CC_CACHED_MEM_POOL_H_ ================================================ FILE: monolith/native_training/data/training_instance/cc/cached_mem_pool_test.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "cached_mem_pool.h" #include #include #include "gtest/gtest.h" namespace tensorflow { namespace monolith_tf { TEST(CachedMemPoolTest, Basic) { CachedMemPool* mem_pool = CachedMemPool::init(1024 * 1024); auto buffer = mem_pool->allocate(); EXPECT_EQ(mem_pool->get_buffer_size(), 0); mem_pool->deallocate(buffer); EXPECT_EQ(mem_pool->get_buffer_size(), 1); } TEST(CachedMemPoolTest, RecursiveAllocation) { CachedMemPool* mem_pool = CachedMemPool::init(1024 * 1024); std::vector> buffers; for (int i = 0; i < 30; i++) { auto buffer = mem_pool->allocate(); buffers.emplace_back(std::move(buffer)); } EXPECT_EQ(mem_pool->get_buffer_size(), 0); for (auto& buffer : buffers) { mem_pool->deallocate(buffer); } EXPECT_EQ(mem_pool->get_buffer_size(), 30); for (int i = 0; i < 30; i++) { auto buffer = mem_pool->allocate(); buffers.emplace_back(std::move(buffer)); } EXPECT_EQ(mem_pool->get_buffer_size(), 0); } } // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/data/training_instance/cc/data_format_options.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_NATIVE_TRAINING_DATA_TRAINING_INSTANCE_CC_DATA_FORMAT_OPTIONS_H_ #define MONOLITH_NATIVE_TRAINING_DATA_TRAINING_INSTANCE_CC_DATA_FORMAT_OPTIONS_H_ #include namespace tensorflow { namespace monolith_tf { struct DataFormatOptions { bool lagrangex_header = false; bool kafka_dump_prefix = false; bool has_sort_id = false; bool kafka_dump = false; }; inline std::ostream& operator<<(std::ostream& os, const DataFormatOptions& opts) { return os << "lagrangex_header: " << opts.lagrangex_header << " kafka_dump_prefix: " << opts.kafka_dump_prefix << " has_sort_id: " << opts.has_sort_id << " kafka_dump: " << opts.kafka_dump; } } // namespace monolith_tf } // namespace tensorflow #endif // MONOLITH_NATIVE_TRAINING_DATA_TRAINING_INSTANCE_CC_DATA_FORMAT_OPTIONS_H_ ================================================ FILE: monolith/native_training/data/training_instance/cc/data_read_write_test.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "gmock/gmock.h" #include "gtest/gtest.h" #include "monolith/native_training/data/training_instance/cc/data_reader.h" #include "monolith/native_training/data/training_instance/cc/data_writer.h" namespace tensorflow { namespace monolith_tf { namespace { class ReadWriteTest : public ::testing::TestWithParam {}; Status ReadBytes(BaseStreamReaderTmpl* reader, absl::string_view* out) { uint8_t pb_type; uint32_t data_source_key; return reader->ReadPBBytes(&pb_type, &data_source_key, out); } TEST_P(ReadWriteTest, Basic) { DataFormatOptions options = GetParam(); std::string s; StringStreamWriter writer(options, &s); for (int i = 0; i < 16; ++i) { EXPECT_TRUE(writer.WriteRecord(std::string(i, 'a')).ok()); } ZeroCopyStringViewStreamReader reader(options, s); absl::string_view out; for (int i = 0; i < 16; ++i) { auto status = ReadBytes(&reader, &out); EXPECT_TRUE(status.ok()) << status; EXPECT_THAT(out, std::string(i, 'a')) << i; } } std::vector GenerateOptions() { std::vector res; for (int i = 0; i < 16; ++i) { DataFormatOptions options; options.lagrangex_header = i & 1; options.kafka_dump_prefix = i / 2 & 1; options.has_sort_id = i / 4 & 1; options.kafka_dump = i / 8 & 1; res.push_back(options); } return res; } INSTANTIATE_TEST_SUITE_P(ReadWriteTestAll, ReadWriteTest, testing::ValuesIn(GenerateOptions())); } // namespace } // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/data/training_instance/cc/data_reader.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/data/training_instance/cc/data_reader.h" #include #include #include "monolith/native_training/data/training_instance/cc/snappy_inputbuffer.h" #include "monolith/native_training/data/training_instance/cc/zstd_inputbuffer.h" #include "tensorflow/core/lib/core/coding.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/io/buffered_inputstream.h" #include "tensorflow/core/lib/io/random_inputstream.h" #include "tensorflow/core/lib/io/zlib_compression_options.h" #include "tensorflow/core/lib/io/zlib_inputstream.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/profiler/lib/traceme.h" namespace tensorflow { namespace monolith_tf { using EFeature = ::monolith::io::proto::Feature; using Example = ::monolith::io::proto::Example; using ExampleBatch = ::monolith::io::proto::ExampleBatch; using FeatureListType = ::monolith::io::proto::FeatureListType; using IFeature = ::idl::matrix::proto::Feature; using Instance = ::parser::proto::Instance; using LineId = ::idl::matrix::proto::LineId; data_format::DataFormat data_format::StringToDataFormat( const std::string &type) { static const std::string EXAMPLEBATCH_STR = "examplebatch"; static const std::string EXAMPLE_STR = "example"; static const std::string INSTANCE_STR = "instance"; static const std::string PLAINTEXT_STR = "plaintext"; if (type == PLAINTEXT_STR) { return data_format::PLAINTEXT; } else if (type == EXAMPLEBATCH_STR) { return data_format::EXAMPLEBATCH; } else if (type == EXAMPLE_STR) { return data_format::EXAMPLE; } else if (type == INSTANCE_STR) { return data_format::INSTANCE; } else { LOG(WARNING) << "unknow input:" << type; return data_format::UNKNOW; } return data_format::UNKNOW; } static const size_t kLengthSize = 8; const int kDEFAULT_SNAPPY_BUFFER_SIZE = 64 * 1024 * 1024; // 64MB Status AddFeature(const std::string &name, const EFeature &efeat, Instance *instance) { if (name == "__LINE_ID__") { const auto &line_id = efeat.bytes_list().value(0); bool ok = instance->mutable_line_id()->ParseFromArray(line_id.data(), line_id.size()); if (!ok) { return errors::FailedPrecondition("Failed to parse the LineId"); } } else if (name == "__LABEL__") { const auto &float_list = efeat.float_list(); for (const auto &value : float_list.value()) { instance->add_label(value); } } else if (name == "instance_weight") { float instance_weight = 1.0; if (efeat.float_list().value_size() > 0) { instance_weight = efeat.float_list().value(0); } instance->set_instance_weight(instance_weight); } else { switch (efeat.type_case()) { case EFeature::TypeCase::TYPE_NOT_SET: break; case EFeature::TypeCase::kFidV1List: for (const auto &value : efeat.fid_v1_list().value()) { instance->add_fid(value); } break; default: IFeature *ifeat = instance->add_feature(); ifeat->set_name(name); switch (efeat.type_case()) { case EFeature::TypeCase::kFidV2List: for (const auto &fid : efeat.fid_v2_list().value()) { ifeat->add_fid(fid); } break; case EFeature::TypeCase::kFloatList: for (const auto &fv : efeat.float_list().value()) { ifeat->add_float_value(fv); } break; case EFeature::TypeCase::kInt64List: for (const auto &iv : efeat.int64_list().value()) { ifeat->add_int64_value(iv); } break; case EFeature::TypeCase::kBytesList: for (const auto &bv : efeat.bytes_list().value()) { ifeat->add_bytes_value(bv); } break; case EFeature::TypeCase::kFidV2Lists: for (const auto &elist : efeat.fid_v2_lists().list()) { auto *ilist = ifeat->add_fid_list(); for (const auto &value : elist.value()) { ilist->add_value(value); } } break; case EFeature::TypeCase::kFloatLists: for (const auto &elist : efeat.float_lists().list()) { auto *ilist = ifeat->add_float_list(); for (const auto &value : elist.value()) { ilist->add_value(value); } } break; case EFeature::TypeCase::kInt64Lists: for (const auto &elist : efeat.int64_lists().list()) { auto *ilist = ifeat->add_int64_list(); for (const auto &value : elist.value()) { ilist->add_value(value); } } break; case EFeature::TypeCase::kBytesLists: for (const auto &elist : efeat.bytes_lists().list()) { auto *ilist = ifeat->add_bytes_list(); for (const auto &value : elist.value()) { ilist->add_value(value); } } break; default: break; } break; } } return Status::OK(); } void ExtendExample(Example *pb, FeatureNameMapper *mapper /* = nullptr*/) { bool has_line_id = false, has_label = false, has_instance_weight = false; for (uint i = 0; i < pb->named_feature_size(); ++i) { auto &named_feature = *(pb->mutable_named_feature(i)); if (mapper) { int id; int32_t sorted_id = -1; if (mapper->GetIdByName(named_feature.name(), &id, &sorted_id)) { named_feature.set_sorted_id(sorted_id); } } if (named_feature.name() == "__LINE_ID__") { has_line_id = true; const auto &line_id = named_feature.feature().bytes_list().value(0); pb->mutable_line_id()->ParseFromArray(line_id.data(), line_id.size()); } else if (named_feature.name() == "__LABEL__") { has_label = true; const auto &float_list = named_feature.feature().float_list(); for (const auto &value : float_list.value()) { pb->add_label(value); } } else if (named_feature.name() == "instance_weight") { has_instance_weight = true; float instance_weight = 1.0; if (named_feature.feature().float_list().value_size() > 0) { instance_weight = named_feature.feature().float_list().value(0); } pb->set_instance_weight(instance_weight); } if (has_line_id && has_label && has_instance_weight) { break; } } } Status ExampleToInstance(Example *example, Instance *instance) { for (const auto &named_feature : example->named_feature()) { std::string name = named_feature.name(); const EFeature &efeat = named_feature.feature(); TF_RETURN_IF_ERROR(AddFeature(name, efeat, instance)); } // (todo): named_raw_feature is not supported in instance return Status::OK(); } Status InstanceToExample(Instance *instance, Example *example) { int index = 0; if (instance->has_line_id()) { example->mutable_line_id()->CopyFrom(instance->line_id()); } if (instance->label_size() > 0) { for (const auto &value : instance->label()) { example->add_label(value); } } if (instance->has_instance_weight()) { example->set_instance_weight(instance->instance_weight()); } else { example->set_instance_weight(1.0); } if (instance->value_size() > 0) { auto *named_feature = example->add_named_feature(); named_feature->set_name("value"); named_feature->set_id(index++); auto *efeat = named_feature->mutable_feature(); auto *float_list = efeat->mutable_float_list(); for (const auto &value : instance->value()) { float_list->add_value(value); } } std::unordered_map slot_to_efeat_; const auto &fids = instance->fid(); for (const auto &fid : fids) { int slot_id = fid >> 54; auto it = slot_to_efeat_.find(slot_id); ::monolith::io::proto::FidList *fid_v1_list = nullptr; if (it != slot_to_efeat_.end()) { fid_v1_list = it->second; } else { auto *named_feature = example->add_named_feature(); named_feature->set_name(absl::StrCat("fc_slot_", slot_id)); named_feature->set_id(index++); fid_v1_list = named_feature->mutable_feature()->mutable_fid_v1_list(); slot_to_efeat_.emplace(slot_id, fid_v1_list); } fid_v1_list->add_value(fid); } for (const auto &ifeat : instance->feature()) { auto *named_feature = example->add_named_feature(); named_feature->set_name(ifeat.name()); named_feature->set_id(index++); auto *efeat = named_feature->mutable_feature(); if (ifeat.fid_size() > 0) { auto *list = efeat->mutable_fid_v2_list(); for (const auto &value : ifeat.fid()) { list->add_value(value); } } else if (ifeat.float_value_size() > 0) { auto *list = efeat->mutable_float_list(); for (const auto &value : ifeat.float_value()) { list->add_value(value); } } else if (ifeat.int64_value_size() > 0) { auto *list = efeat->mutable_int64_list(); for (const auto &value : ifeat.int64_value()) { list->add_value(value); } } else if (ifeat.bytes_value_size() > 0) { auto *bytes_list = efeat->mutable_bytes_list(); for (const auto &value : ifeat.bytes_value()) { bytes_list->add_value(value); } } else if (ifeat.fid_list_size() > 0) { auto *elists = efeat->mutable_fid_v2_lists(); for (const auto &ilist : ifeat.fid_list()) { auto *elist = elists->add_list(); for (const auto &value : ilist.value()) { elist->add_value(value); } } } else if (ifeat.float_list_size() > 0) { auto *elists = efeat->mutable_float_lists(); for (const auto &ilist : ifeat.float_list()) { auto *list = elists->add_list(); for (const auto &value : ilist.value()) { list->add_value(value); } } } else if (ifeat.int64_list_size() > 0) { auto *elists = efeat->mutable_int64_lists(); for (const auto &ilist : ifeat.int64_list()) { auto *list = elists->add_list(); for (const auto &value : ilist.value()) { list->add_value(value); } } } else if (ifeat.bytes_list_size() > 0) { auto *elists = efeat->mutable_bytes_lists(); for (const auto &ilist : ifeat.bytes_list()) { auto *list = elists->add_list(); for (const auto &value : ilist.value()) { list->add_value(value); } } } else { LOG(INFO) << absl::StrCat("empty ", ifeat.name()); } } return Status::OK(); } Status ExampleBatchToInstance(ExampleBatch *example_batch, int index, Instance *instance) { for (const auto &named_feature_list : example_batch->named_feature_list()) { // NamedFeatureList const std::string &name = named_feature_list.name(); const EFeature &efeat = named_feature_list.type() == FeatureListType::SHARED ? named_feature_list.feature(0) : named_feature_list.feature(index); TF_RETURN_IF_ERROR(AddFeature(name, efeat, instance)); } instance->set_data_source_key(example_batch->data_source_key()); // (todo): named_raw_feature is not supported in instance return Status::OK(); } Status ExampleBatchToExample(ExampleBatch *example_batch, int index, Example *example, FeaturePruningType feature_pruning_type, FeatureNameMapper *mapper) { profiler::TraceMe activity([]() { return "ExampleBatchToExample"; }); for (const auto &named_feature : example_batch->named_feature_list()) { if (named_feature.type() != FeatureListType::SHARED) { if (example_batch->batch_size() != named_feature.feature_size()) { std::string err_log = absl::StrFormat( "ExampleBatch batch_size should be equal to named_feature size, " "while got %d vs %d for feature_name %s", example_batch->batch_size(), named_feature.feature_size(), named_feature.name()); LOG(ERROR) << err_log; return errors::OutOfRange(err_log); } if (index >= named_feature.feature_size()) { std::string err_log = absl::StrFormat( "index should be less than named_feature size, " "while got %d vs %d for feautre_name %s", index, named_feature.feature_size(), named_feature.name()); LOG(ERROR) << err_log; return errors::OutOfRange(err_log); } } else { if (named_feature.feature_size() == 0) { std::string err_log = absl::StrFormat( "named_feature size should be positive while got %d for " "feature_name %s", named_feature.feature_size(), named_feature.name()); LOG(ERROR) << err_log; return errors::OutOfRange(err_log); } } const auto &efeat = named_feature.type() == FeatureListType::SHARED ? named_feature.feature(0) : named_feature.feature(index); if (named_feature.name() == "__LINE_ID__") { const auto &line_id = efeat.bytes_list().value(0); bool ok = example->mutable_line_id()->ParseFromArray(line_id.data(), line_id.size()); if (!ok) { return errors::FailedPrecondition("Failed to parse the LineId"); } } else if (named_feature.name() == "__LABEL__") { const auto &float_list = efeat.float_list(); for (const auto &value : float_list.value()) { example->add_label(value); } } else if (named_feature.name() == "instance_weight") { float instance_weight = 1.0; if (efeat.float_list().value_size() > 0) { instance_weight = efeat.float_list().value(0); } example->set_instance_weight(instance_weight); } else if (feature_pruning_type != PRUNING_FEATURE) { // FeatureNameMapper if (mapper == nullptr) { return errors::InvalidArgument( "FeatureNameMapper should be specified, while we got " "mapper==nullptr"); } if (mapper->IsAvailable()) { int32_t id = -1; int32_t sorted_id = -1; bool found = mapper->GetIdByName(named_feature.name(), &id, &sorted_id); if (found) { auto *out = example->add_named_feature(); out->set_name(named_feature.name()); out->set_id(named_feature.id()); out->set_sorted_id(sorted_id); out->mutable_feature()->MergeFrom(efeat); } } else { LOG_FIRST_N(INFO, 1) << "FeatureNameMapper is not available!"; auto *out = example->add_named_feature(); out->set_name(named_feature.name()); out->set_id(named_feature.id()); out->mutable_feature()->MergeFrom(efeat); } } } if (feature_pruning_type != PRUNING_RAW_FEATURE) { for (const auto &named_feature : example_batch->named_raw_feature_list()) { const auto &efeat = named_feature.type() == FeatureListType::SHARED ? named_feature.raw_feature(0) : named_feature.raw_feature(index); auto *out = example->add_named_raw_feature(); out->set_name(named_feature.name()); out->set_id(named_feature.id()); out->mutable_raw_feature()->MergeFrom(efeat); } } example->set_data_source_key(example_batch->data_source_key()); return Status::OK(); } InputStreamReader::InputStreamReader( const DataFormatOptions &options, std::unique_ptr input_stream) : BaseStreamReader(options), input_stream_(std::move(input_stream)), last_read_failed_(false) {} namespace { std::unique_ptr CreateInputFileStream( RandomAccessFile *f, const InputCompressType compression_type, int64 buffer_size) { buffer_size = buffer_size ? buffer_size : kDEFAULT_SNAPPY_BUFFER_SIZE; if (compression_type == InputCompressType::SNAPPY) { return std::make_unique(f, buffer_size, buffer_size); } else if (compression_type == InputCompressType::ZSTD) { auto s = std::make_unique(f); return std::make_unique( s.release(), buffer_size, buffer_size); } else if (compression_type == InputCompressType::ZLIB || compression_type == InputCompressType::GZIP) { const io::ZlibCompressionOptions zlib_options = compression_type == InputCompressType::ZLIB ? io::ZlibCompressionOptions::DEFAULT() : io::ZlibCompressionOptions::GZIP(); auto s = std::make_unique(f); return std::make_unique( s.release(), static_cast(buffer_size), static_cast(buffer_size), zlib_options); } else { auto s = std::make_unique(f); return std::make_unique(s.release(), buffer_size, true); } } } // namespace FileStreamReader::FileStreamReader(const DataFormatOptions &options, std::unique_ptr f, const InputCompressType compression_type, int64 buffer_size) : InputStreamReader(options, CreateInputFileStream( f.get(), compression_type, buffer_size)), f_(std::move(f)) {} Status InputStreamReader::ReadNBytes(size_t n, tstring *result) { if (n >= SIZE_MAX - sizeof(uint32)) { return errors::DataLoss("record size too large"); } TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(n, result)); if (result->size() != n) { last_read_failed_ = true; if (result->empty()) { return errors::OutOfRange("eof"); } else { return errors::DataLoss("truncated record"); } } return Status::OK(); } uint64 InputStreamReader::GetOffset() { return input_stream_->Tell(); } Status InputStreamReader::SetOffset(uint64 *offset) { int64 curr_pos = input_stream_->Tell(); int64 desired_pos = static_cast(*offset); if (curr_pos > desired_pos || curr_pos < 0 /* EOF */ || (curr_pos == desired_pos && last_read_failed_)) { last_read_failed_ = false; TF_RETURN_IF_ERROR(input_stream_->Reset()); TF_RETURN_IF_ERROR(input_stream_->SkipNBytes(desired_pos)); } else if (curr_pos < desired_pos) { TF_RETURN_IF_ERROR(input_stream_->SkipNBytes(desired_pos - curr_pos)); } DCHECK_EQ(desired_pos, input_stream_->Tell()); return Status::OK(); } // TODO(leqi.zou): Make input stream async and static. // Currently the problem is if we are unable to read N bytes, // the code is not cancellable (CTRL + C not working). StdinStreamReader::StdinStreamReader(const DataFormatOptions &options, int64 buffer_size) : BaseStreamReader(options), buffer_size_(buffer_size) { LOG_FIRST_N(INFO, 1) << "Init stdin read buffer_size: " << buffer_size_; input_stream_.reset(&std::cin, [](...) {}); buffer_.reset(new char[buffer_size]); } Status StdinStreamReader::ReadNBytes(size_t n, tstring *result) { if (n > buffer_size_) { // return errors::DataLoss( // "Buffer size may be too small! Should be bigger than ", n); // TODO(ltli): 临时支持推荐平台项目,sandbox 测试数据格式错误过不了检查 LOG(WARNING) << "Data header abnormal! read too big size: " << n << ". pls check PBDataset params: lagrangex_header, sort_id"; LOG(INFO) << "stdin meets EOF"; return errors::OutOfRange("eof"); } if (!input_stream_->read(buffer_.get(), n)) { if (input_stream_->eof()) { LOG(INFO) << "stdin meets EOF"; return errors::OutOfRange("eof"); } else { return errors::DataLoss("streaming load broken"); } } offset_ += n; result->assign(buffer_.get(), n); return Status::OK(); } uint64 StdinStreamReader::GetOffset() { return offset_; } Status StdinStreamReader::SetOffset(uint64 *offset) { if (offset_ < *offset) { tstring buf; size_t size = *offset - offset_; bool result = static_cast(input_stream_->read(buf.data(), size)); if (input_stream_->eof()) { return errors::OutOfRange("eof"); } else if (input_stream_->fail() || !result) { return errors::DataLoss("streaming load broken"); } offset_ = *offset; return Status::OK(); } if (offset_ == *offset) { return Status::OK(); } else { return errors::FailedPrecondition( "Cannot set the offset of stdin ahead of current position"); } } PBIterator::PBIterator(std::unique_ptr reader, FeaturePruningType feature_pruning_type) : feature_pruning_type_(feature_pruning_type), reader_(std::move(reader)), counter_(std::make_unique()) {} Status PBIterator::next(uint64 *offset, uint32_t *data_source_key, tstring *serialized) { uint8_t pb_type; reader_->SetOffset(offset); TF_RETURN_IF_ERROR( reader_->ReadPBBytes(&pb_type, data_source_key, serialized)); return Status::OK(); } Status PBIterator::next(uint64 *offset, Instance *pb) { tstring buf; uint32_t data_source_key; TF_RETURN_IF_ERROR(next(offset, &data_source_key, &buf)); bool ok = pb->ParseFromArray(buf.data(), buf.size()); pb->set_data_source_key(data_source_key); if (ok) { return Status::OK(); } else { return errors::FailedPrecondition("Failed to parse the Instance."); } } Status PBIterator::next(uint64 *offset, Example *pb) { tstring buf; uint32_t data_source_key; TF_RETURN_IF_ERROR(next(offset, &data_source_key, &buf)); bool ok = pb->ParseFromArray(buf.data(), buf.size()); pb->set_data_source_key(data_source_key); if (ok) { ExtendExample(pb); counter_->AddByteSize(pb->ByteSizeLong()); if (feature_pruning_type_ == PRUNING_FEATURE) { auto *named_features = pb->mutable_named_feature(); named_features->erase(named_features->begin(), named_features->end()); } else if (feature_pruning_type_ == PRUNING_RAW_FEATURE) { auto *named_raw_feature = pb->mutable_named_raw_feature(); named_raw_feature->erase(named_raw_feature->cbegin(), named_raw_feature->cend()); } counter_->AddByteSizePruned(pb->ByteSizeLong()); LOG_EVERY_N_SEC(INFO, 180) << counter_->DebugString(); return Status::OK(); } else { return errors::FailedPrecondition("Failed to parse the Example."); } } Status PBIterator::next(uint64 *offset, ExampleBatch *pb) { tstring buf; uint32_t data_source_key; TF_RETURN_IF_ERROR(next(offset, &data_source_key, &buf)); bool ok = pb->ParseFromArray(buf.data(), buf.size()); pb->set_data_source_key(data_source_key); counter_->AddByteSize(pb->ByteSizeLong()); if (feature_pruning_type_ == PRUNING_FEATURE) { auto *named_feature_list = pb->mutable_named_feature_list(); auto it = named_feature_list->begin(); while (it != named_feature_list->end()) { if (it->name() != "__LABEL__" && it->name() != "__LINE_ID__" && it->name() != "instance_weight") { // if erase, it will move to the next element named_feature_list->erase(it); } else { ++it; } } } else if (feature_pruning_type_ == PRUNING_RAW_FEATURE) { auto *named_raw_feature_list = pb->mutable_named_raw_feature_list(); named_raw_feature_list->erase(named_raw_feature_list->begin(), named_raw_feature_list->end()); } counter_->AddByteSizePruned(pb->ByteSizeLong()); LOG_EVERY_N_SEC(INFO, 180) << counter_->DebugString(); if (ok) { return Status::OK(); } else { return errors::FailedPrecondition("Failed to parse the ExampleBatch."); } } uint64 PBIterator::GetOffset() { return reader_->GetOffset(); } Status PBIterator::SetOffset(uint64 *offset) { return reader_->SetOffset(offset); } ExampleBatchIterator::ExampleBatchIterator( std::unique_ptr reader, FeaturePruningType feature_pruning_type, FeatureNameMapper *mapper) : PBIterator(std::move(reader), feature_pruning_type), mapper_(mapper) { arena_ = std::make_unique(); cur_ = google::protobuf::Arena::CreateMessage(arena_.get()); } Status ExampleBatchIterator::next_internal(uint64 *offset) { if (index_ < batch_size_ - 1) { index_++; return Status::OK(); } profiler::TraceMe activity([]() { return "ReadAndDeserialize"; }); uint8_t pb_type; uint32_t data_source_key; tstring buf; reader_->SetOffset(offset); arena_ = std::make_unique(); cur_ = google::protobuf::Arena::CreateMessage(arena_.get()); TF_RETURN_IF_ERROR(reader_->ReadPBBytes(&pb_type, &data_source_key, &buf)); bool ok = cur_->ParseFromArray(buf.data(), buf.size()); counter_->AddByteSize(cur_->ByteSizeLong()); cur_->set_data_source_key(data_source_key); if (!ok) { return errors::FailedPrecondition("Failed to parse the ExampleBatch."); } else { index_ = 0; batch_size_ = cur_->batch_size(); return Status::OK(); } } Status ExampleBatchIterator::next(uint64 *offset, uint32_t *data_source_key, tstring *serialized) { uint8_t pb_type; reader_->SetOffset(offset); TF_RETURN_IF_ERROR( reader_->ReadPBBytes(&pb_type, data_source_key, serialized)); return Status::OK(); } Status ExampleBatchIterator::next(uint64 *offset, ExampleBatch *pb) { uint8_t pb_type; uint32_t data_source_key; tstring buf; reader_->SetOffset(offset); TF_RETURN_IF_ERROR(reader_->ReadPBBytes(&pb_type, &data_source_key, &buf)); bool ok = pb->ParseFromArray(buf.data(), buf.size()); pb->set_data_source_key(data_source_key); counter_->AddByteSize(pb->ByteSizeLong()); if (feature_pruning_type_ == PRUNING_FEATURE) { auto *named_feature_list = pb->mutable_named_feature_list(); auto it = named_feature_list->begin(); while (it != named_feature_list->end()) { if (it->name() != "__LABEL__" && it->name() != "__LINE_ID__" && it->name() != "instance_weight") { // if erase, it will move to the next element named_feature_list->erase(it); } else { ++it; } } } else if (feature_pruning_type_ == PRUNING_RAW_FEATURE) { auto *named_raw_feature_list = pb->mutable_named_raw_feature_list(); named_raw_feature_list->erase(named_raw_feature_list->begin(), named_raw_feature_list->end()); } counter_->AddByteSizePruned(pb->ByteSizeLong()); LOG_EVERY_N_SEC(INFO, 180) << counter_->DebugString(); if (!ok) { return errors::FailedPrecondition("Failed to parse the ExampleBatch."); } else { return Status::OK(); } } Status ExampleBatchIterator::next(uint64 *offset, Instance *pb) { TF_RETURN_IF_ERROR(next_internal(offset)); return ExampleBatchToInstance(cur_, index_, pb); } Status ExampleBatchIterator::next(uint64 *offset, Example *pb) { profiler::TraceMe activity([]() { return "ExampleBatchIteratorNext"; }); TF_RETURN_IF_ERROR(next_internal(offset)); Status s = ExampleBatchToExample(cur_, index_, pb, feature_pruning_type_, mapper_); counter_->AddByteSizePruned(pb->ByteSizeLong()); LOG_EVERY_N_SEC(INFO, 3600) << counter_->DebugString(); return s; } } // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/data/training_instance/cc/data_reader.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_NATIVE_TRAINING_DATA_TRAINING_INSTANCE_CC_DATA_READER_H_ #define MONOLITH_NATIVE_TRAINING_DATA_TRAINING_INSTANCE_CC_DATA_READER_H_ #include "absl/container/flat_hash_map.h" #include "tensorflow/core/framework/variant.h" #include "tensorflow/core/lib/core/coding.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/io/inputstream_interface.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/file_system.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" #include "monolith/native_training/data/training_instance/cc/data_format_options.h" #include "monolith/native_training/data/training_instance/cc/pb_variant.h" #include "monolith/native_training/data/training_instance/cc/reader_util.h" namespace tensorflow { namespace monolith_tf { enum FeaturePruningType { AS_IS = 0, PRUNING_FEATURE = 1, PRUNING_RAW_FEATURE = 2 }; namespace data_format { enum DataFormat { UNKNOW = 0, PLAINTEXT = 1, INSTANCE = 2, EXAMPLE = 3, EXAMPLEBATCH = 4 }; DataFormat StringToDataFormat(const std::string &type); }; // namespace data_format void ExtendExample(::monolith::io::proto::Example *pb, FeatureNameMapper *mapper = nullptr); Status ExampleToInstance(::monolith::io::proto::Example *example, ::parser::proto::Instance *instance); Status InstanceToExample(::parser::proto::Instance *instance, ::monolith::io::proto::Example *example); Status ExampleBatchToInstance( ::monolith::io::proto::ExampleBatch *example_batch, int index, ::parser::proto::Instance *instance); Status ExampleBatchToExample(::monolith::io::proto::ExampleBatch *example_batch, int index, ::monolith::io::proto::Example *example, FeaturePruningType feature_pruning_type, FeatureNameMapper *mapper); template class BaseStreamReaderTmpl { public: explicit BaseStreamReaderTmpl(const DataFormatOptions &options) : options_(options) {} virtual ~BaseStreamReaderTmpl() = default; Status ReadPBBytes(uint8_t *pb_type, uint32_t *data_source_key, T *record) { TF_RETURN_IF_ERROR(ReadDataHeader(pb_type, data_source_key)); size_t size; TF_RETURN_IF_ERROR(ReadBinarySize(&size)); // Don't know whether FALLBACK_RESERVE_VALUE is in use. if (size == 0xfefefefe) { return errors::InvalidArgument("DEADBEEF value found"); } TF_RETURN_IF_ERROR(ReadNBytes(size, record)); return Status::OK(); } virtual uint64 GetOffset() = 0; virtual Status SetOffset(uint64 *offset) = 0; protected: virtual Status ReadNBytes(size_t n, T *result) = 0; private: Status ReadDataHeader(uint8_t *pb_type, uint32_t *data_source_key) { size_t size = 0, aggregate_page_sortid_size = 0; if (options_.lagrangex_header) { // *dtype = ins_type == 0 ? PROTO_INSTANCE : EXAMPLE_PB; TF_RETURN_IF_ERROR(ReadBinarySize(&size)); uint64_t lgx_header = static_cast(size); *pb_type = static_cast(lgx_header & 0xff); uint32_t source = static_cast(lgx_header); *data_source_key = (source >> 8) << 8; } else { *pb_type = 0; if (options_.kafka_dump_prefix) { TF_RETURN_IF_ERROR(ReadBinarySize(&size)); if (size == 0) { TF_RETURN_IF_ERROR(ReadBinarySize(&size)); } else { aggregate_page_sortid_size = size; } } if (options_.has_sort_id) { if (aggregate_page_sortid_size == 0) { TF_RETURN_IF_ERROR(ReadBinarySize(&size)); } else { size = aggregate_page_sortid_size; } T sort_id; TF_RETURN_IF_ERROR(ReadNBytes(size, &sort_id)); } if (options_.kafka_dump) { TF_RETURN_IF_ERROR(ReadBinarySize(&size)); } } return Status::OK(); } Status ReadBinarySize(size_t *size) { T result; TF_RETURN_IF_ERROR(ReadNBytes(sizeof(size_t), &result)); *size = static_cast(core::DecodeFixed64(result.data())); return Status::OK(); } DataFormatOptions options_; }; using BaseStreamReader = BaseStreamReaderTmpl; class StdinStreamReader : public BaseStreamReader { public: explicit StdinStreamReader(const DataFormatOptions &options, int64 buffer_size = 64 * 1024 * 1024); ~StdinStreamReader() override = default; uint64 GetOffset() override; Status SetOffset(uint64 *offset) override; protected: Status ReadNBytes(size_t n, tstring *result) override; private: std::shared_ptr input_stream_; std::unique_ptr buffer_; uint64 offset_; int64 buffer_size_; TF_DISALLOW_COPY_AND_ASSIGN(StdinStreamReader); }; class InputStreamReader : public BaseStreamReader { public: explicit InputStreamReader( const DataFormatOptions &options, std::unique_ptr input_stream); ~InputStreamReader() override = default; uint64 GetOffset() override; Status SetOffset(uint64 *offset) override; private: Status ReadNBytes(size_t n, tstring *result) override; std::unique_ptr input_stream_; bool last_read_failed_; TF_DISALLOW_COPY_AND_ASSIGN(InputStreamReader); }; enum InputCompressType { UNKNOW = 0, NO = 1, SNAPPY = 2, ZSTD = 3, ZLIB = 4, GZIP = 5, MAX = 6 }; class FileStreamReader : public InputStreamReader { public: explicit FileStreamReader(const DataFormatOptions &options, std::unique_ptr f, const InputCompressType compression_type, int64 buffer_size = 64 * 1024 * 1024); static InputCompressType GetCompressType(const bool use_snappy, const int32 compression_type) { if (compression_type < InputCompressType::UNKNOW || compression_type >= InputCompressType::MAX) { LOG(FATAL) << "GetCompressType error : compression_type" << compression_type; } InputCompressType ret = InputCompressType::NO; if (use_snappy) { if (compression_type != InputCompressType::SNAPPY && compression_type != InputCompressType::UNKNOW) { LOG(FATAL) << "GetCompressType error: " << use_snappy << "," << compression_type; } ret = InputCompressType::SNAPPY; } else { if (compression_type == InputCompressType::UNKNOW) { ret = InputCompressType::NO; } else { ret = static_cast(compression_type); } } return ret; } private: std::unique_ptr f_; }; template class StringStreamReader : public BaseStreamReaderTmpl { public: explicit StringStreamReader(const DataFormatOptions &options, T content) : BaseStreamReaderTmpl(options), content_(std::move(content)), cur_(0) {} Status ReadNBytes(size_t n, T *result) override { if (cur_ + n > content_.size()) { return errors::FailedPrecondition("request n error"); } if (n > 0 && cur_ == content_.size()) { return errors::OutOfRange("Size exceeds he content size."); } *result = T(content_.data() + cur_, n); cur_ += n; return Status::OK(); } uint64 GetOffset() override { return cur_; } Status SetOffset(uint64 *offset) override { cur_ = *offset; return Status::OK(); } private: T content_; int64 cur_; }; using ZeroCopyStringViewStreamReader = StringStreamReader; class PBIterator { public: PBIterator() = default; explicit PBIterator(std::unique_ptr reader, FeaturePruningType feature_pruning_type); virtual ~PBIterator() = default; virtual Status next(uint64 *offset, uint32_t *data_source_key, tstring *serialized); virtual Status next(uint64 *offset, ::parser::proto::Instance *pb); virtual Status next(uint64 *offset, ::monolith::io::proto::Example *pb); virtual Status next(uint64 *offset, ::monolith::io::proto::ExampleBatch *pb); uint64 GetOffset(); Status SetOffset(uint64 *offset); protected: FeaturePruningType feature_pruning_type_ = PRUNING_RAW_FEATURE; std::unique_ptr reader_; std::unique_ptr counter_; TF_DISALLOW_COPY_AND_ASSIGN(PBIterator); }; class ExampleBatchIterator : public PBIterator { public: ExampleBatchIterator() = default; explicit ExampleBatchIterator(std::unique_ptr reader, FeaturePruningType feature_pruning_type, FeatureNameMapper *mapper); Status next(uint64 *offset, uint32_t *data_source_key, tstring *serialized); Status next(uint64 *offset, ::monolith::io::proto::ExampleBatch *pb); Status next(uint64 *offset, ::parser::proto::Instance *pb) override; Status next(uint64 *offset, ::monolith::io::proto::Example *pb) override; private: Status next_internal(uint64 *offset); int index_ = 0, batch_size_ = 0; monolith::io::proto::ExampleBatch *cur_; std::unique_ptr arena_; FeatureNameMapper *mapper_; TF_DISALLOW_COPY_AND_ASSIGN(ExampleBatchIterator); }; /* class THanler { struct CurOutput : public PBIteratorWithDataFormatTransBaseOutput { }; template Status HandleReaderNextStauts(const Status &s, const TResult &result) { return errors::Unimplemented("not implement"); } template Status HandleResult(TResult &&result, CurOutput *output) { return errors::Unimplemented("not implement"); } }; */ struct PBIteratorWithDataFormatTransBaseOutput { Status reader_status; }; template class PBIteratorWithDataFormatTrans : public THanler { public: PBIteratorWithDataFormatTrans(data_format::DataFormat input_pb_type, data_format::DataFormat output_pb_type) : input_pb_type_(input_pb_type), output_pb_type_(output_pb_type) {} Status GetNext(PBIterator *reader, typename THanler::CurOutput *output, uint64 *offset) { Status s; if (output_pb_type_ == data_format::PLAINTEXT) { tstring serialized; uint32_t data_source_key; output->reader_status = reader->next(offset, &data_source_key, &serialized); s = THanler::HandleReaderNextStauts(output->reader_status, serialized); if (!s.ok()) return s; s = THanler::HandleResult(std::move(serialized), output); } else if (input_pb_type_ == data_format::EXAMPLE && output_pb_type_ == data_format::INSTANCE) { ::monolith::io::proto::Example exa_pb; output->reader_status = reader->next(offset, &exa_pb); s = THanler::HandleReaderNextStauts(output->reader_status, exa_pb); if (!s.ok()) return s; ::parser::proto::Instance ins_pb; ExampleToInstance(&exa_pb, &ins_pb); s = THanler::HandleResult(std::move(ins_pb), output); } else if (input_pb_type_ == data_format::INSTANCE && output_pb_type_ == data_format::EXAMPLE) { ::parser::proto::Instance ins_pb; output->reader_status = reader->next(offset, &ins_pb); s = THanler::HandleReaderNextStauts(output->reader_status, ins_pb); if (!s.ok()) return s; ::monolith::io::proto::Example exa_pb; InstanceToExample(&ins_pb, &exa_pb); s = THanler::HandleResult(std::move(exa_pb), output); } else if (output_pb_type_ == data_format::EXAMPLE) { // any -> // example ::monolith::io::proto::Example exa_pb; output->reader_status = reader->next(offset, &exa_pb); s = THanler::HandleReaderNextStauts(output->reader_status, exa_pb); if (!s.ok()) return s; s = THanler::HandleResult(std::move(exa_pb), output); } else if (output_pb_type_ == data_format::INSTANCE) { // any -> // instance ::parser::proto::Instance ins_pb; output->reader_status = reader->next(offset, &ins_pb); s = THanler::HandleReaderNextStauts(output->reader_status, ins_pb); if (!s.ok()) return s; s = THanler::HandleResult(std::move(ins_pb), output); } else { // any -> example_batch ::monolith::io::proto::ExampleBatch eb_pb; output->reader_status = reader->next(offset, &eb_pb); s = THanler::HandleReaderNextStauts(output->reader_status, eb_pb); if (!s.ok()) return s; s = THanler::HandleResult(std::move(eb_pb), output); } return s; } data_format::DataFormat output_pb_type_; data_format::DataFormat input_pb_type_; }; } // namespace monolith_tf } // namespace tensorflow #endif // MONOLITH_NATIVE_TRAINING_DATA_TRAINING_INSTANCE_CC_DATA_READER_H_ ================================================ FILE: monolith/native_training/data/training_instance/cc/data_writer.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/data/training_instance/cc/data_writer.h" #include "absl/strings/str_cat.h" #include "tensorflow/core/platform/coding.h" namespace tensorflow { namespace monolith_tf { Status BaseStreamWriter::PrepareHeader() { if (options_.lagrangex_header) { TF_RETURN_IF_ERROR(Write(std::string(8, 0))); } else { if (options_.kafka_dump_prefix) { TF_RETURN_IF_ERROR(Write(std::string(16, 0))); } if (options_.has_sort_id) { TF_RETURN_IF_ERROR(Write(std::string(8, 0))); } if (options_.kafka_dump) { TF_RETURN_IF_ERROR(Write(std::string(8, 0))); } } return Status::OK(); } BaseStreamWriter::BaseStreamWriter(DataFormatOptions options) : options_(std::move(options)) {} Status BaseStreamWriter::WriteRecord(absl::string_view record) { TF_RETURN_IF_ERROR(PrepareHeader()); char size_encoded[8]; core::EncodeFixed64(size_encoded, record.size()); TF_RETURN_IF_ERROR(Write(absl::string_view(size_encoded, 8))); TF_RETURN_IF_ERROR(Write(record)); return Status::OK(); } Status StringStreamWriter::Write(absl::string_view s) { absl::StrAppend(out_, s); return Status::OK(); } } // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/data/training_instance/cc/data_writer.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_NATIVE_TRAINING_DATA_TRAINING_INSTANCE_CC_DATA_WRITER_H_ #define MONOLITH_NATIVE_TRAINING_DATA_TRAINING_INSTANCE_CC_DATA_WRITER_H_ #include "absl/strings/string_view.h" #include "monolith/native_training/data/training_instance/cc/data_format_options.h" #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/status.h" namespace tensorflow { namespace monolith_tf { class BaseStreamWriter { public: explicit BaseStreamWriter(DataFormatOptions options); Status WriteRecord(absl::string_view record); protected: virtual Status Write(absl::string_view s) = 0; private: Status PrepareHeader(); private: DataFormatOptions options_; }; class StringStreamWriter : public BaseStreamWriter { public: explicit StringStreamWriter(DataFormatOptions options, std::string* out) : BaseStreamWriter(std::move(options)), out_(out) {} private: Status Write(absl::string_view s) override; std::string* out_; }; } // namespace monolith_tf } // namespace tensorflow #endif // MONOLITH_NATIVE_TRAINING_DATA_TRAINING_INSTANCE_CC_DATA_WRITER_H_ ================================================ FILE: monolith/native_training/data/training_instance/cc/fid.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_MONOLITH_NATIVE_TRAINING_DATA_TRAINING_INSTANCE_CC_FID_H_ #define MONOLITH_MONOLITH_NATIVE_TRAINING_DATA_TRAINING_INSTANCE_CC_FID_H_ #include #include union FIDV2; union FIDV1 { struct Underlying { uint64_t signature : 54; uint64_t slot : 10; Underlying(uint64_t slot, uint64_t signature) : slot(slot), signature(signature) {} }; Underlying underlying; uint64_t value; FIDV1() : underlying(0, 0) {} FIDV1(uint64_t slot, int64_t signature) : underlying(slot, signature) { if (slot >= 1024) { throw std::invalid_argument("slot should be less than 1024, while got " + std::to_string(slot)); } } FIDV1(uint64_t fid_v1_value) : value(fid_v1_value) {} operator uint64_t() const { return this->value; } [[nodiscard]] uint64_t slot() const { return this->underlying.slot; } [[nodiscard]] uint64_t signature() const { return this->underlying.signature; } [[nodiscard]] std::string DebugString() const { std::stringstream ss; ss << value << "(v1|slot=" << underlying.slot << "|sig=" << underlying.signature << ")"; return ss.str(); } [[nodiscard]] FIDV2 ConvertAsV2() const; }; union FIDV2 { struct Underlying { uint64_t signature : 48; uint64_t slot : 15; uint64_t reserved : 1; Underlying(uint64_t slot, uint64_t signature) : reserved(0), slot(slot), signature(signature) {} }; Underlying underlying; uint64_t value; FIDV2() : underlying(0, 0) {} FIDV2(uint64_t slot, uint64_t signature) : underlying(slot, signature) { if (slot >= 32768) { throw std::invalid_argument("slot should be less than 32768, while got " + std::to_string(slot)); } } FIDV2(uint64_t fid_v2_value) : value(fid_v2_value) { if (this->underlying.reserved == 1) { throw std::invalid_argument("slot should be less than 32768, while got " + std::to_string(this->slot() + 32768)); } } operator uint64_t() const { return value; } [[nodiscard]] uint64_t slot() const { return this->underlying.slot; } [[nodiscard]] uint64_t signature() const { return this->underlying.signature; } [[nodiscard]] std::string DebugString() const { std::stringstream ss; ss << value << "(v2|slot=" << underlying.slot << "|sig=" << underlying.signature << ")"; return ss.str(); } }; FIDV2 FIDV1::ConvertAsV2() const { return {this->underlying.slot, this->underlying.signature}; } namespace std { template <> struct hash { std::size_t operator()(FIDV1 fid) const { return std::hash()(fid); } }; template <> struct hash { std::size_t operator()(FIDV2 fid) const { return std::hash()(fid); } }; } // namespace std #endif // MONOLITH_MONOLITH_NATIVE_TRAINING_DATA_TRAINING_INSTANCE_CC_FID_H_ ================================================ FILE: monolith/native_training/data/training_instance/cc/fid_test.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "gmock/gmock.h" #include "gtest/gtest.h" #include "monolith/native_training/data/training_instance/cc/fid.h" #include "monolith/native_training/data/training_instance/cc/reader_util.h" namespace { using tensorflow::monolith_tf::GetFidV1; using tensorflow::monolith_tf::GetFidV2; TEST(FIDTest, FIDV1) { // 8 bytes EXPECT_EQ(sizeof(FIDV1), 8); // normal case FIDV1 fid1(1, 100); EXPECT_EQ(fid1.slot(), 1); EXPECT_EQ(fid1.signature(), 100); EXPECT_EQ(fid1, GetFidV1(1, 100)); // corner case1 FIDV1 fid2(1023, 1LL << 54); EXPECT_EQ(fid2.slot(), 1023); EXPECT_EQ(fid2.signature(), 0); EXPECT_EQ(fid2, GetFidV1(1023, 1LL << 54)); // corner case2 EXPECT_THROW( { FIDV1 fid3(1024, 1LL << 54); EXPECT_EQ(fid3.slot(), 0); EXPECT_EQ(fid3.signature(), 0); EXPECT_EQ(fid3, GetFidV1(1024, 1LL << 54)); }, std::invalid_argument); // corner case3 EXPECT_THROW( { FIDV1 fid4(1025, 1LL << 54 | 1); EXPECT_EQ(fid4.slot(), 1); EXPECT_EQ(fid4.signature(), 1); EXPECT_EQ(fid4, GetFidV1(1025, 1LL << 54 | 1)); }, std::invalid_argument); } TEST(FIDTest, FIDV2) { // 8 bytes EXPECT_EQ(sizeof(FIDV2), 8); // normal case FIDV2 fid1(1, 100); EXPECT_EQ(fid1.slot(), 1); EXPECT_EQ(fid1.signature(), 100); EXPECT_EQ(fid1, GetFidV2(1, 100)); // corner case1 FIDV2 fid2(1024, 1LL << 54); EXPECT_EQ(fid2.slot(), 1024); EXPECT_EQ(fid2.signature(), 0); EXPECT_EQ(fid2, GetFidV2(1024, 1LL << 54)); // corner case2 FIDV2 fid3(32767, 1LL << 48); EXPECT_EQ(fid3.slot(), 32767); EXPECT_EQ(fid3.signature(), 0); EXPECT_EQ(fid3, GetFidV2(32767, 1LL << 48)); // corner case3 EXPECT_THROW( { FIDV2 fid4(32768, 1LL << 48); EXPECT_EQ(fid4.slot(), 0); EXPECT_EQ(fid4.signature(), 0); // GetFidV2 has a tiny bug EXPECT_EQ(fid4, (GetFidV2(32768, 1LL << 48) << 1) >> 1); }, std::invalid_argument); // corner case4 EXPECT_THROW( { FIDV2 fid5(32769, 1LL << 48 | 1); EXPECT_EQ(fid5.slot(), 1); EXPECT_EQ(fid5.signature(), 1); // GetFidV2 has a tiny bug EXPECT_EQ(fid5, (GetFidV2(32769, 1LL << 48 | 1) << 1) >> 1); }, std::invalid_argument); } TEST(FIDTest, FIDV1ConvertV2) { // normal case FIDV1 fid_v1(1, 100); FIDV2 fid_v2 = fid_v1.ConvertAsV2(); EXPECT_EQ(fid_v2.slot(), 1); EXPECT_EQ(fid_v2.signature(), 100); EXPECT_EQ(fid_v2, convert_fid_v1_to_v2(fid_v1)); // corner case1 FIDV1 fid_v1_1(1023, 1LL << 54); FIDV2 fid_v2_1 = fid_v1_1.ConvertAsV2(); EXPECT_EQ(fid_v2_1.slot(), 1023); EXPECT_EQ(fid_v2_1.signature(), 0); EXPECT_EQ(fid_v2_1, convert_fid_v1_to_v2(fid_v1_1)); // corner case2 EXPECT_THROW( { FIDV1 fid_v1_2(1024, 1LL << 54); FIDV2 fid_v2_2 = fid_v1_2.ConvertAsV2(); EXPECT_EQ(fid_v2_2.slot(), 0); EXPECT_EQ(fid_v2_2.signature(), 0); EXPECT_EQ(fid_v2_2, convert_fid_v1_to_v2(fid_v1_2)); }, std::invalid_argument); } } // namespace ================================================ FILE: monolith/native_training/data/training_instance/cc/instance_dataset_kernel.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/io/buffered_inputstream.h" #include "tensorflow/core/lib/io/inputbuffer.h" #include "tensorflow/core/lib/io/random_inputstream.h" #include "tensorflow/core/lib/io/zlib_compression_options.h" #include "tensorflow/core/lib/io/zlib_inputstream.h" #include "monolith/native_training/data/training_instance/cc/data_reader.h" namespace tensorflow { namespace data { namespace monolith_tf { using ::tensorflow::monolith_tf::BaseStreamReader; using ::tensorflow::monolith_tf::DataFormatOptions; using ::tensorflow::monolith_tf::FileStreamReader; using ::tensorflow::monolith_tf::InputCompressType; using ::tensorflow::monolith_tf::PBIterator; using ::tensorflow::monolith_tf::PRUNING_RAW_FEATURE; using ::tensorflow::monolith_tf::StdinStreamReader; struct DsOptions : DataFormatOptions { bool use_snappy = false; int32 compression_type = InputCompressType::UNKNOW; }; // This is the instance dataset op and used in the estimator as input fn. class InstanceDatasetOp : public DatasetOpKernel { public: static constexpr const char* const kDatasetType = "PbInstance"; static constexpr const char* const kFileName = "file_name"; static constexpr const char* const kUseSnappy = "use_snappy"; static constexpr const char* const kHasSortId = "has_sort_id"; static constexpr const char* const kKafkaDump = "kafka_dump"; static constexpr const char* const kKafkaDumpPrefix = "kafka_dump_prefix"; static constexpr const char* const kCompressionType = "compression_type"; explicit InstanceDatasetOp(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr(kCompressionType, &compression_type_)); } ~InstanceDatasetOp() {} private: void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { tstring file_name; DsOptions options; OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kFileName, &file_name)); OP_REQUIRES_OK( ctx, ParseScalarArgument(ctx, kUseSnappy, &options.use_snappy)); options.compression_type = compression_type_; OP_REQUIRES_OK( ctx, ParseScalarArgument(ctx, kHasSortId, &options.has_sort_id)); OP_REQUIRES_OK( ctx, ParseScalarArgument(ctx, kKafkaDump, &options.kafka_dump)); OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kKafkaDumpPrefix, &options.kafka_dump_prefix)); output_ = new Dataset(ctx, file_name, options); *output = output_; } class Dataset : public DatasetBase { public: explicit Dataset(OpKernelContext* ctx, const tstring& file_name, const DsOptions& options) : DatasetBase(DatasetContext(ctx)), file_name_(file_name), options_(options) {} std::unique_ptr MakeIteratorInternal( const string& prefix) const override { return absl::make_unique( Iterator::Params{this, strings::StrCat(prefix, "::", kDatasetType)}); } const DataTypeVector& output_dtypes() const override { static DataTypeVector* dtypes = new DataTypeVector({DT_STRING}); return *dtypes; } const std::vector& output_shapes() const override { static std::vector* shapes = new std::vector{TensorShape({})}; return *shapes; } string DebugString() const override { return ("This is the customized Instance Dataset: " + file_name_); } Status CheckExternalState() const override { return Status::OK(); } private: Status AsGraphDefInternal(SerializationContext* ctx, DatasetGraphDefBuilder* b, Node** output) const override { Node* filename = nullptr; TF_RETURN_IF_ERROR(b->AddScalar(file_name_, &filename)); Node* use_snappy = nullptr; TF_RETURN_IF_ERROR(b->AddScalar(options_.use_snappy, &use_snappy)); Node* has_sort_id = nullptr; TF_RETURN_IF_ERROR(b->AddScalar(options_.has_sort_id, &has_sort_id)); Node* kafka_dump = nullptr; TF_RETURN_IF_ERROR(b->AddScalar(options_.kafka_dump, &kafka_dump)); Node* kafka_dump_prefix = nullptr; TF_RETURN_IF_ERROR( b->AddScalar(options_.kafka_dump_prefix, &kafka_dump_prefix)); AttrValue compression_type; b->BuildAttrValue(options_.compression_type, &compression_type); TF_RETURN_IF_ERROR(b->AddDataset( this, {filename, use_snappy, has_sort_id, kafka_dump, kafka_dump_prefix}, {{kCompressionType, compression_type}}, output)); return Status::OK(); } class Iterator : public DatasetIterator { public: explicit Iterator(const Params& params) : DatasetIterator(params) {} Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, bool* end_of_sequence) override { out_tensors->reserve(1); mutex_lock l(mu_); if (!reader_) { TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env())); } out_tensors->emplace_back(ctx->allocator({}), DT_STRING, TensorShape({})); uint32_t data_source_key; Status s = reader_->next(&offset_, &data_source_key, &out_tensors->back().scalar()()); if (s.ok()) { static monitoring::CounterCell* bytes_counter = metrics::GetTFDataBytesReadCounter(kDatasetType); bytes_counter->IncrementBy( out_tensors->back().scalar()().size()); *end_of_sequence = false; num_random_samples_++; offset_ = reader_->GetOffset(); return Status::OK(); } out_tensors->pop_back(); ResetStreamsLocked(); if (errors::IsOutOfRange(s)) { *end_of_sequence = true; return Status::OK(); } return s; } private: std::shared_ptr CreateNode( IteratorContext* ctx, model::Node::Args args) const override { return model::MakeSourceNode(std::move(args)); } Status SaveInternal(SerializationContext* ctx, IteratorStateWriter* writer) override { mutex_lock l(mu_); LOG(INFO) << "Save function is not supported yet."; TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("num_random_samples"), num_random_samples_)); TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("offset_"), offset_)); return Status::OK(); } Status RestoreInternal(IteratorContext* ctx, IteratorStateReader* reader) override { mutex_lock l(mu_); LOG(INFO) << "Restore function is not supported yet."; TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("num_random_samples"), &num_random_samples_)); int64 offset; TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("offset_"), &offset)); if (dataset()->file_name_.empty()) { offset_ = 0; } else { offset_ = offset; } return Status::OK(); } // Sets up reader streams to read from filename Status SetupStreamsLocked(Env* env) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { std::unique_ptr stream_reader; if (dataset()->file_name_.empty()) { stream_reader = std::make_unique(dataset()->options_); } else { std::unique_ptr f; TF_RETURN_IF_ERROR( env->NewRandomAccessFile(dataset()->file_name_, &f)); auto compression_type = FileStreamReader::GetCompressType( dataset()->options_.use_snappy, dataset()->options_.compression_type); stream_reader = std::make_unique( dataset()->options_, std::move(f), compression_type); } reader_ = absl::make_unique(std::move(stream_reader), PRUNING_RAW_FEATURE); return Status::OK(); } // Resets all reader streams. void ResetStreamsLocked() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { reader_.reset(); } mutex mu_; std::unique_ptr reader_ TF_GUARDED_BY(mu_); int64 num_random_samples_ TF_GUARDED_BY(mu_) = 0; uint64 offset_ TF_GUARDED_BY(mu_) = 0; }; tstring file_name_; DsOptions options_; }; int32 compression_type_; Dataset* output_ = nullptr; }; namespace { REGISTER_KERNEL_BUILDER(Name("InstanceDataset").Device(DEVICE_CPU), InstanceDatasetOp); } // namespace } // namespace monolith_tf } // namespace data } // namespace tensorflow ================================================ FILE: monolith/native_training/data/training_instance/cc/instance_dataset_ops.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_def_builder.h" #include "tensorflow/core/framework/shape_inference.h" namespace tensorflow { REGISTER_OP("InstanceDataset") .Input("file_name: string") .Input("use_snappy: bool") .Input("has_sort_id: bool") .Input("kafka_dump: bool") .Input("kafka_dump_prefix: bool") .Output("handle: variant") .Attr("compression_type: int = 0") .SetDoNotOptimize() // Source dataset ops must disable constant folding. .SetShapeFn([](shape_inference::InferenceContext* c) { shape_inference::ShapeHandle unused; TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused)); TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); return shape_inference::ScalarShape(c); }); } // namespace tensorflow ================================================ FILE: monolith/native_training/data/training_instance/cc/instance_processor.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "absl/strings/str_format.h" #include "absl/time/clock.h" #include "gflags/gflags.h" #include "idl/matrix/proto/proto_parser.pb.h" #include "monolith/native_training/data/training_instance/cc/data_reader.h" #include "monolith/native_training/data/training_instance/cc/instance_utils.h" #include "third_party/nlohmann/json.hpp" DEFINE_bool(kafka_dump, false, "kafka_dump"); DEFINE_bool(kafka_dump_prefix, false, "kafka_dump_prefix"); DEFINE_bool(has_sort_id, true, "has_sort_id"); DEFINE_string(has_fids, "", "The instance of interest should contain at least one of the " "given fids, or it will be dropped"); DEFINE_string(has_actions, "", "The instance of interest should contain at least one of the " "given actions, or it will be dropped"); DEFINE_string( filter_fids, "", "The instance will be dropped if it contains any one of the given fids."); DEFINE_string( select_fids, "", "The instance of interest should contain all of the given fids, or it " "will be dropped."); DEFINE_int64( req_time_min, 0, "The instance of interest should satisfy line_id.req_time >= req_time_min"); DEFINE_int32(buffer_size, 32, "The buffer number of instance"); using tensorflow::Status; using tensorflow::tstring; using tensorflow::uint64; using tensorflow::monolith_tf::IsInstanceOfInterest; using ::tensorflow::monolith_tf::PBIterator; using ::tensorflow::monolith_tf::DataFormatOptions; using ::tensorflow::monolith_tf::StdinStreamReader; using ::tensorflow::monolith_tf::StrToIntegerSet; int main(int argc, char* argv[]) { ::gflags::ParseCommandLineFlags(&argc, &argv, true); auto has_fids = StrToIntegerSet(FLAGS_has_fids); auto filter_fids = StrToIntegerSet(FLAGS_filter_fids); auto select_fids = StrToIntegerSet(FLAGS_select_fids); auto has_actions = StrToIntegerSet(FLAGS_has_actions); absl::Time t = absl::FromUnixSeconds(FLAGS_req_time_min); nlohmann::json json; json["kafka_dump"] = FLAGS_kafka_dump; json["kafka_dump_prefix"] = FLAGS_kafka_dump_prefix; json["has_sort_id"] = FLAGS_has_sort_id; json["has_fids"] = has_fids; json["filter_fids"] = filter_fids; json["select_fids"] = select_fids; json["has_actions"] = has_actions; json["req_time_min"] = FLAGS_req_time_min; json["req_time_min_human_readable"] = absl::FormatTime(t); std::cerr << absl::StrFormat("%s Instance processor config:\n%s", absl::FormatTime(absl::Now()), json.dump(2)) << std::endl; DataFormatOptions options; options.kafka_dump = FLAGS_kafka_dump; options.kafka_dump_prefix = FLAGS_kafka_dump_prefix; options.has_sort_id = FLAGS_has_sort_id; PBIterator reader(std::make_unique(options), tensorflow::monolith_tf::PRUNING_RAW_FEATURE); uint64 offset = 0, count = 0, total = 0; tstring sort_id, serialized_instance; std::stringstream ss; uint32_t data_source_key; while (reader.next(&offset, &data_source_key, &serialized_instance) == Status::OK()) { offset = reader.GetOffset(); parser::proto::Instance instance; instance.ParseFromArray(serialized_instance.data(), serialized_instance.size()); ++total; if (IsInstanceOfInterest(instance, filter_fids, has_fids, select_fids, has_actions, FLAGS_req_time_min, {})) { std::string serialized_instance = instance.SerializeAsString(); uint64_t size_of_sort_id = sort_id.length(); uint64_t size_of_pb = serialized_instance.length(); ss.write(reinterpret_cast(&size_of_sort_id), sizeof(size_of_sort_id)); ss.write(sort_id.data(), sort_id.length()); ss.write(reinterpret_cast(&size_of_pb), sizeof(size_of_pb)); ss.write(const_cast(serialized_instance.data()), serialized_instance.length()); ++count; if (count % FLAGS_buffer_size == 0) { std::string output = ss.str(); ss.str(""); std::cout.write(output.data(), output.length()); std::cout.flush(); } } if (total % 1000000 == 0) { std::cerr << absl::StrFormat( "%s Instance processor input_num = %ld, output_num = %ld.", absl::FormatTime(absl::Now()), total, count) << std::endl; } } if (count % FLAGS_buffer_size) { std::string output = ss.str(); std::cout.write(output.data(), output.length()); std::cout.flush(); } std::cerr << absl::StrFormat( "%s Instance processor input_num = %ld, output_num = %ld. " "Successfully finished!", absl::FormatTime(absl::Now()), total, count) << std::endl; return 0; } ================================================ FILE: monolith/native_training/data/training_instance/cc/instance_reader.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include "absl/strings/str_format.h" #include "absl/time/clock.h" #include "idl/matrix/proto/proto_parser.pb.h" #include "monolith/native_training/data/training_instance/cc/data_reader.h" #include "monolith/native_training/data/training_instance/cc/fid.h" #include "monolith/native_training/data/transform/cc/transforms.h" #include "monolith/native_training/data/transform/transform_config.pb.h" #include "tensorflow/core/platform/base64.h" #include "third_party/cli11/CLI11.hpp" #include "third_party/nlohmann/json.hpp" namespace tf = tensorflow; using idl::matrix::proto::LineId; using monolith::io::proto::Example; using monolith::io::proto::ExampleBatch; using monolith::io::proto::Feature; using monolith::native_training::data::TransformConfig; using parser::proto::Instance; using tf::monolith_tf::FeatureNameMapper; using tf::monolith_tf::FeaturePruningType; using tf::monolith_tf::FileStreamReader; using tf::monolith_tf::InputCompressType; using tf::monolith_tf::PBIterator; using tf::monolith_tf::StdinStreamReader; using tf::monolith_tf::TransformInterface; struct Options { int verbose_level = 0; std::string filepath; std::string dtype = "instance"; std::string compression_type = "none"; std::string config; bool lagrangex_header = false; bool kafka_dump = false; bool kafka_dump_prefix = false; bool has_sort_id = true; int64_t limit = std::numeric_limits::max(); std::string output_format = "json"; bool silent = false; }; void AddOptions(CLI::App& app, Options* options) { app.add_option("-v,--verbose", options->verbose_level, "Verbose level, default: 0"); app.add_option("-i,--input", options->filepath, "Input filepath, read from stdin if empty!"); app.add_option("-d,--dtype", options->dtype, "Data type, default: instance, choices = [instance, example, " "example_batch]") ->check([](std::string choice) { std::unordered_set choices = {"instance", "example", "example_batch"}; if (!choices.count(choice)) { return absl::StrFormat("Invalid dtype: %s", choice); } return std::string(); }); app.add_option("-c,--compression_type", options->compression_type, "Compression type, default: none, choices = [none, snappy]") ->check([](std::string choice) { std::unordered_set choices = {"none", "snappy"}; if (!choices.count(choice)) { return absl::StrFormat("Invalid compression_type: %s", choice); } return std::string(); }); app.add_option("--lagrangex_header", options->lagrangex_header, "default: false"); app.add_option("-k,--kafka_dump", options->kafka_dump, "default: false"); app.add_option("--kafka_dump_prefix", options->kafka_dump_prefix, "default: false"); app.add_option("--has_sort_id", options->has_sort_id, "default: true"); app.add_option("--config", options->config, "Transform config, plain text, e.g. configs { basic_config { " "filter_by_fid { select_fids: 18428264561369945341 } } }"); app.add_option("-l,--limit", options->limit, "Output limit number records, default: inf"); app.add_option("-f,--format", options->output_format, "Output format, default: json, choices = [json, pbtxt]") ->check([](std::string choice) { std::unordered_set choices = {"json", "pbtxt"}; if (!choices.count(choice)) { return absl::StrFormat("Invalid output format: %s", choice); } return std::string(); }); app.add_option("--silent", options->silent, "Output nothing but statistics information."); } class InputReader { public: explicit InputReader(const Options& options, TransformConfig config) : config_(std::move(config)), offset_(0), total_(0), count_(0), end_of_sequence_(false) { tf::monolith_tf::DataFormatOptions ds_options{ options.lagrangex_header, options.kafka_dump_prefix, options.has_sort_id, options.kafka_dump}; tf::Env* env = tf::Env::Default(); std::unique_ptr stream_reader; if (options.filepath.empty()) { stream_reader = std::make_unique(ds_options); } else { std::unique_ptr f; TF_CHECK_OK(env->NewRandomAccessFile(options.filepath, &f)); stream_reader = std::make_unique( ds_options, std::move(f), options.compression_type == "none" ? InputCompressType::NO : InputCompressType::SNAPPY); } if (options.dtype == "instance" || options.dtype == "example") { reader_ = absl::make_unique( std::move(stream_reader), FeaturePruningType::AS_IS); } else { mapper_ = std::make_unique(); reader_ = absl::make_unique( std::move(stream_reader), FeaturePruningType::AS_IS, mapper_.get()); } transform_ = tf::monolith_tf::NewTransformFromConfig(config_); } template bool ReadOne(T* output) { if (IsBufferEmpty()) { tf::tstring serialized_instance; try { uint32_t data_source_key = 0; while (!end_of_sequence_ && reader_->next(&offset_, &data_source_key, &serialized_instance) .ok()) { offset_ = reader_->GetOffset(); std::shared_ptr sample = std::make_shared(); if (!sample->ParseFromArray(serialized_instance.data(), serialized_instance.size())) { LOG(ERROR) << "Unable to parse data. Data might be corrupted"; return false; } ++total_; std::vector> outputs; Transform(sample, &outputs); count_ += outputs.size(); for (const auto& sample : outputs) { PushIntoBuffer(sample); } if (!outputs.empty()) { break; } } } catch (const std::out_of_range& e) { end_of_sequence_ = true; LOG(INFO) << e.what(); } catch (const std::exception& e) { end_of_sequence_ = true; LOG(ERROR) << e.what(); } } if (!IsBufferEmpty()) { std::shared_ptr front; PopFromBuffer(&front); front->Swap(output); return true; } return false; } template bool ReadOneSerialized(std::string* serialized) { T t; bool success = ReadOne(&t); if (success) { *serialized = t.SerializeAsString(); } return success; } private: template typename std::enable_if::value, void>::type Transform(std::shared_ptr input, std::vector>* output) { transform_->Transform(input, output); } template typename std::enable_if::value, void>::type Transform(std::shared_ptr input, std::vector>* output) { transform_->Transform(input, output); } template typename std::enable_if::value, void>::type Transform(std::shared_ptr input, std::vector>* output) { output->push_back(input); } template typename std::enable_if::value, bool>::type IsBufferEmpty() { return instance_buffer_.empty(); } template typename std::enable_if::value, bool>::type IsBufferEmpty() { return example_buffer_.empty(); } template typename std::enable_if::value, bool>::type IsBufferEmpty() { return example_batch_buffer_.empty(); } template typename std::enable_if::value, void>::type PushIntoBuffer(std::shared_ptr t) { instance_buffer_.push(t); } template typename std::enable_if::value, void>::type PushIntoBuffer(std::shared_ptr t) { example_buffer_.push(t); } template typename std::enable_if::value, void>::type PushIntoBuffer(std::shared_ptr t) { example_batch_buffer_.push(t); } template typename std::enable_if::value, void>::type PopFromBuffer(std::shared_ptr* t) { *t = instance_buffer_.front(); instance_buffer_.pop(); } template typename std::enable_if::value, void>::type PopFromBuffer(std::shared_ptr* t) { *t = example_buffer_.front(); example_buffer_.pop(); } template typename std::enable_if::value, void>::type PopFromBuffer(std::shared_ptr* t) { *t = example_batch_buffer_.front(); example_batch_buffer_.pop(); } TransformConfig config_; std::unique_ptr transform_; std::unique_ptr reader_; std::unique_ptr mapper_; std::queue> instance_buffer_; std::queue> example_buffer_; std::queue> example_batch_buffer_; tf::uint64 offset_; uint64_t total_; uint64_t count_; bool end_of_sequence_; }; template void ReadAndSerialize( InputReader& reader, T* t, const Options& options, const std::function& callback_fn) { auto json_options = google::protobuf::util::JsonOptions(); json_options.add_whitespace = true; json_options.preserve_proto_field_names = true; for (int64_t i = 0; i < options.limit; ++i) { if (reader.ReadOne(t)) { if (!options.silent) { std::string output; if (options.output_format == "json") { google::protobuf::util::MessageToJsonString(*t, &output, json_options); output = callback_fn(output); } else { output = t->DebugString(); } std::cout.write(output.data(), output.length()); std::cout.flush(); } } else { break; } } } void to_json(nlohmann::json& j, const FIDV1& fid) { j = fid.DebugString(); } void to_json(nlohmann::json& j, const FIDV2& fid) { j = fid.DebugString(); } std::string JsonCallbackFn(const std::string& serialized) { nlohmann::json json; try { json = nlohmann::json::parse(serialized); } catch (const std::exception& e) { LOG(FATAL) << e.what() << "\nserialized:\n" << serialized; } auto CollectFID = [](const nlohmann::json& j, std::vector* fids) { CHECK(j.is_array()); fids->reserve(j.size()); for (absl::string_view fid_str : j) { uint64_t fid = 0; CHECK(absl::SimpleAtoi(fid_str, &fid)); fids->emplace_back(fid); } std::sort(fids->begin(), fids->end()); }; // instance fid(v1) if (json.contains("fid") && json["fid"].is_array()) { std::vector fids; CollectFID(json["fid"], &fids); std::vector fids_v1(fids.begin(), fids.end()); json["fid"] = fids_v1; } // instance feature(v2) if (json.contains("feature") && json["feature"].is_array()) { for (nlohmann::json& element : json["feature"]) { if (element.contains("fid") && element["fid"].is_array()) { std::vector fids; CollectFID(element["fid"], &fids); std::vector fids_v2(fids.begin(), fids.end()); element["fid"] = fids_v2; } } } auto ReplaceFeatureFn = [&](nlohmann::json& feature, bool is_line_id) { if (feature.contains("fid_v1_list")) { nlohmann::json& fid_v1_list = feature["fid_v1_list"]; if (fid_v1_list.contains("value") && fid_v1_list["value"].is_array()) { std::vector fids; CollectFID(fid_v1_list["value"], &fids); std::vector fids_v1(fids.begin(), fids.end()); fid_v1_list["value"] = fids_v1; } } else if (feature.contains("fid_v2_list")) { nlohmann::json& fid_v2_list = feature["fid_v2_list"]; if (fid_v2_list.contains("value") && fid_v2_list["value"].is_array()) { std::vector fids; CollectFID(fid_v2_list["value"], &fids); std::vector fids_v2(fids.begin(), fids.end()); fid_v2_list["value"] = fids_v2; } } else if (feature.contains("bytes_list") && is_line_id) { nlohmann::json& bytes_list = feature["bytes_list"]; if (bytes_list.contains("value") && bytes_list["value"].is_array()) { CHECK_EQ(bytes_list["value"].size(), 1); LineId line_id; std::string based64_encoded = bytes_list["value"][0]; std::string serialized; tf::Base64Decode(based64_encoded, &serialized); CHECK(line_id.ParseFromArray(serialized.data(), serialized.size())) << serialized; auto json_options = google::protobuf::util::JsonOptions(); json_options.add_whitespace = true; json_options.preserve_proto_field_names = true; std::string output; google::protobuf::util::MessageToJsonString(line_id, &output, json_options); bytes_list["value"] = nlohmann::json::parse(output); } } }; // example named_feature if (json.contains("named_feature") && json["named_feature"].is_array()) { for (nlohmann::json& element : json["named_feature"]) { if (element.contains("feature")) { ReplaceFeatureFn(element["feature"], false); } } } // ExampleBatch named_feature_list if (json.contains("named_feature_list") && json["named_feature_list"].is_array()) { for (nlohmann::json& element : json["named_feature_list"]) { if (element.contains("feature") && element["feature"].is_array()) { bool is_line_id = element.contains("name") && element["name"] == "__LINE_ID__"; for (nlohmann::json& feature : element["feature"]) { ReplaceFeatureFn(feature, is_line_id); } } } } return json.dump(2); } int main(int argc, char* argv[]) { CLI::App app("instance_reader"); app.set_version_flag("--version", "0.0.1"); Options options; AddOptions(app, &options); CLI11_PARSE(app, argc, argv) if (options.dtype == "example_batch" && !options.config.empty()) { LOG(FATAL) << "Transform cannot process ExampleBatch!"; } TransformConfig config; CHECK(google::protobuf::TextFormat::ParseFromString(options.config, &config)); std::cerr << config.DebugString(); InputReader reader(options, config); std::string output; if (options.dtype == "instance") { Instance instance; ReadAndSerialize(reader, &instance, options, JsonCallbackFn); } else if (options.dtype == "example") { Example example; ReadAndSerialize(reader, &example, options, JsonCallbackFn); } else if (options.dtype == "example_batch") { ExampleBatch example_batch; ReadAndSerialize(reader, &example_batch, options, JsonCallbackFn); } else { throw std::invalid_argument( absl::StrFormat("Invalid dtype=%s", options.dtype)); } return 0; } ================================================ FILE: monolith/native_training/data/training_instance/cc/instance_utils.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/data/training_instance/cc/instance_utils.h" namespace tensorflow { namespace monolith_tf {} // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/data/training_instance/cc/instance_utils.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_NATIVE_TRAINING_DATA_TRAINING_INSTANCE_CC_INSTANCE_UTILS_H_ #define MONOLITH_NATIVE_TRAINING_DATA_TRAINING_INSTANCE_CC_INSTANCE_UTILS_H_ #include #include "absl/strings/numbers.h" #include "absl/strings/str_format.h" #include "absl/strings/str_split.h" #include "idl/matrix/proto/example.pb.h" #include "idl/matrix/proto/proto_parser.pb.h" #include "monolith/native_training/data/training_instance/cc/reader_util.h" namespace tensorflow { namespace monolith_tf { template std::set StrToIntegerSet(const std::string& str) { static_assert( std::is_same::value || std::is_same::value, "Template typename T should be uint64_t or int32_t!"); std::set integers; std::set splits = absl::StrSplit(str, ","); for (const auto& s : splits) { if (!s.empty()) { T fid; if (absl::SimpleAtoi(s, &fid)) { integers.insert(fid); } else { throw std::invalid_argument( absl::StrFormat("Invalid integer string: %s", s)); } } } return integers; } template typename std::enable_if::value, void>::type CollectFidIntoSet(const T& instance, std::set* fid_set) { const auto& instance_fids = instance.fid(); fid_set->insert(instance_fids.begin(), instance_fids.end()); } template typename std::enable_if::value, void>::type CollectFidIntoSet(const T& example, std::set* fid_set) { for (const auto& named_feature : example.named_feature()) { if (named_feature.feature().has_fid_v1_list()) { const auto& fids = named_feature.feature().fid_v1_list().value(); fid_set->insert(fids.begin(), fids.end()); } if (named_feature.feature().has_fid_v2_list()) { const auto& fids = named_feature.feature().fid_v2_list().value(); fid_set->insert(fids.begin(), fids.end()); } } } template typename std::enable_if::value, void>::type CollectSlotIntoSet(const T& instance, std::set* slot_set) { for (uint64_t fid : instance.fid()) { int slot = slot_id_v1(fid); slot_set->insert(slot); } for (const auto& f : instance.feature()) { for (uint64_t fid : f.fid()) { int slot = slot_id_v2(fid); slot_set->insert(slot); } } } template typename std::enable_if::value, void>::type CollectSlotIntoSet(const T& example, std::set* slot_set) { for (const auto& named_feature : example.named_feature()) { if (named_feature.feature().has_fid_v1_list()) { const auto& fids = named_feature.feature().fid_v1_list().value(); for (uint64_t fid : fids) { int slot = slot_id_v1(fid); slot_set->insert(slot); } } if (named_feature.feature().has_fid_v2_list()) { const auto& fids = named_feature.feature().fid_v2_list().value(); for (uint64_t fid : fids) { int slot = slot_id_v2(fid); slot_set->insert(slot); } } } } template bool IsInstanceOfInterest(const T& pb, const std::set& filter_fids, const std::set& has_fids, const std::set& select_fids, const std::set& has_actions, int64_t req_time_min, const std::set& select_slots) { if (pb.line_id().req_time() < req_time_min) { return false; } std::set fid_set; CollectFidIntoSet(pb, &fid_set); std::set slot_set; CollectSlotIntoSet(pb, &slot_set); const auto& actions = pb.line_id().actions(); std::set instance_actions_set(actions.begin(), actions.end()); if (!filter_fids.empty()) { std::set intersection; std::set_intersection(fid_set.begin(), fid_set.end(), filter_fids.begin(), filter_fids.end(), std::inserter(intersection, intersection.begin())); // If the instance contains any one of the given `filter_fids`, it will be // dropped. if (!intersection.empty()) { return false; } } if (!has_fids.empty()) { std::set intersection; std::set_intersection(fid_set.begin(), fid_set.end(), has_fids.begin(), has_fids.end(), std::inserter(intersection, intersection.begin())); // If the instance does not contain any one of the given `has_fids`, it will // be dropped. if (intersection.empty()) { return false; } } if (!select_fids.empty()) { std::set intersection; std::set_intersection(fid_set.begin(), fid_set.end(), select_fids.begin(), select_fids.end(), std::inserter(intersection, intersection.begin())); // If the instance does not contain all of the given `select_fids`, it will // be dropped. if (intersection.size() < select_fids.size()) { return false; } } if (!select_slots.empty()) { std::set intersection; std::set_intersection(slot_set.begin(), slot_set.end(), select_slots.begin(), select_slots.end(), std::inserter(intersection, intersection.begin())); // If the instance does not contain all of the given `select_slots`, it will // be dropped. if (intersection.size() < select_slots.size()) { return false; } } if (!has_actions.empty()) { std::set intersection; std::set_intersection(instance_actions_set.begin(), instance_actions_set.end(), has_actions.begin(), has_actions.end(), std::inserter(intersection, intersection.begin())); // If the instance does not contain any one of the given `has_actions`, it // will be dropped. if (intersection.empty()) { return false; } } return true; } } // namespace monolith_tf } // namespace tensorflow #endif // MONOLITH_NATIVE_TRAINING_DATA_TRAINING_INSTANCE_CC_INSTANCE_UTILS_H_ ================================================ FILE: monolith/native_training/data/training_instance/cc/instance_utils_test.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/data/training_instance/cc/instance_utils.h" #include "absl/time/clock.h" #include "glog/logging.h" #include "gmock/gmock.h" #include "gtest/gtest.h" namespace tensorflow { namespace monolith_tf { namespace { using ::testing::Eq; TEST(StrToIntegerSet, StrToFIDs) { std::set fids1 = StrToIntegerSet(""); EXPECT_TRUE(fids1.empty()); std::set fids2 = StrToIntegerSet("6461985998153810495"); EXPECT_EQ(fids2.size(), 1); EXPECT_EQ(*fids2.begin(), 6461985998153810495ull); std::set fids3 = StrToIntegerSet("6457882839108881377,6436927642569553426,"); EXPECT_EQ(fids3.size(), 2); EXPECT_EQ(*fids3.begin(), 6436927642569553426ull); EXPECT_EQ(*std::next(fids3.begin(), 1), 6457882839108881377ul); std::set fids4 = StrToIntegerSet("6461985998153810495,"); EXPECT_EQ(fids4.size(), 1); EXPECT_EQ(*fids4.begin(), 6461985998153810495ull); try { StrToIntegerSet("6461985998153810495,abc"); } catch (const std::invalid_argument& e) { EXPECT_THAT(std::string(e.what()), Eq("Invalid integer string: abc")); } catch (const std::exception& e) { LOG(ERROR) << "Unexpected exception thrown: " << e.what() << std::endl; } } TEST(StrToIntegerSet, StrToActions) { std::set actions1 = StrToIntegerSet(""); EXPECT_TRUE(actions1.empty()); std::set actions2 = StrToIntegerSet("-1"); EXPECT_EQ(actions2.size(), 1); EXPECT_EQ(*actions2.begin(), -1); std::set actions3 = StrToIntegerSet("-1,3,"); EXPECT_EQ(actions3.size(), 2); EXPECT_EQ(*actions3.begin(), -1); EXPECT_EQ(*std::next(actions3.begin(), 1), 3); std::set actions4 = StrToIntegerSet("1,"); EXPECT_EQ(actions4.size(), 1); EXPECT_EQ(*actions4.begin(), 1); try { StrToIntegerSet("1,abc"); } catch (const std::invalid_argument& e) { EXPECT_THAT(std::string(e.what()), Eq("Invalid integer string: abc")); } catch (const std::exception& e) { LOG(ERROR) << "Unexpected exception thrown: " << e.what() << std::endl; } } TEST(IsInstanceOfInterest, Basic) { parser::proto::Instance instance; instance.mutable_fid()->Add(6436927642569553426ull); instance.mutable_fid()->Add(6457882839108881377ull); instance.mutable_fid()->Add(6461985998153810495ull); int64_t now = absl::ToUnixSeconds(absl::Now()); instance.mutable_line_id()->set_req_time(now); instance.mutable_line_id()->mutable_actions()->Add(-1); instance.mutable_line_id()->mutable_actions()->Add(1); std::set filter_fids = {6436927642569553426ull}; std::set has_fids = {6436927642569553426ull}; std::set select_fids = {6436927642569553426ull, 6457882839108881377ull}; // 1626537600 -> 2021-07-18 00:00:00 int64_t req_time_min = 1626537600; EXPECT_TRUE(IsInstanceOfInterest(instance, {}, {}, {}, {}, req_time_min, {})); EXPECT_TRUE(IsInstanceOfInterest(instance, {}, {}, {}, {}, now, {})); EXPECT_TRUE(!IsInstanceOfInterest(instance, {}, {}, {}, {}, now + 1, {})); EXPECT_TRUE(!IsInstanceOfInterest(instance, filter_fids, {}, {}, {}, req_time_min, {})); EXPECT_TRUE( IsInstanceOfInterest(instance, {}, has_fids, {}, {}, req_time_min, {})); EXPECT_TRUE(IsInstanceOfInterest(instance, {}, {}, select_fids, {}, req_time_min, {})); EXPECT_TRUE(!IsInstanceOfInterest(instance, filter_fids, has_fids, select_fids, {}, req_time_min, {})); EXPECT_TRUE( IsInstanceOfInterest(instance, {}, {}, {}, {-1}, req_time_min, {})); EXPECT_TRUE( IsInstanceOfInterest(instance, {}, {}, {}, {1, 5}, req_time_min, {})); EXPECT_TRUE(IsInstanceOfInterest(instance, {}, has_fids, select_fids, {1, 5}, req_time_min, {})); } TEST(CollectFidIntoSet, Instance) { parser::proto::Instance instance; instance.mutable_fid()->Add(GetFidV1(2, 200)); instance.mutable_fid()->Add(GetFidV1(3, 300)); auto f1 = instance.mutable_feature()->Add(); f1->mutable_fid()->Add(GetFidV2(1024, 102400)); auto f2 = instance.mutable_feature()->Add(); f2->mutable_fid()->Add(GetFidV2(4096, 409600)); std::set slots, slots_expected{2, 3, 1024, 4096}, intersection; CollectSlotIntoSet(instance, &slots); std::set_intersection(slots.begin(), slots.end(), slots_expected.begin(), slots_expected.end(), std::inserter(intersection, intersection.begin())); EXPECT_EQ(intersection.size(), slots_expected.size()); std::set select_slots1 = {2, 3, 1024, 4096}, select_slots2 = {2, 10}; EXPECT_TRUE(IsInstanceOfInterest(instance, {}, {}, {}, {}, 0, select_slots1)); EXPECT_TRUE( !IsInstanceOfInterest(instance, {}, {}, {}, {}, 0, select_slots2)); } TEST(CollectFidIntoSet, Example) { monolith::io::proto::Example example; auto f1 = example.mutable_named_feature()->Add(); f1->set_name("user_id"); f1->mutable_feature()->mutable_fid_v1_list()->mutable_value()->Add( GetFidV1(2, 200)); auto f2 = example.mutable_named_feature()->Add(); f2->set_name("item_id"); f2->mutable_feature()->mutable_fid_v1_list()->mutable_value()->Add( GetFidV1(3, 300)); auto f3 = example.mutable_named_feature()->Add(); f3->set_name("gender"); f3->mutable_feature()->mutable_fid_v2_list()->mutable_value()->Add( GetFidV2(1024, 102400)); auto f4 = example.mutable_named_feature()->Add(); f4->set_name("age"); f4->mutable_feature()->mutable_fid_v2_list()->mutable_value()->Add( GetFidV2(4096, 409600)); std::set slots, slots_expected{2, 3, 1024, 4096}, intersection; CollectSlotIntoSet(example, &slots); std::set_intersection(slots.begin(), slots.end(), slots_expected.begin(), slots_expected.end(), std::inserter(intersection, intersection.begin())); EXPECT_EQ(intersection.size(), slots_expected.size()); std::set select_slots1 = {2, 3, 1024, 4096}, select_slots2 = {2, 10}; EXPECT_TRUE(IsInstanceOfInterest(example, {}, {}, {}, {}, 0, select_slots1)); EXPECT_TRUE(!IsInstanceOfInterest(example, {}, {}, {}, {}, 0, select_slots2)); } } // namespace } // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/data/training_instance/cc/parse_instance_kernel.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "absl/container/flat_hash_map.h" #include "google/protobuf/descriptor.h" #include "idl/matrix/proto/proto_parser.pb.h" #include "monolith/native_training/data/training_instance/cc/parse_instance_lib.h" #include "monolith/native_training/data/training_instance/cc/pb_variant.h" #include "monolith/native_training/data/training_instance/cc/reader_util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/platform/env.h" namespace tensorflow { namespace monolith_tf { namespace { using Instance = ::parser::proto::Instance; Status GetParserConfig(OpKernelConstruction *ctx, InstanceParserConfig *c) { TF_RETURN_IF_ERROR(ctx->GetAttr("fidv1_features", &c->fidv1_features)); TF_RETURN_IF_ERROR(ctx->GetAttr("fidv2_features", &c->fidv2_features)); TF_RETURN_IF_ERROR( ctx->GetAttr("float_feature_dims", &c->float_feature_dims)); TF_RETURN_IF_ERROR(ctx->GetAttr("float_features", &c->float_features)); if (c->float_features.size() != c->float_feature_dims.size()) { return errors::InvalidArgument( "Num of float features and float feature dims do not match"); } TF_RETURN_IF_ERROR( ctx->GetAttr("int64_feature_dims", &c->int64_feature_dims)); TF_RETURN_IF_ERROR(ctx->GetAttr("int64_features", &c->int64_features)); if (c->int64_features.size() != c->int64_feature_dims.size()) { return errors::InvalidArgument( "Num of int64 features and int64 feature dims do not match"); } TF_RETURN_IF_ERROR( ctx->GetAttr("string_feature_dims", &c->string_feature_dims)); TF_RETURN_IF_ERROR(ctx->GetAttr("string_features", &c->string_features)); if (c->string_features.size() != c->string_feature_dims.size()) { return errors::InvalidArgument( "Num of string features and string feature dims do not match"); } TF_RETURN_IF_ERROR( ctx->GetAttr("misc_float_features", &c->misc_float_features)); TF_RETURN_IF_ERROR(ctx->GetAttr("misc_float_dims", &c->misc_float_dims)); if (c->misc_float_features.size() != c->misc_float_dims.size()) { return errors::InvalidArgument( "Num of float features do not match it dims the size of " "misc_float_features is ", c->misc_float_features.size(), ", while the size of misc_float_dims is ", c->misc_float_dims.size()); } TF_RETURN_IF_ERROR( ctx->GetAttr("misc_int64_features", &c->misc_int64_features)); TF_RETURN_IF_ERROR(ctx->GetAttr("misc_int64_dims", &c->misc_int64_dims)); if (c->misc_int64_features.size() != c->misc_int64_dims.size()) { return errors::InvalidArgument( "Num of features do not match it dims the size of " "misc_features is ", c->misc_int64_features.size(), ", while the size of misc_dims is ", c->misc_int64_dims.size()); } TF_RETURN_IF_ERROR( ctx->GetAttr("misc_string_features", &c->misc_string_features)); TF_RETURN_IF_ERROR(ctx->GetAttr("misc_string_dims", &c->misc_string_dims)); if (c->misc_string_features.size() != c->misc_string_dims.size()) { return errors::InvalidArgument( "Num of features do not match it dims the size of " "misc_features is ", c->misc_string_features.size(), ", while the size of misc_dims is ", c->misc_string_dims.size()); } return Status::OK(); } bool ParseInstance(const tstring &serialized, Instance *instance) { return instance->ParseFromArray(serialized.data(), serialized.size()); } class ParseInstancesOp : public OpKernel { public: explicit ParseInstancesOp(OpKernelConstruction *ctx) : OpKernel(ctx) { InstanceParserConfig config; OP_REQUIRES_OK(ctx, GetParserConfig(ctx, &config)); config.collapse_batch_dim = false; parser_ = std::make_unique(config); OP_REQUIRES_OK(ctx, parser_->Init()); } void Compute(OpKernelContext *ctx) override { // Grab the input tensor const Tensor *serialized; OP_REQUIRES_OK(ctx, ctx->input("serialized", &serialized)); TTypes::ConstVec serialized_protos = serialized->vec(); const int batch_size = serialized_protos.dimension(0); std::vector instances(batch_size); for (int i = 0; i < batch_size; ++i) { OP_REQUIRES(ctx, ParseInstance(serialized_protos(i), &instances[i]), errors::FailedPrecondition("Failed to parse the Instance.")); } InstanceParser::Output output; OP_REQUIRES_OK(ctx, parser_->Parse(ctx, instances, &output)); for (int i = 0; i < static_cast(output.tensors.size()); ++i) { ctx->set_output(i, output.tensors[i]); } } private: std::unique_ptr parser_; }; // This class is mainly for testing parser. // Do not use in the model code directly. class RawParseInstanceOp : public OpKernel { public: explicit RawParseInstanceOp(OpKernelConstruction *ctx) : OpKernel(ctx) { InstanceParserConfig config; OP_REQUIRES_OK(ctx, GetParserConfig(ctx, &config)); OP_REQUIRES_OK( ctx, ctx->GetAttr("collapse_batch_dim", &config.collapse_batch_dim)); std::string fid_output_type; OP_REQUIRES_OK(ctx, ctx->GetAttr("fid_output_type", &fid_output_type)); absl::flat_hash_map str_to_enum = { {"REGULAR", InstanceParserConfig::REGULAR}, {"CONCAT", InstanceParserConfig::CONCAT}, }; config.fid_output_type = str_to_enum.at(fid_output_type); parser_ = std::make_unique(config); OP_REQUIRES_OK(ctx, parser_->Init()); } void Compute(OpKernelContext *ctx) override { // Grab the input tensor const Tensor *serialized; OP_REQUIRES_OK(ctx, ctx->input("serialized", &serialized)); auto serialized_flat = serialized->flat(); std::vector instances(serialized_flat.size()); for (size_t i = 0; i < instances.size(); ++i) { OP_REQUIRES(ctx, ParseInstance(serialized_flat(i), &instances[i]), errors::FailedPrecondition("Failed to parse the Instance.")); } InstanceParser::Output output; OP_REQUIRES_OK(ctx, parser_->Parse(ctx, instances, &output)); OpOutputList l; ctx->output_list("tensors", &l); for (size_t i = 0; i < output.tensors.size(); ++i) { l.set(i, output.tensors[i]); } } private: std::unique_ptr parser_; }; REGISTER_KERNEL_BUILDER(Name("MonolithParseInstances").Device(DEVICE_CPU), ParseInstancesOp); REGISTER_KERNEL_BUILDER(Name("MonolithRawParseInstance").Device(DEVICE_CPU), RawParseInstanceOp); } // namespace } // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/data/training_instance/cc/parse_instance_lib.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include "glog/logging.h" #include "tensorflow/core/platform/errors.h" #include "idl/matrix/compression/float16.h" #include "monolith/native_training/data/training_instance/cc/parse_instance_lib.h" #include "monolith/native_training/data/training_instance/cc/reader_util.h" #include "monolith/native_training/data/training_instance/cc/ue_compress.h" namespace tensorflow { namespace monolith_tf { namespace { using ::google::protobuf::FieldDescriptor; using ::idl::matrix::proto::Feature; using ::parser::proto::Instance; using tensorflow::monolith_tf::UECompress; // The spec that will be used by parser. // It includes some preprocessed data struct InstanceParserSpec : InstanceParserConfig { explicit InstanceParserSpec(const InstanceParserConfig &config) : InstanceParserConfig(config) {} Status Init() { fidv1_features_set = {fidv1_features.begin(), fidv1_features.end()}; fidv2_features_set = {fidv2_features.begin(), fidv2_features.end()}; int index = 0; for (int slot : fidv1_features) { slot_to_index[slot] = index++; } for (const std::string &name : fidv2_features) { fidv2_name_to_index[name] = index++; } n_ragged_tensors = fidv1_features.size() + fidv2_features.size(); float_features_set = {float_features.begin(), float_features.end()}; for (size_t i = 0; i < float_features.size(); ++i) { float_feature_name_to_index[float_features[i]] = i; } n_float_tensors = float_features.size(); int64_features_set = {int64_features.begin(), int64_features.end()}; for (size_t i = 0; i < int64_features.size(); ++i) { int64_feature_name_to_index[int64_features[i]] = i; } n_int64_tensors = int64_features.size(); string_features_set = {string_features.begin(), string_features.end()}; for (size_t i = 0; i < string_features.size(); ++i) { string_feature_name_to_index[string_features[i]] = i; } n_string_tensors = string_features.size(); return Status::OK(); } // Fid features attrs absl::flat_hash_set fidv1_features_set; absl::flat_hash_map slot_to_index; absl::flat_hash_set fidv2_features_set; absl::flat_hash_map fidv2_name_to_index; int n_ragged_tensors; // Float features attrs int n_float_tensors; absl::flat_hash_set float_features_set; absl::flat_hash_map float_feature_name_to_index; // Int64 features attrs int n_int64_tensors; absl::flat_hash_set int64_features_set; absl::flat_hash_map int64_feature_name_to_index; // String features attrs int n_string_tensors; absl::flat_hash_set string_features_set; absl::flat_hash_map string_feature_name_to_index; }; class RaggedTensorProcessor { public: explicit RaggedTensorProcessor(const InstanceParserSpec *spec) : spec_(*spec) {} virtual ~RaggedTensorProcessor() = default; // Process the ragged tensor. // The output will be added to the output. virtual Status ParseRaggedTensors(OpKernelContext *ctx, absl::Span instances, InstanceParser::Output *output) = 0; protected: const InstanceParserSpec &spec() const { return spec_; } // A util function for child class to use. template void IterateFidFeatures(const Instance &instance, Func func) { const bool apply_fid_v2 = !spec_.fidv2_features.empty(); for (const uint64_t fid : instance.fid()) { int slot_id = slot_id_v1(fid); if (!spec_.fidv1_features_set.contains(slot_id)) continue; uint64_t converted_fid = apply_fid_v2 ? convert_fid_v1_to_v2(slot_id, fid) : fid; func(spec_.slot_to_index.at(slot_id), converted_fid); } // Feature v2 should never have 2 features with the same feature name. for (const auto &feature : instance.feature()) { if (!spec_.fidv2_features_set.contains(feature.name())) continue; // This is a simple sample list. for (const auto &fid : feature.fid()) { func(spec_.fidv2_name_to_index.at(feature.name()), fid); } // this is a sequence feature list. for (const auto &fidlist : feature.fid_list()) { for (const auto &fid : fidlist.value()) { func(spec_.fidv2_name_to_index.at(feature.name()), fid); } } } } private: const InstanceParserSpec &spec_; }; class RegularRaggedTensorProcessor : public RaggedTensorProcessor { public: explicit RegularRaggedTensorProcessor(const InstanceParserSpec *spec) : RaggedTensorProcessor(spec) {} Status ParseRaggedTensors(OpKernelContext *ctx, absl::Span instances, InstanceParser::Output *output) override { int batch_size = instances.size(); std::vector::Vec> splits_vec; splits_vec.reserve(spec().n_ragged_tensors); for (int i = 0; i < spec().n_ragged_tensors; ++i) { output->tensors.emplace_back(); Tensor *t = &output->tensors.back(); TF_RETURN_IF_ERROR(ctx->allocate_temp(DT_INT64, {batch_size + 1}, t)); auto vec = t->vec(); vec(0) = 0; splits_vec.emplace_back(vec); } std::vector nums(spec().n_ragged_tensors); for (int i = 0; i < batch_size; ++i) { IterateFidFeatures(instances[i], [&nums](int idx, int64_t fid) { nums[idx]++; }); for (int j = 0; j < spec().n_ragged_tensors; ++j) { splits_vec[j](i + 1) = nums[j]; } } std::vector::Vec> values_vec; for (int i = 0; i < spec().n_ragged_tensors; ++i) { output->tensors.emplace_back(); Tensor *t = &output->tensors.back(); TF_RETURN_IF_ERROR( ctx->allocate_temp(DT_INT64, {splits_vec[i](batch_size)}, t)); values_vec.emplace_back(t->vec()); } std::fill(nums.begin(), nums.end(), 0); for (int i = 0; i < batch_size; ++i) { IterateFidFeatures(instances[i], [&nums, &values_vec](int idx, int64_t fid) { values_vec[idx](nums[idx]) = fid; nums[idx]++; }); } return Status::OK(); } }; class ConcatRaggedTensorProcessor : public RaggedTensorProcessor { public: explicit ConcatRaggedTensorProcessor(const InstanceParserSpec *spec) : RaggedTensorProcessor(spec) {} Status ParseRaggedTensors(OpKernelContext *ctx, absl::Span instances, InstanceParser::Output *output) override { int batch_size = instances.size(); if (batch_size != 1) { return errors::InvalidArgument( "ConcatRaggedTensorProcessor only support batch_size == 1"); } const Instance &instance = instances[0]; Tensor t; TF_RETURN_IF_ERROR( ctx->allocate_temp(DT_INT64, {spec().n_ragged_tensors + 1}, &t)); output->tensors.push_back(t); auto split = t.vec().setZero(); IterateFidFeatures(instance, [&split](int idx, int64_t fid) { ++split(idx + 1); }); for (int i = 1; i <= spec().n_ragged_tensors; ++i) { split(i) = split(i) + split(i - 1); } TF_RETURN_IF_ERROR( ctx->allocate_temp(DT_INT64, {split(spec().n_ragged_tensors)}, &t)); output->tensors.push_back(t); auto value = t.vec(); std::vector pos(spec().n_ragged_tensors); for (int i = 0; i < spec().n_ragged_tensors; ++i) { pos[i] = split(i); } IterateFidFeatures(instance, [&value, &pos](int idx, int64_t fid) { value(pos[idx]++) = fid; }); return Status::OK(); } }; } // namespace class InstanceParser::Impl { public: explicit Impl(const InstanceParserConfig &config) : spec_(config) { switch (config.fid_output_type) { case InstanceParserConfig::REGULAR: ragged_tensor_processor_ = std::make_unique(&spec_); break; case InstanceParserConfig::CONCAT: ragged_tensor_processor_ = std::make_unique(&spec_); break; } ue_compress_ = std::make_unique(); } Status Init() { return spec_.Init(); } Status Parse(OpKernelContext *ctx, absl::Span instances, Output *output) { output->tensors.clear(); TF_RETURN_IF_ERROR( ragged_tensor_processor_->ParseRaggedTensors(ctx, instances, output)); TF_RETURN_IF_ERROR(FillFloatFeatures(ctx, instances, output)); TF_RETURN_IF_ERROR(FillInt64Features(ctx, instances, output)); TF_RETURN_IF_ERROR(FillStringFeatures(ctx, instances, output)); TF_RETURN_IF_ERROR(ParseFloatTensors(ctx, instances, output)); TF_RETURN_IF_ERROR(ParseInt64Tensors(ctx, instances, output)); TF_RETURN_IF_ERROR(ParseStringTensors(ctx, instances, output)); return Status::OK(); } private: Status FillFloatFeatures(OpKernelContext *ctx, absl::Span instances, Output *output) { const int batch_size = instances.size(); std::vector::Matrix> values_mat; for (int i = 0; i < spec_.n_float_tensors; ++i) { output->tensors.emplace_back(); Tensor *t = &output->tensors.back(); TF_RETURN_IF_ERROR(ctx->allocate_temp( DT_FLOAT, GetBatched1DShape(batch_size, spec_.float_feature_dims[i]), t)); values_mat.emplace_back( t->shaped({batch_size, spec_.float_feature_dims[i]})); // To be safe, we initialize float tensors to zero by default. values_mat.back().setZero(); } for (int i = 0; i < batch_size; ++i) { const Instance &instance = instances[i]; for (const Feature &feature : instance.feature()) { if (spec_.float_features_set.contains(feature.name())) { std::vector embedding; bool ret = ue_compress_->decompress_embeddings( feature, &embedding, UECompressMethod::COMPRESS_QTZ8); int idx = spec_.float_feature_name_to_index.at(feature.name()); if (ret) { // Process data with qtz8 compression. if (spec_.float_feature_dims[idx] != embedding.size()) { return errors::Internal( "Decompressed qtz8 data length doesn't match feature dim,", " feature dim: ", spec_.float_feature_dims[idx], ", uncompressed qtz8 size: ", embedding.size()); } for (int j = 0; j < spec_.float_feature_dims[idx]; ++j) { values_mat[idx](i, j) = embedding[j]; } } else if (spec_.float_feature_dims[idx] == feature.float_value_size()) { for (int j = 0; j < spec_.float_feature_dims[idx]; ++j) { values_mat[idx](i, j) = feature.float_value(j); } } else { // TODO(zouxuan) Set the default value to 0 for now. Xuan will make // an eventual fix for this later. for (int j = 0; j < spec_.float_feature_dims[idx]; ++j) { values_mat[idx](i, j) = 0; } } } } } return Status::OK(); } Status FillInt64Features(OpKernelContext *ctx, absl::Span instances, Output *output) { const int batch_size = instances.size(); std::vector::Matrix> values_mat; for (int i = 0; i < spec_.n_int64_tensors; ++i) { output->tensors.emplace_back(); Tensor *t = &output->tensors.back(); TF_RETURN_IF_ERROR(ctx->allocate_temp( DT_INT64, GetBatched1DShape(batch_size, spec_.int64_feature_dims[i]), t)); values_mat.emplace_back( t->shaped({batch_size, spec_.int64_feature_dims[i]})); // To be safe, we initialize int64 tensors to zero by default. values_mat.back().setZero(); } for (int i = 0; i < batch_size; ++i) { const Instance &instance = instances[i]; for (const Feature &feature : instance.feature()) { if (spec_.int64_features_set.contains(feature.name())) { int idx = spec_.int64_feature_name_to_index.at(feature.name()); if (spec_.int64_feature_dims[idx] == feature.int64_value_size()) { for (int j = 0; j < spec_.int64_feature_dims[idx]; ++j) { values_mat[idx](i, j) = feature.int64_value(j); } } else { // TODO(zouxuan) Set the default value to 0 for now. Xuan will make // an eventual fix for this later. for (int j = 0; j < spec_.int64_feature_dims[idx]; ++j) { values_mat[idx](i, j) = 0; } } } } } return Status::OK(); } Status FillStringFeatures(OpKernelContext *ctx, absl::Span instances, Output *output) { const int batch_size = instances.size(); std::vector::Matrix> values_mat; for (int i = 0; i < spec_.n_string_tensors; ++i) { output->tensors.emplace_back(); Tensor *t = &output->tensors.back(); TF_RETURN_IF_ERROR(ctx->allocate_temp( DT_STRING, GetBatched1DShape(batch_size, spec_.string_feature_dims[i]), t)); values_mat.emplace_back( t->shaped({batch_size, spec_.string_feature_dims[i]})); } for (int i = 0; i < batch_size; ++i) { const Instance &instance = instances[i]; for (const Feature &feature : instance.feature()) { if (spec_.string_features_set.contains(feature.name())) { int idx = spec_.string_feature_name_to_index.at(feature.name()); if (spec_.string_feature_dims[idx] == feature.bytes_value_size()) { for (int j = 0; j < spec_.string_feature_dims[idx]; ++j) { values_mat[idx](i, j) = feature.bytes_value(j); } } else { for (int j = 0; j < spec_.string_feature_dims[idx]; ++j) { values_mat[idx](i, j) = ""; } } } } } return Status::OK(); } Status ParseFloatTensors(OpKernelContext *ctx, absl::Span instances, Output *output) { const int batch_size = instances.size(); const auto *descriptor = ::idl::matrix::proto::LineId::GetDescriptor(); const auto *reflection = ::idl::matrix::proto::LineId::GetReflection(); for (size_t i = 0; i < spec_.misc_float_features.size(); ++i) { output->tensors.emplace_back(); Tensor *t = &output->tensors.back(); TF_RETURN_IF_ERROR(ctx->allocate_temp( DT_FLOAT, GetBatched1DShape(batch_size, spec_.misc_float_dims[i]), t)); auto mat = t->shaped({batch_size, spec_.misc_float_dims[i]}); // To be safe, we initialize float tensors to zero by default. mat.setZero(); const std::string &name = spec_.misc_float_features[i]; if (name == "label") { for (size_t j = 0; j < instances.size(); ++j) { const Instance &instance = instances[j]; int dim = spec_.misc_float_dims[i]; if (instance.label_size() < dim) { LOG_EVERY_N_SEC(ERROR, 60) << name << " Dim is smaller than expected " << instance.label_size() << " v.s. " << dim; dim = instance.label_size(); } for (int k = 0; k < dim; ++k) { mat(j, k) = instance.label(k); } } continue; } else if (name == "instance_weight") { int dim = spec_.misc_float_dims[i]; if (dim != 1) { LOG_EVERY_N_SEC(ERROR, 60) << name << " Dim is illegal, expected 1 " << " v.s. " << dim; dim = 1; } for (size_t j = 0; j < instances.size(); ++j) { const Instance &instance = instances[j]; mat(j, 0) = instance.has_instance_weight() ? instance.instance_weight() : 1.0f; } continue; } const auto *field = descriptor->FindFieldByName(name); if (field == nullptr) { return errors::NotFound(name + " not found in misc_float_features!"); } for (size_t j = 0; j < instances.size(); ++j) { const Instance &instance = instances[j]; if (field->is_repeated()) { int dim = spec_.misc_float_dims[i]; int field_size = reflection->FieldSize(instance.line_id(), field); if (field_size < dim) { LOG_EVERY_N_SEC(ERROR, 60) << name << " Dim is smaller than expected " << field_size << " v.s. " << dim; dim = field_size; } for (int k = 0; k < dim; ++k) { mat(j, k) = reflection->GetRepeatedFloat(instance.line_id(), field, k); } } else { mat(j, 0) = reflection->GetFloat(instance.line_id(), field); } } } return Status::OK(); } Status ParseInt64Tensors(OpKernelContext *ctx, absl::Span instances, Output *output) { const int batch_size = instances.size(); const auto *descriptor = ::idl::matrix::proto::LineId::GetDescriptor(); const auto *reflection = ::idl::matrix::proto::LineId::GetReflection(); for (size_t i = 0; i < spec_.misc_int64_features.size(); ++i) { output->tensors.emplace_back(); Tensor *t = &output->tensors.back(); TF_RETURN_IF_ERROR(ctx->allocate_temp( DT_INT64, GetBatched1DShape(batch_size, spec_.misc_int64_dims[i]), t)); auto mat = t->shaped({batch_size, spec_.misc_int64_dims[i]}); // To be safe, we initialize int64 tensors to zero by default. mat.setZero(); const std::string &name = spec_.misc_int64_features[i]; const auto *field = descriptor->FindFieldByName(name); if (field == nullptr) { return errors::NotFound(name + " not found in misc_int64_features!"); } if (field->is_repeated()) { for (int j = 0; j < batch_size; ++j) { int dim = spec_.misc_int64_dims[i]; const Instance &instance = instances[j]; const int field_size = reflection->FieldSize(instance.line_id(), field); if (field_size < dim) { LOG_EVERY_N_SEC(ERROR, 60) << name << " Dim is smaller than expected " << field_size << " v.s. " << dim; dim = field_size; } switch (field->cpp_type()) { case FieldDescriptor::CPPTYPE_INT32: for (int k = 0; k < dim; ++k) { mat(j, k) = reflection->GetRepeatedInt32(instance.line_id(), field, k); } break; case FieldDescriptor::CPPTYPE_UINT32: for (int k = 0; k < dim; ++k) { mat(j, k) = reflection->GetRepeatedUInt32(instance.line_id(), field, k); } break; case FieldDescriptor::CPPTYPE_INT64: for (int k = 0; k < dim; ++k) { mat(j, k) = reflection->GetRepeatedInt64(instance.line_id(), field, k); } break; case FieldDescriptor::CPPTYPE_UINT64: for (int k = 0; k < dim; ++k) { mat(j, k) = reflection->GetRepeatedUInt64(instance.line_id(), field, k); } break; default: return errors::InvalidArgument( name, " Data type not match, only int32/int64/float32 supported."); } } } else { for (int j = 0; j < batch_size; ++j) { const Instance &instance = instances[j]; switch (field->cpp_type()) { case FieldDescriptor::CPPTYPE_INT32: mat(j, 0) = reflection->GetInt32(instance.line_id(), field); break; case FieldDescriptor::CPPTYPE_UINT32: mat(j, 0) = reflection->GetUInt32(instance.line_id(), field); break; case FieldDescriptor::CPPTYPE_INT64: mat(j, 0) = reflection->GetInt64(instance.line_id(), field); break; case FieldDescriptor::CPPTYPE_UINT64: mat(j, 0) = reflection->GetUInt64(instance.line_id(), field); break; default: return errors::InvalidArgument( name, " Data type not match, only int32/int64/float32 supported."); } } } } return Status::OK(); } Status ParseStringTensors(OpKernelContext *ctx, absl::Span instances, Output *output) { const int batch_size = instances.size(); const auto *descriptor = ::idl::matrix::proto::LineId::GetDescriptor(); const auto *reflection = ::idl::matrix::proto::LineId::GetReflection(); for (size_t i = 0; i < spec_.misc_string_features.size(); ++i) { output->tensors.emplace_back(); Tensor *t = &output->tensors.back(); TF_RETURN_IF_ERROR(ctx->allocate_temp( DT_STRING, GetBatched1DShape(batch_size, spec_.misc_string_dims[i]), t)); auto mat = t->shaped({batch_size, spec_.misc_string_dims[i]}); const std::string &name = spec_.misc_string_features[i]; const auto *field = descriptor->FindFieldByName(name); if (field == nullptr) { return errors::NotFound(name + " not found in misc_string_features!"); } for (size_t j = 0; j < instances.size(); ++j) { const Instance &instance = instances[j]; if (field->is_repeated()) { int dim = spec_.misc_string_dims[i]; int field_size = reflection->FieldSize(instance.line_id(), field); if (field_size < dim) { LOG_EVERY_N_SEC(ERROR, 60) << name << " Dim is smaller than expected " << field_size << " v.s. " << dim; dim = field_size; } for (int k = 0; k < dim; ++k) { mat(j, k) = reflection->GetRepeatedString(instance.line_id(), field, k); } } else { mat(j, 0) = reflection->GetString(instance.line_id(), field); } } } return Status::OK(); } TensorShape GetBatched1DShape(int batch_size, int64 dim) { if (spec_.collapse_batch_dim) { return {dim}; } else { return {batch_size, dim}; } } InstanceParserSpec spec_; std::unique_ptr ragged_tensor_processor_; std::unique_ptr ue_compress_; }; InstanceParser::InstanceParser(const InstanceParserConfig &config) : impl_(std::make_unique(config)) {} InstanceParser::~InstanceParser() {} Status InstanceParser::Init() { return impl_->Init(); } Status InstanceParser::Parse(OpKernelContext *ctx, absl::Span instances, Output *output) const { return impl_->Parse(ctx, instances, output); } } // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/data/training_instance/cc/parse_instance_lib.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_NATIVE_TRAINING_DATA_TRAINING_INSTANCE_CC_PARSE_INSTANCE_LIB_H_ #define MONOLITH_NATIVE_TRAINING_DATA_TRAINING_INSTANCE_CC_PARSE_INSTANCE_LIB_H_ #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/types/span.h" #include "idl/matrix/proto/proto_parser.pb.h" #include "tensorflow/core/framework/op_kernel.h" namespace tensorflow { namespace monolith_tf { // The config to instantiate ParseInstanceSpec. struct InstanceParserConfig { // Fid features. std::vector fidv1_features; std::vector fidv2_features; enum FidOutputType { // Each fid will have its own ragged tensor. REGULAR, // All fids will be outputted as a single ragged tensor. // Only available when collapse_batch_dim == True. CONCAT, }; FidOutputType fid_output_type = REGULAR; // Float features. std::vector float_features; std::vector float_feature_dims; // Int64 features. std::vector int64_features; std::vector int64_feature_dims; // String features. std::vector string_features; std::vector string_feature_dims; // LineId related features, including labels and others. std::vector misc_float_features; std::vector misc_float_dims; std::vector misc_int64_features; std::vector misc_int64_dims; std::vector misc_string_features; std::vector misc_string_dims; bool collapse_batch_dim = false; }; // A parser that is able to parse instance. // Must call Init() before used. class InstanceParser { public: explicit InstanceParser(const InstanceParserConfig &config); ~InstanceParser(); Status Init(); struct Output { std::vector tensors; }; Status Parse(OpKernelContext *ctx, absl::Span instances, Output *tensors) const; private: class Impl; std::unique_ptr impl_; }; } // namespace monolith_tf } // namespace tensorflow #endif // MONOLITH_NATIVE_TRAINING_DATA_TRAINING_INSTANCE_CC_PARSE_INSTANCE_LIB_H_ ================================================ FILE: monolith/native_training/data/training_instance/cc/parse_instance_ops.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/shape_inference.h" namespace tensorflow { namespace monolith_tf { namespace { bool IsUnsetHandle(shape_inference::DimensionHandle handle) { return handle.Handle() == 0; } shape_inference::ShapeHandle GetBatched1D( shape_inference::InferenceContext *ctx, shape_inference::DimensionHandle batch_size, int dim) { if (IsUnsetHandle(batch_size)) { return ctx->Vector(dim); } else { return ctx->Matrix(batch_size, dim); } } Status SetParseInstanceShape(shape_inference::InferenceContext *ctx, shape_inference::DimensionHandle batch_size) { int offset = 0; // Ragged tensor int n; TF_RETURN_IF_ERROR(ctx->GetAttr("N", &n)); for (int i = 0; i < n; ++i) { if (IsUnsetHandle(batch_size)) { ctx->set_output(offset + i, ctx->Vector(2)); continue; } int batch_size_value = ctx->Value(batch_size); if (batch_size_value == shape_inference::InferenceContext::kUnknownDim) { ctx->set_output(offset + i, ctx->Vector(ctx->UnknownDim())); } else { ctx->set_output(offset + i, ctx->Vector(batch_size_value + 1)); } } offset += n; for (int i = 0; i < n; ++i) { ctx->set_output(offset + i, ctx->Vector(ctx->UnknownDim())); } offset += n; // float tensor int m; TF_RETURN_IF_ERROR(ctx->GetAttr("M", &m)); std::vector float_feature_dims; TF_RETURN_IF_ERROR(ctx->GetAttr("float_feature_dims", &float_feature_dims)); for (int i = 0; i < m; ++i) { ctx->set_output(offset + i, GetBatched1D(ctx, batch_size, float_feature_dims[i])); } offset += m; // int64 tensor int o; TF_RETURN_IF_ERROR(ctx->GetAttr("O", &o)); std::vector int64_feature_dims; TF_RETURN_IF_ERROR(ctx->GetAttr("int64_feature_dims", &int64_feature_dims)); for (int i = 0; i < o; ++i) { ctx->set_output(offset + i, GetBatched1D(ctx, batch_size, int64_feature_dims[i])); } offset += o; // string tensor int p; TF_RETURN_IF_ERROR(ctx->GetAttr("P", &p)); std::vector string_feature_dims; TF_RETURN_IF_ERROR(ctx->GetAttr("string_feature_dims", &string_feature_dims)); for (int i = 0; i < p; ++i) { ctx->set_output(offset + i, GetBatched1D(ctx, batch_size, string_feature_dims[i])); } offset += p; // misc_feature_float int q; TF_RETURN_IF_ERROR(ctx->GetAttr("Q", &q)); std::vector misc_float_dims; TF_RETURN_IF_ERROR(ctx->GetAttr("misc_float_dims", &misc_float_dims)); for (int i = 0; i < q; ++i) { ctx->set_output(offset + i, GetBatched1D(ctx, batch_size, misc_float_dims[i])); } offset += q; // misc_feature_int64 int r; TF_RETURN_IF_ERROR(ctx->GetAttr("R", &r)); std::vector misc_int64_dims; TF_RETURN_IF_ERROR(ctx->GetAttr("misc_int64_dims", &misc_int64_dims)); for (int i = 0; i < r; ++i) { ctx->set_output(offset + i, GetBatched1D(ctx, batch_size, misc_int64_dims[i])); } offset += r; // misc_feature_string int s; TF_RETURN_IF_ERROR(ctx->GetAttr("S", &s)); std::vector misc_string_dims; TF_RETURN_IF_ERROR(ctx->GetAttr("misc_string_dims", &misc_string_dims)); for (int i = 0; i < s; ++i) { ctx->set_output(offset + i, GetBatched1D(ctx, batch_size, misc_string_dims[i])); } offset += s; return Status::OK(); } // We use fid_features for FIDV1 key and str_features for FIDv2 keys. // In fid v1, we use the slot whitelist, and in fid v2, we use the // feature_name // whitelist. REGISTER_OP("MonolithParseInstances") .Input("serialized: string") .Output("ragged_feature_splits: N * int64") .Output("ragged_feature_values: N * int64") .Output("float_feature_values: M * float32") .Output("int64_feature_values: O * int64") .Output("string_feature_values: P * string") .Output("misc_float_feature_values: Q * float32") .Output("misc_int64_feature_values: R * int64") .Output("misc_string_feature_values: S * string") .Attr("N: int >= 0") .Attr("M: int >= 0") .Attr("O: int >= 0") .Attr("P: int >= 0") .Attr("Q: int >= 0") .Attr("R: int >= 0") .Attr("S: int >= 0") .Attr("fidv1_features: list(int)") .Attr("fidv2_features: list(string)") .Attr("float_features: list(string)") .Attr("float_feature_dims: list(int)") .Attr("int64_features: list(string)") .Attr("int64_feature_dims: list(int)") .Attr("string_features: list(string)") .Attr("string_feature_dims: list(int)") .Attr("misc_float_features: list(string)") .Attr("misc_float_dims: list(int)") .Attr("misc_int64_features: list(string)") .Attr("misc_int64_dims: list(int)") .Attr("misc_string_features: list(string)") .Attr("misc_string_dims: list(int)") .SetDoNotOptimize() // Source dataset ops must disable constant folding. .SetShapeFn([](shape_inference::InferenceContext *ctx) { shape_inference::DimensionHandle batch_size = ctx->Dim(ctx->input(0), 0); return SetParseInstanceShape(ctx, batch_size); }); REGISTER_OP("MonolithRawParseInstance") .Attr("T: list(type)") .Input("serialized: string") .Output("tensors : T") .Attr("fidv1_features: list(int) = []") .Attr("fidv2_features: list(string) = []") .Attr("float_features: list(string) = []") .Attr("float_feature_dims: list(int) = []") .Attr("int64_features: list(string) = []") .Attr("int64_feature_dims: list(int) = []") .Attr("string_features: list(string) = []") .Attr("string_feature_dims: list(int) = []") .Attr("misc_float_features: list(string) = []") .Attr("misc_float_dims: list(int) = []") .Attr("misc_int64_features: list(string) = []") .Attr("misc_int64_dims: list(int) = []") .Attr("misc_string_features: list(string) = []") .Attr("misc_string_dims: list(int) = []") .Attr("collapse_batch_dim: bool = false") .Attr("fid_output_type: {'REGULAR', 'CONCAT'} = 'REGULAR'") .SetDoNotOptimize() // Source dataset ops must disable constant folding. .SetShapeFn([](shape_inference::InferenceContext *ctx) { return Status::OK(); }); } // namespace } // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/data/training_instance/cc/pb_variant.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "absl/strings/str_cat.h" #include "monolith/native_training/data/training_instance/cc/pb_variant.h" namespace tensorflow { using Example = ::monolith::io::proto::Example; using ExampleBatch = ::monolith::io::proto::ExampleBatch; using Instance = ::parser::proto::Instance; template <> std::string TypeNameVariant(const Example &value) { return "Example"; } template <> std::string DebugStringVariant(const Example &value) { return "Example DebugString"; } template <> bool DecodeVariant(std::string *buf, Example *value) { std::cout << "DecodeVariant - 1" << std::endl; value->ParseFromArray(buf->data(), buf->size()); return true; } template <> void EncodeVariant(const Example &value, std::string *buf) { value.SerializeToString(buf); } template <> bool DecodeVariant(VariantTensorData *data, Example *value) { return false; } template <> void EncodeVariant(const Example &value, VariantTensorData *data) {} template <> std::string TypeNameVariant(const ExampleBatch &value) { return "ExampleBatch"; } template <> std::string DebugStringVariant(const ExampleBatch &value) { return "ExampleBatch DebugString"; } template <> bool DecodeVariant(std::string *buf, ExampleBatch *value) { std::cout << "DecodeVariant - 1" << std::endl; value->ParseFromArray(buf->data(), buf->size()); return true; } template <> void EncodeVariant(const ExampleBatch &value, std::string *buf) { value.SerializeToString(buf); } template <> bool DecodeVariant(VariantTensorData *data, ExampleBatch *value) { return false; } template <> void EncodeVariant(const ExampleBatch &value, VariantTensorData *data) {} template <> std::string TypeNameVariant(const Instance &value) { return "Instance"; } template <> std::string DebugStringVariant(const Instance &value) { return "Instance DebugString"; } template <> bool DecodeVariant(std::string *buf, Instance *value) { std::cout << "DecodeVariant - 1" << std::endl; value->ParseFromArray(buf->data(), buf->size()); return true; } template <> void EncodeVariant(const Instance &value, std::string *buf) { value.SerializeToString(buf); } template <> bool DecodeVariant(VariantTensorData *data, Instance *value) { return false; } template <> void EncodeVariant(const Instance &value, VariantTensorData *data) {} } // namespace tensorflow ================================================ FILE: monolith/native_training/data/training_instance/cc/pb_variant.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_NATIVE_TRAINING_DATA_TRAINING_INSTANCE_CC_PB_VARIANT_H_ #define MONOLITH_NATIVE_TRAINING_DATA_TRAINING_INSTANCE_CC_PB_VARIANT_H_ #include "idl/matrix/proto/example.pb.h" #include "idl/matrix/proto/proto_parser.pb.h" #include "tensorflow/core/framework/variant.h" namespace tensorflow { template <> std::string TypeNameVariant<::monolith::io::proto::Example>( const ::monolith::io::proto::Example &value); template <> std::string DebugStringVariant<::monolith::io::proto::Example>( const ::monolith::io::proto::Example &value); template <> bool DecodeVariant<::monolith::io::proto::Example>( std::string *buf, ::monolith::io::proto::Example *value); template <> void EncodeVariant<::monolith::io::proto::Example>( const ::monolith::io::proto::Example &value, std::string *buf); template <> bool DecodeVariant<::monolith::io::proto::Example>( VariantTensorData *data, ::monolith::io::proto::Example *value); template <> void EncodeVariant<::monolith::io::proto::Example>( const ::monolith::io::proto::Example &value, VariantTensorData *data); template <> std::string TypeNameVariant<::monolith::io::proto::ExampleBatch>( const ::monolith::io::proto::ExampleBatch &value); template <> std::string DebugStringVariant<::monolith::io::proto::ExampleBatch>( const ::monolith::io::proto::ExampleBatch &value); template <> bool DecodeVariant<::monolith::io::proto::ExampleBatch>( std::string *buf, ::monolith::io::proto::ExampleBatch *value); template <> void EncodeVariant<::monolith::io::proto::ExampleBatch>( const ::monolith::io::proto::ExampleBatch &value, std::string *buf); template <> bool DecodeVariant<::monolith::io::proto::ExampleBatch>( VariantTensorData *data, ::monolith::io::proto::ExampleBatch *value); template <> void EncodeVariant<::monolith::io::proto::ExampleBatch>( const ::monolith::io::proto::ExampleBatch &value, VariantTensorData *data); template <> std::string TypeNameVariant<::parser::proto::Instance>( const ::parser::proto::Instance &value); template <> std::string DebugStringVariant<::parser::proto::Instance>( const ::parser::proto::Instance &value); template <> bool DecodeVariant<::parser::proto::Instance>(std::string *buf, ::parser::proto::Instance *value); template <> void EncodeVariant<::parser::proto::Instance>( const ::parser::proto::Instance &value, std::string *buf); template <> bool DecodeVariant<::parser::proto::Instance>(VariantTensorData *data, ::parser::proto::Instance *value); template <> void EncodeVariant<::parser::proto::Instance>( const ::parser::proto::Instance &value, VariantTensorData *data); } // namespace tensorflow #endif // MONOLITH_NATIVE_TRAINING_DATA_TRAINING_INSTANCE_CC_PB_VARIANT_H_ ================================================ FILE: monolith/native_training/data/training_instance/cc/reader_util.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/data/training_instance/cc/reader_util.h" const size_t FALLBACK_RESERVE_VALUE = 0xfefefefe; namespace tensorflow { namespace monolith_tf { void to_json(nlohmann::json& j, const FeatureNameMapperIdInfo& info) { j["id"] = info.id; j["sorted_id"] = info.sorted_id; } } // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/data/training_instance/cc/reader_util.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_NATIVE_TRAINING_DATA_TRAINING_INSTANCE_CC_READER_UTIL_H_ #define MONOLITH_NATIVE_TRAINING_DATA_TRAINING_INSTANCE_CC_READER_UTIL_H_ #include #include #include #include #include #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/strings/str_format.h" #include "absl/synchronization/mutex.h" #include "tensorflow/core/platform/logging.h" #include "third_party/nlohmann/json.hpp" constexpr uint64_t fid_v1_mask = (1LL << 54) - 1; constexpr uint64_t fid_v2_mask = (1LL << 48) - 1; inline int slot_id_v1(uint64_t fid) { return fid >> 54; } inline int slot_id_v2(uint64_t fid) { return (fid >> 48) & (((int64_t)1 << 15) - 1); } inline uint64_t convert_fid_v1_to_v2(int slot, uint64_t fid) { uint64_t slot_long = slot; return ((fid & fid_v2_mask) | slot_long << 48); } inline uint64_t convert_fid_v1_to_v2(uint64_t fid) { uint64_t slot_long = fid >> 54; return ((fid & fid_v2_mask) | slot_long << 48); } inline uint64_t switch_slot_v1(uint64_t fid, uint64_t slot) { return (slot << 54) | (fid & fid_v1_mask); } inline uint64_t switch_slot_v2(uint64_t fid, uint64_t slot) { return (slot << 48) | (fid & fid_v2_mask); } inline int get_max_slot_number() { return 1 << 15; } namespace tensorflow { namespace monolith_tf { inline int64_t GetFidV1(int slot, int64_t signautre) { return ((uint64_t)slot << 54) | (signautre & fid_v1_mask); } inline int64_t GetFidV2(int slot, int64_t signature) { return ((uint64_t)slot << 48) | (signature & fid_v2_mask); } class FeaturePruningByteCounter { public: ~FeaturePruningByteCounter() { LOG(INFO) << absl::StrFormat("Finally %s", DebugString()); } void AddByteSize(uint64_t byte_size) { byte_size_ += byte_size; } void AddByteSizePruned(uint64_t byte_size) { byte_size_pruned_ += byte_size; } std::string DebugString() const { return absl::StrFormat( "read: %llu bytes (%s), after pruning: %llu bytes (%s)", byte_size_, PrettyBytes(byte_size_), byte_size_pruned_, PrettyBytes(byte_size_pruned_)); } private: static std::string PrettyBytes(uint64_t bytes) { const std::vector suffixes = {"B", "KB", "MB", "GB", "TB", "PB", "EB"}; int64_t s = 0; auto count = static_cast(bytes); while (count >= 1024 && s < suffixes.size()) { s++; count /= 1024; } if (count - std::floor(count) == 0.0) { return absl::StrFormat("%llu %s", static_cast(count), suffixes[s]); } else { return absl::StrFormat("%.2f %s", count, suffixes[s]); } } uint64_t byte_size_; uint64_t byte_size_pruned_; }; struct FeatureNameMapperIdInfo { int32_t id; int32_t sorted_id; }; void to_json(nlohmann::json& j, const FeatureNameMapperIdInfo& info); class FeatureNameMapper { public: FeatureNameMapper(FeatureNameMapper const&) = delete; void operator=(FeatureNameMapper const&) = delete; void TurnOn() { turned_on_ = true; } bool IsAvailable() const { return turned_on_ && initialized_; } bool RegisterValidIds(const std::vector>& valid_ids) { absl::WriterMutexLock l(&mu_); registered_feature_id_set_.insert(valid_ids.begin(), valid_ids.end()); if (id_to_name_.empty()) { return true; } absl::flat_hash_set> invalid_ids; for (std::pair p : registered_feature_id_set_) { if (!id_to_name_.contains(p.first) && !id_to_name_.contains(p.second)) { invalid_ids.insert(p); } } if (!invalid_ids.empty()) { nlohmann::json j; j["invalid_ids"] = invalid_ids; LOG(ERROR) << "ResisterValidIds: " << j.dump(); return false; } for (std::pair p : registered_feature_id_set_) { std::vector ids = {p.first, p.second}; for (int id : ids) { auto it = id_to_name_.find(id); if (it != id_to_name_.end()) { for (const std::string& name : it->second) { valid_id_to_name_[it->first].push_back(name); auto sorted_id = name_to_id_.at(name).sorted_id; valid_name_to_id_.insert({name, {it->first, sorted_id}}); } } } } return true; } bool RegisterValidNames(const std::vector& valid_names) { absl::WriterMutexLock l(&mu_); registered_feature_name_set_.insert(valid_names.begin(), valid_names.end()); if (name_to_id_.empty()) { return true; } std::unordered_set invalid_names; for (const std::string& name : registered_feature_name_set_) { if (!name_to_id_.contains(name)) { invalid_names.insert(name); } } if (!invalid_names.empty()) { nlohmann::json j; j["invalid_names"] = invalid_names; LOG(ERROR) << "ResisterValidNames: " << j.dump(); return false; } for (const std::string& name : registered_feature_name_set_) { auto it = name_to_id_.find(name); valid_name_to_id_.insert({it->first, it->second}); valid_id_to_name_[it->second.id].push_back(it->first); } return true; } bool SetMapping(const absl::flat_hash_map& name_to_id, const absl::flat_hash_map>& id_to_name) { absl::WriterMutexLock l(&mu_); int sorted_id = 0; for (auto& iter : name_to_id) { name_to_id_[iter.first] = {iter.second, ++sorted_id}; } id_to_name_ = id_to_name; if (name_to_id_.empty()) { return true; } std::unordered_set invalid_names; for (const std::string& name : registered_feature_name_set_) { if (!name_to_id_.contains(name)) { invalid_names.insert(name); } } absl::flat_hash_set> invalid_ids; for (std::pair p : registered_feature_id_set_) { if (!id_to_name_.contains(p.first) && !id_to_name_.contains(p.second)) { invalid_ids.insert(p); } } if (!invalid_names.empty() || !invalid_ids.empty()) { name_to_id_.clear(); id_to_name_.clear(); nlohmann::json j; j["invalid_names"] = invalid_names; j["invalid_ids"] = invalid_ids; LOG(ERROR) << "SetMapping: " << j.dump(); return false; } for (std::pair p : registered_feature_id_set_) { std::vector ids = {p.first, p.second}; for (int id : ids) { auto it = id_to_name_.find(id); if (it != id_to_name_.end()) { for (const std::string& name : it->second) { valid_id_to_name_[it->first].push_back(name); auto sorted_id = name_to_id_.at(name).sorted_id; valid_name_to_id_.insert({name, {it->first, sorted_id}}); } } } } for (const std::string& name : registered_feature_name_set_) { auto it = name_to_id_.find(name); valid_name_to_id_.insert({it->first, it->second}); valid_id_to_name_[it->second.id].push_back(it->first); } initialized_ = true; return true; } bool GetIdByName(const std::string& name, int32_t* id, int32_t* sorted_id = nullptr) { absl::ReaderMutexLock l(&mu_); auto it = valid_name_to_id_.find(name); if (it == valid_name_to_id_.end()) { return false; } *id = it->second.id; if (sorted_id) { *sorted_id = it->second.sorted_id; } return true; } std::string DebugString() { absl::ReaderMutexLock l(&mu_); nlohmann::json j = valid_name_to_id_; return j.dump(2); } FeatureNameMapper() : turned_on_(false), initialized_(false) {} private: std::atomic_bool turned_on_; std::atomic_bool initialized_; absl::flat_hash_map> id_to_name_; absl::flat_hash_map name_to_id_; absl::flat_hash_map> valid_id_to_name_; absl::flat_hash_map valid_name_to_id_; absl::flat_hash_set registered_feature_name_set_; absl::flat_hash_set> registered_feature_id_set_; absl::Mutex mu_; }; } // namespace monolith_tf } // namespace tensorflow #endif // MONOLITH_NATIVE_TRAINING_DATA_TRAINING_INSTANCE_CC_READER_UTIL_H_ ================================================ FILE: monolith/native_training/data/training_instance/cc/reader_util_test.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/data/training_instance/cc/reader_util.h" #include #include "gtest/gtest.h" namespace tensorflow { namespace monolith_tf { const uint32_t kSlotBits = 15; const uint32_t kFeatureV2Bits = 48; TEST(ReaderUtilTest, GetSlotID) { int64_t slot_id = 123; int64_t fid = slot_id << kFeatureV2Bits; int64_t slot_id_ret = slot_id_v2(fid); ASSERT_TRUE(slot_id_ret == slot_id); slot_id = (1 << kSlotBits) - 1; fid = slot_id << kFeatureV2Bits; slot_id_ret = slot_id_v2(fid); ASSERT_TRUE(slot_id_ret == slot_id); } TEST(ReaderUtilTest, FeatureNameMapperNormalCase1) { auto mapper = std::make_unique(); ASSERT_FALSE(mapper->IsAvailable()); mapper->TurnOn(); ASSERT_FALSE(mapper->IsAvailable()); std::string name1 = "slot_1", name2 = "slot_2", name3 = "slot_3", name4 = "slot_4"; int id1 = -1, id2 = -1, id3 = -1, id4 = -1; ASSERT_TRUE(mapper->RegisterValidNames({name1, name2})); ASSERT_TRUE(mapper->RegisterValidIds({{3, 10003}})); absl::flat_hash_map m; absl::flat_hash_map> m2; m.insert({name1, 1}); m.insert({name2, 2}); m.insert({name3, 3}); m.insert({name4, 4}); m2.insert({1, {name1}}); m2.insert({2, {name2}}); m2.insert({3, {name3}}); m2.insert({4, {name4}}); ASSERT_TRUE(mapper->SetMapping(m, m2)); ASSERT_TRUE(mapper->GetIdByName(name1, &id1)); ASSERT_TRUE(mapper->GetIdByName(name2, &id2)); ASSERT_TRUE(mapper->GetIdByName(name3, &id3)); ASSERT_FALSE(mapper->GetIdByName(name4, &id4)); ASSERT_EQ(id1, 1); ASSERT_EQ(id2, 2); ASSERT_EQ(id3, 3); LOG(INFO) << mapper->DebugString(); } TEST(ReaderUtilTest, FeatureNameMapperNormalCase2) { auto mapper = std::make_unique(); ASSERT_FALSE(mapper->IsAvailable()); mapper->TurnOn(); ASSERT_FALSE(mapper->IsAvailable()); std::string name1 = "slot_1", name2 = "slot_2", name3 = "slot_3", name4 = "slot_4"; int id1 = -1, id2 = -1, id3 = -1, id4 = -1; absl::flat_hash_map m; absl::flat_hash_map> m2; m.insert({name1, 1}); m.insert({name2, 2}); m.insert({name3, 3}); m.insert({name4, 4}); m2.insert({1, {name1}}); m2.insert({2, {name2}}); m2.insert({3, {name3}}); m2.insert({4, {name4}}); ASSERT_TRUE(mapper->SetMapping(m, m2)); ASSERT_FALSE(mapper->GetIdByName(name1, &id1)); ASSERT_FALSE(mapper->GetIdByName(name2, &id2)); ASSERT_FALSE(mapper->GetIdByName(name3, &id3)); ASSERT_FALSE(mapper->GetIdByName(name4, &id4)); ASSERT_TRUE(mapper->RegisterValidNames({name1, name2})); ASSERT_TRUE(mapper->RegisterValidIds({{3, 10003}})); ASSERT_TRUE(mapper->GetIdByName(name1, &id1)); ASSERT_TRUE(mapper->GetIdByName(name2, &id2)); ASSERT_TRUE(mapper->GetIdByName(name3, &id3)); ASSERT_FALSE(mapper->GetIdByName(name4, &id4)); ASSERT_EQ(id1, 1); ASSERT_EQ(id2, 2); ASSERT_EQ(id3, 3); LOG(INFO) << mapper->DebugString(); } TEST(ReaderUtilTest, FeatureNameMapperCornerCase1) { auto mapper = std::make_unique(); ASSERT_FALSE(mapper->IsAvailable()); mapper->TurnOn(); ASSERT_FALSE(mapper->IsAvailable()); std::string name1 = "slot_1", name2 = "slot_2", name3 = "slot_3", name4 = "slot_4"; int id1 = -1, id2 = -1, id3 = -1, id4 = -1; ASSERT_TRUE(mapper->RegisterValidNames({name1, name2})); absl::flat_hash_map m; absl::flat_hash_map> m2; m.insert({name1, 1}); m.insert({name2, 2}); m.insert({name3, 3}); m.insert({name4, 4}); m2.insert({1, {name1}}); m2.insert({2, {name2}}); m2.insert({3, {name3}}); m2.insert({4, {name4}}); ASSERT_TRUE(mapper->SetMapping(m, m2)); ASSERT_TRUE(mapper->RegisterValidIds({{3, 10003}})); ASSERT_TRUE(mapper->GetIdByName(name1, &id1)); ASSERT_TRUE(mapper->GetIdByName(name2, &id2)); ASSERT_TRUE(mapper->GetIdByName(name3, &id3)); ASSERT_FALSE(mapper->GetIdByName(name4, &id4)); ASSERT_EQ(id1, 1); ASSERT_EQ(id2, 2); ASSERT_EQ(id3, 3); LOG(INFO) << mapper->DebugString(); } TEST(ReaderUtilTest, FeatureNameMapperCornerCase2) { auto mapper = std::make_unique(); ASSERT_FALSE(mapper->IsAvailable()); mapper->TurnOn(); ASSERT_FALSE(mapper->IsAvailable()); std::string name1 = "slot_1", name2 = "slot_2", name3 = "slot_3", name4 = "slot_4"; int id1 = -1, id2 = -1, id3 = -1, id4 = -1; ASSERT_TRUE(mapper->RegisterValidIds({{3, 10003}})); absl::flat_hash_map m; absl::flat_hash_map> m2; m.insert({name1, 1}); m.insert({name2, 2}); m.insert({name3, 3}); m.insert({name4, 4}); m2.insert({1, {name1}}); m2.insert({2, {name2}}); m2.insert({3, {name3}}); m2.insert({4, {name4}}); ASSERT_TRUE(mapper->SetMapping(m, m2)); ASSERT_TRUE(mapper->RegisterValidNames({name1, name2})); ASSERT_TRUE(mapper->GetIdByName(name1, &id1)); ASSERT_TRUE(mapper->GetIdByName(name2, &id2)); ASSERT_TRUE(mapper->GetIdByName(name3, &id3)); ASSERT_FALSE(mapper->GetIdByName(name4, &id4)); ASSERT_EQ(id1, 1); ASSERT_EQ(id2, 2); ASSERT_EQ(id3, 3); LOG(INFO) << mapper->DebugString(); } TEST(ReaderUtilTest, FeatureNameMapperCornerCase3) { auto mapper = std::make_unique(); ASSERT_FALSE(mapper->IsAvailable()); mapper->TurnOn(); ASSERT_FALSE(mapper->IsAvailable()); std::string name1 = "slot_1"; ASSERT_TRUE(mapper->RegisterValidIds({{2, 10002}})); absl::flat_hash_map m; absl::flat_hash_map> m2; m.insert({name1, 1}); m2.insert({1, {name1}}); ASSERT_FALSE(mapper->SetMapping(m, m2)); LOG(INFO) << mapper->DebugString(); } TEST(ReaderUtilTest, FeatureNameMapperCornerCase4) { auto mapper = std::make_unique(); ASSERT_FALSE(mapper->IsAvailable()); mapper->TurnOn(); ASSERT_FALSE(mapper->IsAvailable()); std::string name1 = "slot_1"; absl::flat_hash_map m; absl::flat_hash_map> m2; m.insert({name1, 1}); m2.insert({1, {name1}}); ASSERT_TRUE(mapper->SetMapping(m, m2)); ASSERT_FALSE(mapper->RegisterValidIds({{2, 10002}})); LOG(INFO) << mapper->DebugString(); } } // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/data/training_instance/cc/snappy_inputbuffer.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include #include "glog/logging.h" #include "monolith/native_training/data/training_instance/cc/snappy_inputbuffer.h" namespace tensorflow { namespace io { ByteSnappyInputBuffer::ByteSnappyInputBuffer( RandomAccessFile* file, size_t input_buffer_bytes, // size of input_buffer_ size_t output_buffer_bytes // size of output_buffer_ ) : file_(file), input_buffer_capacity_(input_buffer_bytes), output_buffer_capacity_(output_buffer_bytes), bytes_read_(0) { cached_mem_pool_ = CachedMemPool::init(input_buffer_bytes); input_buffer_ = cached_mem_pool_->allocate(); output_buffer_ = cached_mem_pool_->allocate(); next_in_ = input_buffer_.get(); LOG_IF(ERROR, !ReadFromFile().ok()) << "Failed to read ahead from HDFS."; } ByteSnappyInputBuffer::~ByteSnappyInputBuffer() { cached_mem_pool_->deallocate(output_buffer_); cached_mem_pool_->deallocate(input_buffer_); } Status ByteSnappyInputBuffer::ReadNBytes(int64 bytes_to_read, tstring* result) { result->clear(); result->resize_uninitialized(bytes_to_read); char* result_ptr = result->mdata(); // Read as many bytes as possible from cache. size_t bytes_read = ReadBytesFromCache(bytes_to_read, result_ptr); bytes_to_read -= bytes_read; result_ptr += bytes_read; while (bytes_to_read > 0) { // Now that the cache is empty we need to inflate more data. TF_RETURN_IF_ERROR(Inflate()); bytes_read = ReadBytesFromCache(bytes_to_read, result_ptr); bytes_to_read -= bytes_read; result_ptr += bytes_read; } return Status::OK(); } int64 ByteSnappyInputBuffer::Tell() const { return bytes_read_; } Status ByteSnappyInputBuffer::Reset() { file_pos_ = 0; avail_in_ = 0; avail_out_ = 0; next_in_ = input_buffer_.get(); bytes_read_ = 0; return Status::OK(); } size_t ByteSnappyInputBuffer::ReadBytesFromCache(size_t bytes_to_read, char* result_ptr) { size_t can_read_bytes = std::min(bytes_to_read, avail_out_); if (can_read_bytes > 0) { memcpy(result_ptr, next_out_, can_read_bytes); next_out_ += can_read_bytes; avail_out_ -= can_read_bytes; bytes_read_ += can_read_bytes; } return can_read_bytes; } Status ByteSnappyInputBuffer::Inflate() { // Output buffer must have been cleared before uncompressing more input. DCHECK_EQ(avail_out_, 0); // Read origin length of a block. if (block_length_ == 0) { TF_RETURN_IF_ERROR(ReadBlockLength(&block_length_)); // Output buffer must be large enough to fit the uncompressed block. DCHECK_GE(output_buffer_capacity_, block_length_); } // Read length of a compressed chunk. uint32 compressed_chunk_length = 0; TF_RETURN_IF_ERROR(ReadBlockLength(&compressed_chunk_length)); // Read bytes to buffer a chunk if (avail_in_ < compressed_chunk_length) { TF_RETURN_IF_ERROR(ReadFromFile()); if (avail_in_ < compressed_chunk_length) { if (compressed_chunk_length > input_buffer_capacity_) { // TODO(gaofei.gf): increase buffer size dynamically return errors::ResourceExhausted( "Input buffer(size: ", input_buffer_capacity_, " bytes) too small. Should be larger ", "than ", compressed_chunk_length, " bytes."); } else { return errors::OutOfRange("EOF reached with incomplete tail bytes."); } } } // Uncompress a chunk size_t chunk_length = 0; if (!port::Snappy_GetUncompressedLength(next_in_, compressed_chunk_length, &chunk_length)) { return errors::DataLoss("Snappy_GetUncompressedLength failed"); } next_out_ = output_buffer_.get(); if (!port::Snappy_Uncompress(next_in_, compressed_chunk_length, next_out_)) { return errors::DataLoss("Snappy_Uncompress failed"); } next_in_ += compressed_chunk_length; avail_in_ -= compressed_chunk_length; avail_out_ += chunk_length; uncompressed_bytes_in_block_ += chunk_length; // Check a block is uncompressed if (uncompressed_bytes_in_block_ == block_length_) { block_length_ = 0; uncompressed_bytes_in_block_ = 0; } return Status::OK(); } Status ByteSnappyInputBuffer::ReadBlockLength(uint32* length) { *length = 0; size_t bytes_to_read = 4; while (bytes_to_read > 0) { if (avail_in_ == 0) { TF_RETURN_IF_ERROR(ReadFromFile()); } size_t readable = std::min(bytes_to_read, avail_in_); for (int i = 0; i < readable; i++) { // The "unsigned char" type cast is intentional to avoid implicit type // casting of the signed char to unsigned int during bitwise OR which // causes weird overflow errors. // Little endian *length = (*length << 8) | static_cast(next_in_[0]); bytes_to_read--; next_in_++; avail_in_--; } } return Status::OK(); } Status ByteSnappyInputBuffer::ReadFromFile() { int bytes_to_read = input_buffer_capacity_; char* read_offset = reinterpret_cast(input_buffer_.get()); // If there are unread bytes in the input stream we move them to the head // of the stream to maximize the space available to read new data into. // TODO(srbs): A circular buffer would be useful here. if (avail_in_ > 0) { size_t read_bytes = next_in_ - input_buffer_.get(); // Remove `read_bytes` from the head of the input stream. // Move unread bytes to the head of the input stream. if (read_bytes > 0) { memmove(input_buffer_.get(), next_in_, avail_in_); } bytes_to_read -= avail_in_; read_offset += avail_in_; } StringPiece data; // Try to read enough data to fill up input_buffer_. struct timeval t0; struct timeval t1; size_t old_size = data.size(); gettimeofday(&t0, NULL); Status s = Status(error::OUT_OF_RANGE, "Read less bytes than requested"); if (!reached_eof_) { read_round_++; s = file_->Read(file_pos_, bytes_to_read, &data, read_offset); } gettimeofday(&t1, NULL); int64_t elapsed = (t1.tv_sec - t0.tv_sec) * 1000000 + t1.tv_usec - t0.tv_usec; elapsed /= 1000; auto throughput = (elapsed == 0) ? 0.0 : (data.size() - old_size) / (1024 * elapsed); LOG_EVERY_N(INFO, 100) << "********************************At round: " << read_round_ << ", the expected read: " << bytes_to_read << " and the actual read is: " << (data.size() - old_size) / (1024.0 * 1024) << " MB at timestamp: " << elapsed << " ms with a bandwidth: " << throughput << " MBps. If out of range? " << errors::IsOutOfRange(s) << " .*********************************"; if (data.data() != read_offset) { memmove(read_offset, data.data(), data.size()); } // Since we moved unread data to the head of the input stream we can point // next_in to the head of the input stream. next_in_ = input_buffer_.get(); // Note: data.size() could be different from bytes_to_read. avail_in_ += data.size(); file_pos_ += data.size(); // Report failure if not EoF or normal reading. if (!s.ok() && !errors::IsOutOfRange(s)) { return s; } // We throw OutOfRange error iff no new data has been read from file. // Since we never check how much data is remaining in the file, it is // possible that on the last read there isn't enough data in the file to // fill up the buffer in which case file_->ReadNBytes would return an // OutOfRange error. if (data.empty()) { reached_eof_ = true; return errors::OutOfRange("EOF reached"); } if (errors::IsOutOfRange(s)) { reached_eof_ = true; return Status::OK(); } return s; } } // namespace io } // namespace tensorflow ================================================ FILE: monolith/native_training/data/training_instance/cc/snappy_inputbuffer.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ // Code is modified based on // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/lib/io/snappy/snappy_inputbuffer.h #ifndef MONOLITH_NATIVE_TRAINING_DATA_TRAINING_INSTANCE_CC_SNAPPY_INPUTBUFFER_H_ #define MONOLITH_NATIVE_TRAINING_DATA_TRAINING_INSTANCE_CC_SNAPPY_INPUTBUFFER_H_ #include #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/io/inputstream_interface.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/snappy.h" #include "tensorflow/core/platform/types.h" #include "cached_mem_pool.h" namespace tensorflow { namespace io { using CachedMemPool = ::tensorflow::monolith_tf::CachedMemPool; // An SnappyInputBuffer provides support for reading from a hdfs file compressed // using snappy (https://github.com/google/snappy). // // A Hadoop snappy compressed file contains several compressed data blocks. The // format of a compressed block is, // uint32_t Uncompressed Length // uint32_t Compressed Length // byte[compressed_length] Compressed block // See: // https://github.com/apache/hadoop/blob/trunk/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-nativetask/src/main/native/src/codec/SnappyCodec.cc#L35 // // A given instance of an SnappyInputBuffer is NOT safe for concurrent use // by multiple threads class ByteSnappyInputBuffer : public InputStreamInterface { public: // Create a SnappyInputBuffer for `file` with a buffer of size // `input_buffer_bytes` bytes for reading contents from `file` and another // buffer with size `output_buffer_bytes` for caching decompressed contents. // Does *not* take ownership of "file". ByteSnappyInputBuffer(RandomAccessFile* file, size_t input_buffer_bytes, size_t output_buffer_bytes); ~ByteSnappyInputBuffer() override; // Reads bytes_to_read bytes into *result, overwriting *result. // // Return Status codes: // OK: // If successful. // OUT_OF_RANGE: // If there are not enough bytes to read before the end of the file. // DATA_LOSS: // If uncompression failed or if the file is corrupted. // RESOURCE_EXHAUSTED: // If input_buffer_ is smaller in size than a compressed block. // others: // If reading from file failed. Status ReadNBytes(int64 bytes_to_read, tstring* result) override; int64 Tell() const override; Status Reset() override; private: // Reads data from `file_` and tries to fill up `input_buffer_` if enough // unread data is left in `file_`. // // Looks up `next_in_` to check how much data in `input_buffer_` // has already been read. The used data is removed and new data is added to // after any unread data in `input_buffer_`. // After this call `next_in` points to the start of `input_buffer_` // and `avail_in_` stores the number of readable bytes in // `input_buffer_`. // // Returns OutOfRange error if NO data could be read from file. Note that this // won't return an OutOfRange if there wasn't sufficient data in file to // completely fill up `input_buffer_`. Status ReadFromFile(); // 1. Reads the uncompressed length of the next compressed block // stored in the next 4 bytes at `next_in_`. // 2. Reads the compressed length of the next compressed block // stored in the next 4 bytes at `next_in_`. // 3. Uncompresses the next compressed block and writes the output // produced to the output_buffer_. // // Should be called only after the cached output has been consumed. Status Inflate(); // Starts reading bytes at `next_out_` till either `bytes_to_read` // bytes have been read or `next_out_` is reached. // Returns the number of bytes read and advances the `next_out_` // pointer to the next offset to read from. size_t ReadBytesFromCache(size_t bytes_to_read, char* result); // Reads the length of the next *compressed* block and stores in `length`. // The length is stored in 4 bytes in little endian notation. // For each block, call this method for two times. The first one read the // uncompressed length, the second one read the compressed. Status ReadBlockLength(uint32* length); RandomAccessFile* file_; // Not owned int64 file_pos_ = 0; // Next position to read from in `file_` size_t input_buffer_capacity_; // Size of `input_buffer_`. // Must be at least as big as the size of // the largest compressed block. size_t output_buffer_capacity_; // Size of `output_buffer_` // Singleton memory pool. CachedMemPool* cached_mem_pool_; // Buffer for storing contents read from compressed file. // TODO(srbs): Consider using circular buffers. That would greatly simplify // the implementation. std::unique_ptr input_buffer_; // Buffer for storing inflated contents of `file_`. std::unique_ptr output_buffer_; // Next unread byte in `input_buffer_`. char* next_in_; // Next unread byte in `output_buffer_` char* next_out_; // Number of unread bytes bytes available at `next_in_` in `input_buffer_`. size_t avail_in_ = 0; // Number of unread bytes bytes available at `next_out_` in `output_buffer_`. size_t avail_out_ = 0; // Number of uncompressed bytes has been read. int64 bytes_read_ = 0; // States when uncompressing a block uint32 block_length_ = 0; uint32 uncompressed_bytes_in_block_ = 0; bool reached_eof_ = false; uint32 read_round_ = 0; TF_DISALLOW_COPY_AND_ASSIGN(ByteSnappyInputBuffer); }; } // namespace io } // namespace tensorflow #endif // MONOLITH_NATIVE_TRAINING_DATA_TRAINING_INSTANCE_CC_SNAPPY_INPUTBUFFER_H_ ================================================ FILE: monolith/native_training/data/training_instance/cc/ue_compress.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "monolith/native_training/data/training_instance/cc/ue_compress.h" namespace tensorflow { namespace monolith_tf { using matrix::compression::Float16; const char* UE_COMPRESS_FLAG = "UE_QTZ"; bool UECompress::compress_embeddings( ::idl::matrix::proto::Feature* feature_column, UECompressMethod compress_method) { if (compress_method == UECompressMethod::COMPRESS_QTZ8) { auto* bytes_value = feature_column->add_bytes_value(); int embedding_size = feature_column->float_value_size(); std::vector compress_input; compress_input.reserve(embedding_size); for (auto value : feature_column->float_value()) { compress_input.push_back(value); } std::string compress_out; bool ret = matrix::compression::compress_float_list_qtz8mm( (const char*)compress_input.data(), embedding_size * sizeof(float), &compress_out); if (!ret) { LOG(ERROR) << "compress_embeddings failed, feature_column name=%s" << feature_column->name().c_str(); return false; } *bytes_value = UE_COMPRESS_FLAG + compress_out; return true; } else { LOG(ERROR) << "compress_embeddings invalid compress method:%d" << compress_method; return false; } } bool UECompress::decompress_embeddings( const idl::matrix::proto::Feature& feature_column, std::vector* embedding, UECompressMethod compress_method) { if (compress_method == UECompressMethod::COMPRESS_QTZ8) { std::string compress_out; std::string bytes_value; for (auto& value : feature_column.bytes_value()) { if (value.find(UE_COMPRESS_FLAG) == 0) { bytes_value = value.substr(strlen(UE_COMPRESS_FLAG)); } } if (bytes_value.empty()) { return false; } bool ret = matrix::compression::decompress_float_list_qtz8mm( (const char*)bytes_value.data(), bytes_value.size(), &compress_out); if (!ret) { LOG(ERROR) << "decompress_embeddings failed, feature_column name=%s" << feature_column.name().c_str(); return false; } size_t embedding_size = bytes_value.size() - 2 * sizeof(Float16); const float* output = reinterpret_cast(compress_out.data()); embedding->clear(); for (size_t i = 0; i < embedding_size; ++i) { embedding->emplace_back(output[i]); } return true; } else { LOG(ERROR) << "decompress_embeddings invalid compress method:%d" << compress_method; return false; } } } // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/data/training_instance/cc/ue_compress.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_NATIVE_TRAINING_DATA_TRAINING_INSTANCE_CC_UE_COMPRESS_H_ #define MONOLITH_NATIVE_TRAINING_DATA_TRAINING_INSTANCE_CC_UE_COMPRESS_H_ #include "glog/logging.h" #include "idl/matrix/compression/compression.h" #include "idl/matrix/compression/float16.h" #include "idl/matrix/proto/proto_parser.pb.h" namespace tensorflow { namespace monolith_tf { enum UECompressMethod { COMPRESS_QTZ8 = 0 // 8bit 量化 }; class UECompress { public: UECompress() = default; virtual ~UECompress() = default; bool compress_embeddings(::idl::matrix::proto::Feature* feature_column, UECompressMethod compress_method); bool decompress_embeddings( const ::idl::matrix::proto::Feature& feature_column, std::vector* embedding, UECompressMethod compress_method); }; } // namespace monolith_tf } // namespace tensorflow #endif // MONOLITH_NATIVE_TRAINING_DATA_TRAINING_INSTANCE_CC_UE_COMPRESS_H_ ================================================ FILE: monolith/native_training/data/training_instance/cc/ue_compress_test.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "ue_compress.h" #include "gmock/gmock.h" #include "gtest/gtest.h" namespace tensorflow { namespace monolith_tf { using ::testing::Pointwise; using ::testing::FloatNear; TEST(UECompressTest, Basic) { std::shared_ptr ue_compress_ = std::make_shared(); std::vector float_values; float_values.push_back(1.1); float_values.push_back(0.1); float_values.push_back(3.1); float_values.push_back(5.1); float_values.push_back(2.2); float_values.push_back(3.3); float_values.push_back(4.3); ::idl::matrix::proto::Feature feature; feature.set_name("fc_test"); feature.mutable_float_value()->Reserve(float_values.size()); for (auto& v : float_values) { feature.add_float_value(v); } ue_compress_->compress_embeddings(&feature, UECompressMethod::COMPRESS_QTZ8); for (auto& values : feature.bytes_value()) { std::cout << "values " << values << std::endl; } feature.clear_float_value(); std::vector embedding; bool ret = ue_compress_->decompress_embeddings( feature, &embedding, UECompressMethod::COMPRESS_QTZ8); ASSERT_TRUE(ret); ASSERT_EQ(float_values.size(), embedding.size()); for (int i = 0; i < embedding.size(); i++) { ASSERT_THAT(embedding, Pointwise(FloatNear(1e-2), float_values)); } } } // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/data/training_instance/cc/zstd_inputbuffer.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "monolith/native_training/data/training_instance/cc/zstd_inputbuffer.h" #include #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/strcat.h" namespace tensorflow { namespace io { MonolithZstdInputStream::MonolithZstdInputStream( InputStreamInterface* input_stream, size_t input_buffer_bytes, size_t output_buffer_bytes, bool owns_input_stream) : owns_input_stream_(owns_input_stream), input_stream_(input_stream), input_buffer_(new char[input_buffer_bytes]), input_buffer_capacity_(input_buffer_bytes), output_buffer_(new char[output_buffer_bytes]), output_buffer_capacity_(output_buffer_bytes), bytes_read_(0) { InitZstdBuffer(); } MonolithZstdInputStream::MonolithZstdInputStream( InputStreamInterface* input_stream, size_t input_buffer_bytes, size_t output_buffer_bytes) : MonolithZstdInputStream(input_stream, input_buffer_bytes, output_buffer_bytes, false) {} MonolithZstdInputStream::~MonolithZstdInputStream() { ZSTD_freeDCtx(context_); if (owns_input_stream_) { delete input_stream_; } } void MonolithZstdInputStream::InitZstdBuffer() { context_ = ZSTD_createDCtx(); if (context_ == nullptr) { LOG(FATAL) << "Creation of context failed."; } next_in_byte_ = input_buffer_.get(); zstd_input_buffer_ = {next_in_byte_, 0, 0}; next_unread_byte_ = output_buffer_.get(); unread_bytes_ = 0; avail_in_ = 0; } Status MonolithZstdInputStream::Reset() { TF_RETURN_IF_ERROR(input_stream_->Reset()); ZSTD_DCtx_reset(context_, ZSTD_reset_session_only); InitZstdBuffer(); bytes_read_ = 0; return Status::OK(); } size_t MonolithZstdInputStream::ReadBytesFromCache(size_t bytes_to_read, tstring* result) { size_t can_read_bytes = std::min(bytes_to_read, unread_bytes_); if (can_read_bytes > 0) { tstring cached_result; cached_result.append(next_unread_byte_, can_read_bytes); result->append(next_unread_byte_, can_read_bytes); } next_unread_byte_ += can_read_bytes; unread_bytes_ -= can_read_bytes; bytes_read_ += can_read_bytes; return can_read_bytes; } Status MonolithZstdInputStream::ReadNBytes(int64 bytes_to_read, tstring* result) { result->clear(); bytes_to_read -= ReadBytesFromCache(bytes_to_read, result); while (bytes_to_read > 0) { // No bytes should be left in the cache. CHECK_EQ(unread_bytes_, 0); // Now that the cache is empty we need to inflate more data. next_unread_byte_ = output_buffer_.get(); TF_RETURN_IF_ERROR(Inflate()); // If no progress was made by inflate, read more compressed data from the // input stream. if (unread_bytes_ == 0) { TF_RETURN_IF_ERROR(ReadFromStream()); if (avail_in_ == 0) { bytes_to_read = 0; } } else { bytes_to_read -= ReadBytesFromCache(bytes_to_read, result); } } return Status::OK(); } #if defined(TF_CORD_SUPPORT) Status MonolithZstdInputStream::ReadNBytes(int64 bytes_to_read, absl::Cord* result) { // TODO(frankchn): Optimize this instead of bouncing through the buffer. tstring buf; TF_RETURN_IF_ERROR(ReadNBytes(bytes_to_read, &buf)); result->Clear(); result->Append(buf.data()); return Status::OK(); } #endif Status MonolithZstdInputStream::Inflate() { ZSTD_outBuffer output = {next_unread_byte_, output_buffer_capacity_, 0}; last_return_ = ZSTD_decompressStream(context_, &output, &zstd_input_buffer_); if (ZSTD_isError(last_return_)) { string error_name = ZSTD_getErrorName(last_return_); string error_string = strings::StrCat("ZSTD_decompressStream: ", error_name); return errors::DataLoss(error_string); } avail_in_ = 0; unread_bytes_ = output.pos; return Status::OK(); } Status MonolithZstdInputStream::ReadFromStream() { size_t bytes_to_read = input_buffer_capacity_; char* read_location = input_buffer_.get(); // If there are unread bytes in the input stream we move them to the head // of the stream to maximize the space available to read new data into. if (avail_in_ > 0) { size_t read_bytes = next_in_byte_ - input_buffer_.get(); // Remove `read_bytes` from the head of the input stream. // Move unread bytes to the head of the input stream. if (read_bytes > 0) { memmove(input_buffer_.get(), next_in_byte_, avail_in_); } bytes_to_read -= avail_in_; read_location += avail_in_; } tstring data; Status s = input_stream_->ReadNBytes(bytes_to_read, &data); memcpy(read_location, data.data(), data.size()); // Note: data.size() could be different from bytes_to_read. avail_in_ += data.size(); zstd_input_buffer_.pos = 0; zstd_input_buffer_.size = data.size(); if (!s.ok() && !errors::IsOutOfRange(s)) { return s; } if (errors::IsOutOfRange(s)) { return Status::OK(); } return s; } int64 MonolithZstdInputStream::Tell() const { return bytes_read_; } } // namespace io } // namespace tensorflow ================================================ FILE: monolith/native_training/data/training_instance/cc/zstd_inputbuffer.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ // Code is modified based on // https://github.com/tensorflow/tensorflow/compare/master...IAL32:tensorflow:ac/add-zstd-support #ifndef TENSORFLOW_MONOLITH_IO_ZSTD_ZSTD_INPUTSTREAM_H_ #define TENSORFLOW_MONOLITH_IO_ZSTD_ZSTD_INPUTSTREAM_H_ #define ZSTD_STATIC_LINKING_ONLY #include #include "tensorflow/core/lib/io/inputstream_interface.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" namespace tensorflow { namespace io { class MonolithZstdInputStream : public InputStreamInterface { public: // Creates a MonolithZstdInputStream for `input_stream`. // // Takes ownership of `input_stream` iff `owns_input_stream` is true. MonolithZstdInputStream(InputStreamInterface* input_stream, size_t input_buffer_bytes, size_t output_buffer_bytes, bool owns_input_stream); // Equivalent to the previous constructor with owns_input_stream=false. MonolithZstdInputStream(InputStreamInterface* input_stream, size_t input_buffer_bytes, size_t output_buffer_bytes); ~MonolithZstdInputStream() override; // Reads bytes_to_read bytes into *result, overwriting *result. // // Return Status codes: // OK: If successful. // OUT_OF_RANGE: If there are not enough bytes to read before // the end of the stream. // ABORTED: If inflate() fails, we return the error code with the // error message in `z_stream_->msg`. // others: If reading from stream failed. Status ReadNBytes(int64 bytes_to_read, tstring* result) override; #if defined(TF_CORD_SUPPORT) Status ReadNBytes(int64 bytes_to_read, absl::Cord* result) override; #endif int64 Tell() const override; Status Reset() override; private: // Decompress the next chunk of data and place the data into the cache. Status Inflate(); Status ReadFromStream(); // There may be bytes leftover from last read. We read them so that we don't // lose them, and we optimize resources. size_t ReadBytesFromCache(size_t bytes_to_read, tstring* result); void InitZstdBuffer(); const bool owns_input_stream_; InputStreamInterface* input_stream_; std::unique_ptr input_buffer_; size_t input_buffer_capacity_; // Size of input_buffer_ char* next_in_byte_; // Next unread byte to decompress size_t avail_in_; // Number of bytes available to be decompressed ZSTD_inBuffer zstd_input_buffer_; std::unique_ptr output_buffer_; // Inflated buffer size_t output_buffer_capacity_; // Size of output_buffer_ char* next_unread_byte_; // Next unread byte in output_buffer_ // bytes left in the output_buffer_ not yet read. size_t unread_bytes_; ZSTD_DCtx* context_; // Specifies the number of decompressed bytes currently read. size_t bytes_read_; size_t last_return_; TF_DISALLOW_COPY_AND_ASSIGN(MonolithZstdInputStream); }; } // namespace io } // namespace tensorflow #endif // TENSORFLOW_MONOLITH_IO_ZSTD_ZSTD_INPUTSTREAM_H_ ================================================ FILE: monolith/native_training/data/training_instance/python/instance_dataset_op.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from absl import logging import os from enum import Enum import tensorflow as tf from tensorflow.python import tf2 from tensorflow.python.data.experimental.ops import matching_files from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import convert from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.platform import resource_loader from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_spec from tensorflow.python.util.tf_export import tf_export from monolith.native_training.distribute import distributed_dataset from monolith.native_training.hooks import ckpt_hooks from monolith.native_training.runner_utils import RunnerConfig from monolith.native_training.runtime.ops import gen_monolith_ops instance_dataset_op = gen_monolith_ops class _PBInstanceDataset(dataset_ops.DatasetSource): def __init__(self, file_name, use_snappy=False, **kwargs): self._file_name = file_name self._use_snappy = use_snappy self._has_sort_id = kwargs.get('has_sort_id', True) self._kafka_dump = kwargs.get('kafka_dump', False) self._kafka_dump_prefix = kwargs.get('kafka_dump_prefix', False) variant_tensor = instance_dataset_op.instance_dataset( file_name=tf.convert_to_tensor(self._file_name, dtype=tf.string), use_snappy=tf.convert_to_tensor(self._use_snappy, dtype=tf.bool), has_sort_id=tf.convert_to_tensor(self._has_sort_id, dtype=tf.bool), kafka_dump=tf.convert_to_tensor(self._kafka_dump, dtype=tf.bool), kafka_dump_prefix=tf.convert_to_tensor(self._kafka_dump_prefix, dtype=tf.bool)) logging.info("Start init of the pb instance dataset base.") super(_PBInstanceDataset, self).__init__(variant_tensor) @property def element_spec(self): return tensor_spec.TensorSpec([], dtypes.string) class PBInstanceDatasetV2(dataset_ops.DatasetV2): """从标准输入/pb文件中读取序列化Instance, 不反序列化 Args: file_name (:obj:`str`): 文件名, 如果为空, 则从stdin读取数据 use_snappy (:obj:`str`): 输入文件是不否是snappy压缩的 has_sort_id (:obj:`bool`): 输入数据中是否带8 bytes前缀标识, 表明sort_id kafka_dump (:obj:`bool`): 输入数据中是否带8 bytes前缀标识, 表明kafka_dump kafka_dump_prefix (:obj:`bool`): 输入数据中是否带8 bytes前缀标识, 表明kafka_dump_prefix Raises: TypeError: 如果有任何参数与类型不匹配, 则抛TypeError ValueError: 如果有任何值与期望不匹配, 则抛ValueError """ def __init__(self, file_name, use_snappy=False, **kwargs): self._file_name = file_name self._use_snappy = use_snappy self._kwargs = kwargs if isinstance(file_name, str) and not file_name: # This is the special case that dataset uses stdin as the input. # In this case, we should diable the ckpt save/restore. ckpt_hooks.disable_iterator_save_restore() def creator_fn(): return _PBInstanceDataset(file_name, use_snappy, **self._kwargs) self._impl = creator_fn() variant_tensor = self._impl._variant_tensor logging.info("Start init of the pb instance dataset v2") super(PBInstanceDatasetV2, self).__init__(variant_tensor) logging.info("Finish init of the pb instance dataset v2") def _clone(self, file_name, use_snappy=False, **kwargs): _kwargs = self._kwargs.copy() _kwargs.update(kwargs) return PBInstanceDatasetV2(file_name or self._file_name, use_snappy or self._use_snappy, **_kwargs) @property def element_spec(self): return tensor_spec.TensorSpec([], dtypes.string) def _inputs(self): return [] #TODO(leqi.zou): We should rewrite this to make it more clear. def create_instance_dataset(files_list=None, use_snappy=False, expand_glob_path=False, cycle_length=4, num_parallel_calls=4, block_length=1, enable_sharding: bool = False, shard_index: int = None, shard_num: int = None, enable_dynamic_sharding=False, **kwargs): if files_list is None: # use stdin files_list = [""] if len( files_list ) == 1 and not expand_glob_path and not enable_sharding and not enable_dynamic_sharding: if len(files_list[0]) > 0 and not tf.io.gfile.exists(files_list[0]): logging.fatal('File not found: {}'.format(files_list[0])) return PBInstanceDatasetV2(file_name=files_list[0], use_snappy=use_snappy, **kwargs) map_func = lambda file_name: PBInstanceDatasetV2( file_name=file_name, use_snappy=use_snappy, **kwargs) if enable_dynamic_sharding: files_list = distributed_dataset.create_dynamic_sharding_dataset(files_list) return files_list.flat_map(map_func) elif not enable_sharding: if expand_glob_path: files_list = matching_files.MatchingFilesDataset(files_list) else: files_list = tf.data.Dataset.from_tensor_slices(files_list) else: # We should use only 1 pattern for the sharded hdfs reading. assert len(files_list) == 1 # List all the files via the list_files op. files_list = matching_files.MatchingFilesDataset(files_list) # Shard it according to the preallocated index. files_list = files_list.shard(shard_num, shard_index) logging.info("Shard the input files for shard {}/{}.".format( shard_index, shard_num)) use_snappy = True dataset = files_list.interleave(map_func=map_func, cycle_length=cycle_length, block_length=block_length, num_parallel_calls=num_parallel_calls, deterministic=False) return dataset PBInstanceDataset = PBInstanceDatasetV2 ================================================ FILE: monolith/native_training/data/training_instance/python/instance_dataset_op_test_stdin.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import tensorflow as tf from absl import logging from monolith.native_training.data.training_instance.python.instance_dataset_ops import PBInstanceDataset from monolith.native_training.data.training_instance.python.parse_instance_ops import parse_instances from tensorflow.python.framework import sparse_tensor FIDV1_FEATURES = [i for i in range(1, 10)] FIDV2_FEATURES = ["fc_360d_ml_convert_cid", "fc_360d_ml_convert_advertiser_id"] FLOAT_FEATURES = ["fc_muse_finish_rough_10168_uid_d128"] FLOAT_FEATURES_DIM = [128] INT64_FEATURES = ["fc_dense_external_action"] INT64_FEATURE_DIM = [1] def parse(serialized): return parse_instances(serialized, FIDV1_FEATURES, FIDV2_FEATURES, FLOAT_FEATURES, FLOAT_FEATURES_DIM, INT64_FEATURES, INT64_FEATURE_DIM) def testInstanceDataset(): # with self.session() as sess: with tf.compat.v1.Session() as sess: logging.warning("PBInstanceDatasetV2 process is Starting") dataset = PBInstanceDataset( file_name="", has_sort_id=True, kafka_dump_prefix=True, ) dataset = dataset.batch(32).map(parse) it = tf.compat.v1.data.make_one_shot_iterator(dataset) element = it.get_next() logging.warning("PBInstanceDatasetV2 next process is Finished") elements = sess.run(element) logging.warning(element) logging.warning(elements["sample_rate"]) if __name__ == '__main__': tf.compat.v1.disable_eager_execution() testInstanceDataset() ================================================ FILE: monolith/native_training/data/training_instance/python/instance_negative_gen_dataset_op_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import tensorflow as tf from absl import logging import numpy as np from collections import defaultdict from monolith.native_training.data.training_instance.python.instance_dataset_op import PBDataset, PbType, PBInstanceDataset, InstanceNegativeGenDataset from monolith.native_training.data.training_instance.python.parse_instance_ops import parse_variant_instances, parse_instances from monolith.native_training.data.training_instance.python.pb_datasource_ops import variant_dummy from tensorflow.python.framework import sparse_tensor FILE_NAME = 'monolith/native_training/data/training_instance/instance.pb' CHANNEL_SLOT = 357 GROUP_SLOTS = [200,201,202,203,204,205,206,210,211,212,213,214,215,\ 216,217,218,219,220,221,222,223,224,225,230,231,232,233,234,235,236,237,238,239,240,241,242] LABEL_FIELD = 'actions' LABEL_INDEX = 0 NEGATIVE_LABEL = -2 NEGATIVE_LABEL2 = -1 CHANNEL_FEATURE_NAME = "" GROUP_FEATURES_NAME = [] GID = 'gid' CHANNEL_SLOT_NAME = 'slot_' + str(CHANNEL_SLOT) GROUP_SLOT_NAME = 'slot_200' CHANNEL = 6435440280980561277 def parse1(pb_varient: tf.Tensor): FIDV1_FEATURES = [ 1, 3, 4, 5, 7, 8, 9, 31, 32, 33, 35, 36, 37, 38, 42, 44, 60, 61, 62, 63, 65, 66, 67, 68, 72, 74, 90, 91, 92, 93, 95, 120, \ 121, 122, 123, 125, 126, 128, 150, 151, 152, 153, 155, 156, 158, 180, 181, 182, 183, 185, 186, 188, 192, 193, 194, 200, 201, \ 202, 204, 206, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 230, 231, 232, 233, 234, 235, \ 236, 237, 238, 239, 240, 242, 357, 358, 359, 360, 361, 410, 411, 412, 413, 415, 416, 418, 422, 423, 424, 446, 472, 475, 515, 516 ] return parse_variant_instances(pb_varient, FIDV1_FEATURES, misc_int64_features=[GID]) class InsNegativeDatasetTest(tf.test.TestCase): def testNegativeGen(self): with self.session() as sess: dataset = PBDataset(file_name=FILE_NAME, has_sort_id=True, kafka_dump=True, kafka_dump_prefix=False, input_pb_type=PbType.Instance, output_pb_type=PbType.Instance) dataset = dataset.negative_gen(neg_num=7, channel_slot=CHANNEL_SLOT, group_slots=GROUP_SLOTS, per_channel_sample=True, start_num=0, max_group_num_per_channel=10000, label_field=LABEL_FIELD, label_index=0, negative_label=NEGATIVE_LABEL, use_neg_ins=True) dataset = dataset.batch(8, drop_remainder=True).map(parse1) it = tf.compat.v1.data.make_one_shot_iterator(dataset) element = it.get_next() count = 0 channel_res = [] group_res = [] label_res = [] while True: try: ret = sess.run(element) channel_res.append(ret[CHANNEL_SLOT_NAME].flat_values) group_res.append(ret[GROUP_SLOT_NAME].flat_values) label_res.append(ret[LABEL_FIELD]) count += 8 if count > 16: break except tf.errors.OutOfRangeError: logging.info("got eof") break for i in range(1, 8): self.assertEqual(channel_res[0][0], channel_res[0][i]) self.assertEqual(label_res[0][1], NEGATIVE_LABEL) def testRingBufferCache(self): with self.session() as sess: dataset = PBDataset(file_name=FILE_NAME, has_sort_id=True, kafka_dump=True, kafka_dump_prefix=False, input_pb_type=PbType.Instance, output_pb_type=PbType.Instance) max_group_num_per_channel = 2 dataset = dataset.negative_gen( neg_num=7, channel_slot=CHANNEL_SLOT, group_slots=GROUP_SLOTS, per_channel_sample=True, start_num=0, max_group_num_per_channel=max_group_num_per_channel, label_field=LABEL_FIELD, label_index=0, negative_label=NEGATIVE_LABEL, use_neg_ins=True) dataset = dataset.batch(8, drop_remainder=True).map(parse1) it = tf.compat.v1.data.make_one_shot_iterator(dataset) element = it.get_next() count = 0 channel_res = [] group_res = [] label_res = [] gid_res = [] while True: try: ret = sess.run(element) channel_res.append(ret[CHANNEL_SLOT_NAME].flat_values) group_res.append(ret[GROUP_SLOT_NAME].flat_values) label_res.append(ret[LABEL_FIELD]) gid_res.append(ret[GID]) count += 8 if count > 1024: break except tf.errors.OutOfRangeError: logging.info("got eof") break res_by_channel = defaultdict(list) for i in range(100): channel = channel_res[i][0] res_by_channel[channel].append(i) valid_count = 0 for channel in res_by_channel: one_channel_res = res_by_channel[channel] if len(one_channel_res) <= max_group_num_per_channel: continue idx0 = one_channel_res[0] idx1 = one_channel_res[1] idx2 = one_channel_res[2] if gid_res[idx0][0] != gid_res[idx1][0] and gid_res[idx0][0] != gid_res[idx2][0] \ and gid_res[idx1][0] != gid_res[idx2][0]: for fid in group_res[idx2]: self.assertNotIn(fid, group_res[idx0]) valid_count += 1 logging.info('checkout count ' + str(valid_count)) def testIgnoreReaNegInstance(self): with self.session() as sess: dataset = PBDataset(file_name=FILE_NAME, has_sort_id=True, kafka_dump=True, kafka_dump_prefix=False, input_pb_type=PbType.Instance, output_pb_type=PbType.Instance) dataset = dataset.negative_gen(neg_num=7, channel_slot=CHANNEL_SLOT, group_slots=GROUP_SLOTS, per_channel_sample=True, start_num=0, max_group_num_per_channel=10000, label_field=LABEL_FIELD, label_index=0, negative_label=NEGATIVE_LABEL, use_neg_ins=True) dataset = InstanceNegativeGenDataset(input_dataset=dataset, neg_num=2, channel_slot=CHANNEL_SLOT, group_slots=GROUP_SLOTS, per_channel_sample=True, start_num=0, max_group_num_per_channel=10000, label_field=LABEL_FIELD, label_index=0, negative_label=NEGATIVE_LABEL2, use_neg_ins=False) dataset = dataset.batch(8, drop_remainder=True).map(parse1) it = tf.compat.v1.data.make_one_shot_iterator(dataset) element = it.get_next() count = 0 channel_res = [] group_res = [] label_res = [] while True: try: ret = sess.run(element) label_res.append(ret[LABEL_FIELD]) count += 8 if count > 16: break except tf.errors.OutOfRangeError: logging.info("got eof") break self.assertEqual(label_res[0][1], NEGATIVE_LABEL2) def testUseNegInstance(self): with self.session() as sess: dataset = PBDataset(file_name=FILE_NAME, has_sort_id=True, kafka_dump=True, kafka_dump_prefix=False, input_pb_type=PbType.Instance, output_pb_type=PbType.Instance) dataset = dataset.negative_gen(neg_num=2, channel_slot=CHANNEL_SLOT, group_slots=GROUP_SLOTS, per_channel_sample=True, start_num=0, max_group_num_per_channel=10000, label_field=LABEL_FIELD, label_index=0, negative_label=NEGATIVE_LABEL, use_neg_ins=True) dataset = InstanceNegativeGenDataset(input_dataset=dataset, neg_num=2, channel_slot=CHANNEL_SLOT, group_slots=GROUP_SLOTS, per_channel_sample=True, start_num=0, max_group_num_per_channel=10000, label_field=LABEL_FIELD, label_index=0, negative_label=NEGATIVE_LABEL2, use_neg_ins=True) dataset = dataset.batch(8, drop_remainder=True).map(parse1) it = tf.compat.v1.data.make_one_shot_iterator(dataset) element = it.get_next() count = 0 channel_res = [] group_res = [] label_res = [] while True: try: ret = sess.run(element) label_res.append(ret[LABEL_FIELD]) count += 8 if count > 16: break except tf.errors.OutOfRangeError: logging.info("got eof") break self.assertEqual(label_res[0][1], NEGATIVE_LABEL2) self.assertEqual(label_res[0][2], NEGATIVE_LABEL2) self.assertEqual(label_res[0][3], NEGATIVE_LABEL) self.assertEqual(label_res[0][4], NEGATIVE_LABEL) if __name__ == '__main__': tf.compat.v1.disable_eager_execution() tf.test.main() ================================================ FILE: monolith/native_training/data/training_instance/python/parse_instance_ops.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from absl import logging import os import struct from typing import Dict, List, Iterable, Callable import tensorflow as tf from tensorflow.python.platform import resource_loader from tensorflow.python.framework import load_library from tensorflow.python.ops.ragged.row_partition import RowPartition, _row_partition_factory_key from monolith.native_training.data.utils import get_slot_feature_name from monolith.native_training.data.training_instance.python.parser_utils import \ add_extra_parse_step, advanced_parse from monolith.native_training.runtime.ops import gen_monolith_ops parse_instance_ops = gen_monolith_ops def _parse_instance_impl( serialized: tf.Tensor, fidv1_features: List[int], fidv2_features: List[str], float_features: List[str], float_feature_dims: List[int], int64_features: List[str], int64_feature_dims: List[int], string_features: List[str], string_feature_dims: List[int], misc_float_features: List[str], misc_float_dims: List[int], misc_int64_features: List[str], misc_int64_dims: List[int], misc_string_features: List[str], misc_string_dims: List[int], cc_op: Callable): fidv1_features = fidv1_features or [] fidv2_features = fidv2_features or [] float_features = float_features or [] float_feature_dims = float_feature_dims or [] int64_features = int64_features or [] int64_feature_dims = int64_feature_dims or [] string_features = string_features or [] string_feature_dims = string_feature_dims or [] misc_float_features = misc_float_features or [] misc_float_dims = misc_float_dims or [] misc_int64_features = misc_int64_features or [] misc_int64_dims = misc_int64_dims or [] misc_string_features = misc_string_features or [] misc_string_dims = misc_string_dims or [] (ragged_feature_splits, ragged_feature_values, float_feature_values, int64_feature_values, string_feature_values, misc_float_feature_values, misc_int64_feature_values, misc_string_feature_values) = cc_op( serialized, N=(len(fidv1_features) + len(fidv2_features)), M=len(float_features), O=len(int64_features), P=len(string_features), Q=len(misc_float_features), R=len(misc_int64_features), S=len(misc_string_features), fidv1_features=fidv1_features, fidv2_features=fidv2_features, float_features=float_features, float_feature_dims=float_feature_dims, string_features=string_features, string_feature_dims=string_feature_dims, int64_features=int64_features, int64_feature_dims=int64_feature_dims, misc_float_features=misc_float_features, misc_float_dims=misc_float_dims, misc_int64_features=misc_int64_features, misc_int64_dims=misc_int64_dims, misc_string_features=misc_string_features, misc_string_dims=misc_string_dims, ) ragged_keys = [get_slot_feature_name(slot_id) for slot_id in fidv1_features ] + fidv2_features ragged_values = [] for values, row_splits in zip(ragged_feature_values, ragged_feature_splits): row_partition = RowPartition( row_splits, # value_rowids= # nrows= # # TODO(zhuoran): Besides the "value" and "split" parsed from proto above, # precompute other two encodings "value_rowids" & "nrows" in Fountain also, # so that we could construct the ragged tensor with 4 precomputed encodings, # and would not need to recompute them later at training period again. internal=_row_partition_factory_key # Currently, we just compute and cache value_rowids and nrows here: ).with_precomputed_value_rowids().with_precomputed_nrows() ragged_values.append(tf.RaggedTensor(values, row_partition, internal=True)) float_keys = float_features int64_keys = int64_features string_keys = string_features return dict( zip( ragged_keys + float_keys + int64_keys + string_keys + misc_float_features + misc_int64_features + misc_string_features, ragged_values + float_feature_values + int64_feature_values + string_feature_values + misc_float_feature_values + misc_int64_feature_values + misc_string_feature_values)) def parse_instances2(serialized: tf.Tensor, fidv1_features: List[int] = None, fidv2_features: List[str] = None, float_features: List[str] = None, float_feature_dims: List[int] = None, int64_features: List[str] = None, int64_feature_dims: List[int] = None, string_features: List[str] = None, string_feature_dims: List[int] = None, misc_float_features: List[str] = None, misc_float_dims: List[int] = None, misc_int64_features: List[str] = None, misc_int64_dims: List[int] = None, misc_string_features: List[str] = None, misc_string_dims: List[int] = None): """从序列化的instance Tensor中解析instance Args: varient_tensor (:obj:`Tensor`): 输入数据 fidv1_features (:obj:`List[int]`): 在Instance中, fidv1_features是平铺的, 所以用slot指定, 可以是部分slot fidv2_features (:obj:`List[str]`): 在Instance中, fidv2_features存放于feature中, 可以用名字指定, 可以是部分特征名 float_features (:obj:`List[str]`): 在Instance中, 连续特征存于feature中, 可以用名字指定, 可以是部分特征名 float_feature_dims (:obj:`List[int]`): 连续特征的维度, `float_feature_dims`的长度要与`float_features`一致 int64_features (:obj:`List[str]`): 在Instance中, int64特征(非FID)存于feature中, 可以用名字指定, 可以是部分特征名 int64_feature_dims (:obj:`List[int]`): int64特征的维度, `int64_feature_dims`的长度要与`int64_features`一致 string_features (:obj:`List[str]`): 在Instance中, syting特征存于feature中, 可以用名字指定, 可以是部分特征名 string_feature_dims (:obj:`List[int]`): string特征的维度, `string_feature_dims`的长度要与`string_features`一致 misc_float_features (:obj:`List[str]`): 在LineId中, float字段, 用名字指定, 可以有多个 misc_float_dims (:obj:`List[int]`): 在LineId中, float字段维度, `misc_float_dims`的长度要与`misc_float_features`一致 misc_int64_features (:obj:`List[str]`): 在LineId中, int64字段, 用名字指定, 可以有多个 misc_int64_dims (:obj:`List[int]`): 在LineId中, int64字段维度, `misc_int64_dims`的长度要与`misc_int64_features`一致 misc_string_features (:obj:`List[str]`): 在LineId中, string字段, 用名字指定, 可以有多个 misc_string_dims (:obj:`List[str]`): 在LineId中, string字段维度, `misc_string_dims`的长度要与`misc_string_features`一致 Returns: Dict[str, Tensor] 解析出特征名到特征的字典 """ return _parse_instance_impl( serialized, fidv1_features, fidv2_features, float_features, float_feature_dims, int64_features, int64_feature_dims, string_features, string_feature_dims, misc_float_features, misc_float_dims, misc_int64_features, misc_int64_dims, misc_string_features, misc_string_dims, parse_instance_ops.monolith_parse_instances) def parse_instances(serialized: tf.Tensor, fidv1_features: List[int] = None, fidv2_features: List[str] = None, float_features: List[str] = None, float_feature_dims: List[int] = None, int64_features: List[str] = None, int64_feature_dims: List[int] = None, string_features: List[str] = None, string_feature_dims: List[int] = None, misc_float_features: List[str] = ['sample_rate'], misc_int64_features: List[str] = ['req_time', 'uid'], misc_string_features: List[str] = None, misc_repeated_float_features: List[str] = ['label'], misc_repeated_float_dims: List[int] = None, misc_repeated_int64_features: List[str] = None, misc_repeated_int64_dims: List[int] = None, misc_repeated_string_features: List[str] = None, misc_repeated_string_dims: List[str] = None): """从序列化的instance Tensor中解析instance, 但参数较多, 请使用`parse_instances2` Args: varient_tensor (:obj:`Tensor`): 输入数据 fidv1_features (:obj:`List[int]`): 在Instance中, fidv1_features是平铺的, 所以用slot指定, 可以是部分slot fidv2_features (:obj:`List[str]`): 在Instance中, fidv2_features存放于feature中, 可以用名字指定, 可以是部分特征名 float_features (:obj:`List[str]`): 在Instance中, 连续特征存于feature中, 可以用名字指定, 可以是部分特征名 float_feature_dims (:obj:`List[int]`): 连续特征的维度, `float_feature_dims`的长度要与`float_features`一致 int64_features (:obj:`List[str]`): 在Instance中, int64特征(非FID)存于feature中, 可以用名字指定, 可以是部分特征名 int64_feature_dims (:obj:`List[int]`): int64特征的维度, `int64_feature_dims`的长度要与`int64_features`一致 string_features (:obj:`List[str]`): 在Instance中, syting特征存于feature中, 可以用名字指定, 可以是部分特征名 string_feature_dims (:obj:`List[int]`): string特征的维度, `string_feature_dims`的长度要与`string_features`一致 misc_float_features (:obj:`List[str]`): 在LineId中, 非repeated float字段, 用名字指定, 可以有多个 misc_int64_features (:obj:`List[str]`): 在LineId中, 非repeated int64字段, 用名字指定, 可以有多个 misc_string_features (:obj:`List[str]`): 在LineId中, 非repeated string字段, 用名字指定, 可以有多个 misc_repeated_float_features (:obj:`List[str]`): 在LineId中, repeated float字段, 用名字指定, 可以有多个 misc_repeated_float_dims (:obj:`List[int]`): 在LineId中, repeated float字段维度, `misc_repeated_float_dims`的长度要与`misc_repeated_float_features`一致 misc_repeated_int64_features (:obj:`List[str]`): 在LineId中, repeated int64字段, 用名字指定, 可以有多个 misc_repeated_int64_dims (:obj:`List[int]`): 在LineId中, repeated int64字段维度, `misc_repeated_int64_dims`的长度要与`misc_repeated_int64_features`一致 misc_repeated_string_features (:obj:`List[str]`): 在LineId中, repeated string字段, 用名字指定, 可以有多个 misc_repeated_string_dims (:obj:`List[str]`): 在LineId中, repeated string字段维度, `misc_repeated_string_dims`的长度要与`misc_repeated_string_features`一致 Returns: Dict[str, Tensor] 解析出特征名到特征的字典 """ fidv1_features = fidv1_features or [] fidv2_features = fidv2_features or [] float_features = float_features or [] float_feature_dims = float_feature_dims or [] int64_features = int64_features or [] int64_feature_dims = int64_feature_dims or [] string_features = string_features or [] string_feature_dims = string_feature_dims or [] misc_float_features = misc_float_features or [] misc_float_feature_dims = [1] * len(misc_float_features) misc_int64_features = misc_int64_features or [] misc_int64_feature_dims = [1] * len(misc_int64_features) misc_string_features = misc_string_features or [] misc_string_features_dims = [1] * len(misc_string_features) misc_repeated_float_features = misc_repeated_float_features or [] misc_repeated_float_dims = misc_repeated_float_dims or [1] * len( misc_repeated_float_features) misc_repeated_int64_features = misc_repeated_int64_features or [] misc_repeated_int64_dims = misc_repeated_int64_dims or [1] * len( misc_repeated_int64_features) misc_repeated_string_features = misc_repeated_string_features or [] misc_repeated_string_dims = misc_repeated_string_dims or [1] * len( misc_repeated_string_features) features = parse_instances2( serialized, fidv1_features, fidv2_features, float_features, float_feature_dims, int64_features, int64_feature_dims, string_features, string_feature_dims, misc_float_features + misc_repeated_float_features, misc_float_feature_dims + misc_repeated_float_dims, misc_int64_features + misc_repeated_int64_features, misc_int64_feature_dims + misc_repeated_int64_dims, misc_string_features + misc_repeated_string_features, misc_string_features_dims + misc_repeated_string_dims) for key in misc_float_features + misc_int64_features: features[key] = tf.reshape(features[key], [-1]) return features # This is mainly for test purpose, DO NOT use it directly. monolith_raw_parse_instance = parse_instance_ops.MonolithRawParseInstance ================================================ FILE: monolith/native_training/data/training_instance/python/parse_instance_ops_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from absl import logging import numpy as np import random import tensorflow as tf from idl.matrix.proto import proto_parser_pb2 from monolith.native_training.data.training_instance.python import parse_instance_ops as ops from monolith.native_training.data.training_instance.python import parser_utils def make_fid_v1(slot_id, fid): return (slot_id << 54) | fid def make_fid_v2(slot_id, fid): return (slot_id << 48) | fid def get_test_fidv2(): return [make_fid_v2(100, i) for i in range(10)] def generate_instance(): instance = proto_parser_pb2.Instance() v1_fids = [make_fid_v1(i, i) for i in range(10)] v2_fids = get_test_fidv2() instance.fid.extend(v1_fids) fid_feature = instance.feature.add() fid_feature.name = "fidv2" fid_feature.fid.extend(v2_fids) float_feature = instance.feature.add() float_feature.name = "ue" float_feature.float_value.extend([float(i * 1e-5) for i in range(16)]) int64_feature = instance.feature.add() int64_feature.name = "int64_feature" int64_feature.int64_value.append(100) string_feature = instance.feature.add() string_feature.name = "string_feature" string_feature.bytes_value.append(b"test_string") instance.label.extend([1.1, 2.2, 3.3]) instance.line_id.uid = 110 instance.line_id.sample_rate = 0.5 instance.line_id.req_time = 64 instance.line_id.actions.extend([0, 100]) instance.line_id.user_id = "123" return instance class RaggedEncodingHelperTest(tf.test.TestCase): def testExpandContract(self): with tf.compat.v1.Session() as sess: rt = tf.RaggedTensor.from_row_splits(values=[3, 1, 4, 1, 5, 9, 2, 6], row_splits=[0, 4, 4, 7, 8, 8]) rt_copy = tf.RaggedTensor.from_row_splits(values=[3, 1, 4, 1, 5, 9, 2, 6], row_splits=[0, 4, 4, 7, 8, 8]) d = {"slot_2": rt} assert rt._row_partition._value_rowids is None d = parser_utils.RaggedEncodingHelper.expand( d, with_precomputed_value_rowids=True) print(d) assert len(d) == 1 self.assertAllEqual(sess.run(d["slot_2"]["value_rowids"]), sess.run(rt_copy.value_rowids())) d = parser_utils.RaggedEncodingHelper.contract(d) assert len(d) == 1 self.assertAllEqual(sess.run(d["slot_2"]), sess.run(rt)) self.assertAllEqual(sess.run(d["slot_2"]._row_partition._value_rowids), sess.run(rt_copy.value_rowids())) class ParseInstancesTest(tf.test.TestCase): def testParseInstance(self): instance = generate_instance() body = instance.SerializeToString() with tf.compat.v1.Session() as sess: features = ops.parse_instances2( [body, body], fidv1_features=list(range(10)), fidv2_features=["fidv2"], float_features=["ue"], float_feature_dims=[16], int64_features=["int64_feature"], int64_feature_dims=[1], string_features=["string_feature"], string_feature_dims=[1], misc_float_features=["sample_rate", "label"], misc_float_dims=[1, 3], misc_int64_features=["uid", "actions"], misc_int64_dims=[1, 2], misc_string_features=["user_id"], misc_string_dims=[1]) features = sess.run(features) self.assertEqual( len([fidv1_key for fidv1_key in features if "slot" in fidv1_key]), 10) self.assertAllEqual( features["slot_1"], tf.compat.v1.ragged.constant_value([[make_fid_v2(1, 1)]] * 2)) self.assertAllEqual( features["fidv2"], tf.compat.v1.ragged.constant_value([get_test_fidv2()] * 2)) self.assertAllClose(features["int64_feature"], [[100]] * 2) self.assertAllEqual(features["string_feature"], [[b"test_string"]] * 2) self.assertAllClose(features["ue"], [[float(i * 1e-5) for i in range(16)]] * 2) self.assertAllClose(features["sample_rate"], [[0.5]] * 2) self.assertAllClose(features["label"], [[1.1, 2.2, 3.3]] * 2) self.assertAllEqual(features["uid"], [[110]] * 2) self.assertAllEqual(features["actions"], [[0, 100]] * 2) self.assertAllEqual(features["user_id"], [["123"]] * 2) def testParseInstanceV1Only(self): instance = generate_instance() body = instance.SerializeToString() with tf.compat.v1.Session() as sess: features = ops.parse_instances2([body], fidv1_features=[1]) features = sess.run(features) self.assertAllEqual( features["slot_1"], tf.compat.v1.ragged.constant_value([[make_fid_v1(1, 1)]])) def testParseInstanceWithMissingFields(self): instance = generate_instance() body = instance.SerializeToString() with tf.compat.v1.Session() as sess: features = ops.parse_instances2( [body], fidv1_features=list(range(11)), fidv2_features=["fidv2", "fidv2_2"], float_features=["ue", "ue2"], float_feature_dims=[16, 8], int64_features=["int64_feature", "missing_int64_feature"], int64_feature_dims=[1, 10], string_features=["string_feature", "missing_string_feature"], string_feature_dims=[1, 10]) features = sess.run(features) # It should be an empty tensor for the last FID element self.assertAllEqual(features["slot_10"], tf.compat.v1.ragged.constant_value([[]])) self.assertAllEqual(features["fidv2_2"], tf.compat.v1.ragged.constant_value([[]])) # It should be an zero tensor for the second UE element self.assertAllEqual(features["ue2"], [[0 for i in range(8)]]) # It should be an zero tensor for the second int64 element self.assertAllEqual(features["missing_int64_feature"], [[0 for i in range(10)]]) self.assertAllEqual(features["missing_string_feature"], [["" for i in range(10)]]) class RawParseInstanceTest(tf.test.TestCase): def test_concat(self): serialized = [generate_instance().SerializeToString()] tensors = ops.monolith_raw_parse_instance(T=[tf.int64, tf.int64], serialized=serialized, fidv1_features=[0, 1], fidv2_features=["fidv2"], fid_output_type="CONCAT") with self.session() as sess: tensors = sess.run(tensors) self.assertAllEqual(tensors[0], [0, 1, 2, len(get_test_fidv2()) + 2]) self.assertAllEqual( tensors[1], [make_fid_v2(0, 0), make_fid_v2(1, 1)] + get_test_fidv2()) if __name__ == "__main__": tf.compat.v1.disable_eager_execution() tf.test.main() ================================================ FILE: monolith/native_training/data/training_instance/python/parser_utils.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import tensorflow as tf from collections import deque from typing import Callable from monolith.native_training import ragged_utils _extra_parse_steps = deque([]) def add_extra_parse_step(parse_fn: Callable): _extra_parse_steps.append(parse_fn) class RaggedEncodingHelper: """Helper methods to precompute ragged encodings in input_fn, as a workaround Fundamentally, we should modify TensorFlow Dataset structure handler to compute provided encoding tensor in RowParition of a RaggedTensor. """ @staticmethod def expand(name_to_ragged_ids, with_precomputed_nrows=True, with_precomputed_value_rowids=False): """Expand the RaggedTensor format in dict to precompute encodings within data iterator.""" d = {} for k, v in name_to_ragged_ids.items(): if isinstance(v, tf.RaggedTensor): d[k] = { # Basics "values": v.values, "row_splits": v.row_splits, "nrows": v.nrows() if with_precomputed_nrows else None, "value_rowids": ragged_utils.fused_value_rowids(v) if with_precomputed_value_rowids else None } else: d[k] = v return d @staticmethod def contract(name_to_ragged_ids): """Contract to recover RaggedTensor-only dict after computed.""" d = {} for k, v in name_to_ragged_ids.items(): if isinstance(v, dict) and ("values" in v) and ("row_splits" in v): t = tf.RaggedTensor.from_row_splits(v["values"], v["row_splits"], validate=False) if "nrows" in v: assert t._row_partition._nrows is None, "Shouldn't override the exisiting nrows." t._row_partition._nrows = v["nrows"] if "value_rowids" in v: assert t._row_partition._value_rowids is None, "Shouldn't override the exisiting tensor." t._row_partition._value_rowids = v["value_rowids"] d[k] = t else: d[k] = v return d def advanced_parse(features): while _extra_parse_steps: fn = _extra_parse_steps.popleft() features = fn(features) return features ================================================ FILE: monolith/native_training/data/training_instance/python/pb_datasource_ops.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from absl import logging import os from typing import Dict, List, Iterable, Callable import tensorflow as tf from tensorflow.python.framework import load_library from monolith.native_training.runtime.ops import gen_monolith_ops pb_datasource_ops = gen_monolith_ops def filter_by_fids(variant: tf.Tensor, filter_fids: List[int] = None, has_fids: List[int] = None, select_fids: List[int] = None, has_actions: List[int] = None): return pb_datasource_ops.set_filter(variant, filter_fids or [], has_fids or [], select_fids or [], has_actions or []) def filter_by_value(variant: tf.Tensor, field_name: str, op: str, operand: float): return pb_datasource_ops.value_filter(variant, field_name, op, operand) def negative_sample(variant: tf.Tensor, drop_rate: float, label_index: int, threshold: float): return pb_datasource_ops.negative_sample(variant, drop_rate, label_index, threshold) def variant_dummy(variant: tf.Tensor): return pb_datasource_ops.variant_dummy(variant) ================================================ FILE: monolith/native_training/data/training_instance/python/test_data_utils.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import tensorflow as tf ================================================ FILE: monolith/native_training/data/transform/BUILD ================================================ load("@rules_python//python:defs.bzl", "py_binary", "py_library", "py_test") load("@com_google_protobuf//:protobuf.bzl", "cc_proto_library", "py_proto_library") load("@rules_cc//cc:defs.bzl", "cc_library", "cc_test") load("@pip_deps//:requirements.bzl", "requirement") package(default_visibility = ["//visibility:public"]) cc_proto_library( name = "transform_config_cc_proto", srcs = ["transform_config.proto"], ) py_proto_library( name = "transform_config_py_proto", srcs = ["transform_config.proto"], srcs_version = "PY2AND3", ) cc_library( name = "transforms", srcs = ["cc/transforms.cc"], hdrs = ["cc/transforms.h"], deps = [ ":transform_config_cc_proto", "//monolith/native_training/data/training_instance:instance_utils", "//monolith/native_training/data/kernels/internal:label_utils", "//monolith/native_training/data/kernels/internal:value_filter_by_line_id", "//monolith/native_training/data/kernels/internal:relational_utils", "//monolith/native_training/runtime/common:linalg_utils", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", ], ) py_library( name = "transforms_py", srcs = [ "transforms.py", ], deps = [ ":transform_config_py_proto", "//idl:example_py_proto", "//idl:proto_parser_py_proto", ], ) py_test( name = "transforms_test", srcs = ["transforms_test.py"], deps = [ ":transforms_py", ], ) exports_files([ "cc/transforms.cc", ]) ================================================ FILE: monolith/native_training/data/transform/cc/transforms.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/data/transform/cc/transforms.h" #include #include #include "absl/base/internal/cycleclock.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "glog/logging.h" #include "monolith/native_training/data/kernels/internal/label_utils.h" #include "monolith/native_training/data/kernels/internal/value_filter_by_line_id.h" #include "monolith/native_training/data/kernels/internal/relational_utils.h" #include "monolith/native_training/data/training_instance/cc/instance_utils.h" #include "monolith/native_training/runtime/common/linalg_utils.h" #include "tensorflow/core/platform/logging.h" #include "third_party/nlohmann/json.hpp" namespace tensorflow { namespace monolith_tf { using ::google::protobuf::RepeatedField; using ::idl::matrix::proto::LineId; using internal::LineIdValueFilter; using ::monolith::common::IsAlmostEqual; using ::monolith::io::proto::Example; using ::monolith::io::proto::ExampleBatch; using ::parser::proto::Instance; class LogEveryNSecState { public: bool ShouldLog(double seconds) { LossyIncrement(&counter_); const int64 now_cycles = absl::base_internal::CycleClock::Now(); int64 next_cycles = next_log_time_cycles_.load(std::memory_order_relaxed); do { if (now_cycles <= next_cycles) return false; } while (!next_log_time_cycles_.compare_exchange_weak( next_cycles, now_cycles + seconds * absl::base_internal::CycleClock::Frequency(), std::memory_order_relaxed, std::memory_order_relaxed)); return true; } uint32 counter() { return counter_.load(std::memory_order_relaxed); } private: // The following code behaves like AtomicStatsCounter::LossyAdd() for // speed since it is fine to lose occasional updates. // Returns old value of *counter. uint32 LossyIncrement(std::atomic* counter) { const uint32 value = counter->load(std::memory_order_relaxed); counter->store(value + 1, std::memory_order_relaxed); return value; } std::atomic counter_{0}; // Cycle count according to CycleClock that we should next log at. std::atomic next_log_time_cycles_{0}; }; class TransformSummary : public TransformInterface { public: explicit TransformSummary(std::unique_ptr transform, bool print_summary = false) : transform_(std::move(transform)), offset_(0), input_total_(0), output_total_(0) {} ~TransformSummary() override { LOG(INFO) << "Finally " << this->DebugString(); } std::string Name() override { return transform_->Name(); } std::string DebugString() { float rate = input_total_ == 0 ? 0 : static_cast(output_total_) / input_total_; return absl::StrFormat( "%s, input = %ld, output = %ld, retention rate = %.2f", transform_->Name(), input_total_, output_total_, rate); } void Transform(std::shared_ptr instance, std::vector>* output) override { offset_ = output->size(); input_total_ += 1; transform_->Transform(instance, output); output_total_ += output->size() - offset_; if (every_n_sec_state_.ShouldLog(60 * 5)) { LOG(INFO) << DebugString(); } } void Transform(std::shared_ptr example, std::vector>* output) override { offset_ = output->size(); input_total_ += 1; transform_->Transform(example, output); output_total_ += output->size() - offset_; if (every_n_sec_state_.ShouldLog(60 * 5)) { LOG(INFO) << DebugString(); } } private: std::unique_ptr transform_; int64_t offset_; int64_t input_total_; int64_t output_total_; LogEveryNSecState every_n_sec_state_; }; class Identity : public TransformInterface { public: std::string Name() override { return "Identity"; } void Transform(std::shared_ptr instance, std::vector>* output) override { output->push_back(instance); } void Transform(std::shared_ptr example, std::vector>* output) override { output->push_back(example); } }; class FilterByFid : public TransformInterface { public: explicit FilterByFid(FilterByFidConfig config) : config_(std::move(config)) { filter_fids_.insert(config_.filter_fids().begin(), config_.filter_fids().end()); has_fids_.insert(config_.has_fids().begin(), config_.has_fids().end()); select_fids_.insert(config_.select_fids().begin(), config_.select_fids().end()); req_time_min_ = 0; } std::string Name() override { return "FilterByFid"; } void Transform(std::shared_ptr instance, std::vector>* output) override { if (tensorflow::monolith_tf::IsInstanceOfInterest(*instance, filter_fids_, has_fids_, select_fids_, {}, req_time_min_, {})) { output->push_back(instance); } } void Transform(std::shared_ptr example, std::vector>* output) override { if (tensorflow::monolith_tf::IsInstanceOfInterest(*example, filter_fids_, has_fids_, select_fids_, {}, req_time_min_, {})) { output->push_back(example); } } private: std::set filter_fids_; std::set has_fids_; std::set select_fids_; int64_t req_time_min_; FilterByFidConfig config_; }; class FilterByAction : public TransformInterface { public: explicit FilterByAction(FilterByActionConfig config) : config_(std::move(config)) { has_actions_.insert(config_.has_actions().begin(), config_.has_actions().end()); } std::string Name() override { return "FilterByAction"; } void Transform(std::shared_ptr instance, std::vector>* output) override { if (tensorflow::monolith_tf::IsInstanceOfInterest(*instance, {}, {}, {}, has_actions_, 0, {})) { output->push_back(instance); } } void Transform(std::shared_ptr example, std::vector>* output) override { if (tensorflow::monolith_tf::IsInstanceOfInterest(*example, {}, {}, {}, has_actions_, 0, {})) { output->push_back(example); } } private: std::set has_actions_; FilterByActionConfig config_; }; class FilterByLabel : public TransformInterface { public: explicit FilterByLabel(FilterByLabelConfig config) : config_(std::move(config)) {} std::string Name() override { return "FilterByLabel"; } void Transform(std::shared_ptr instance, std::vector>* output) override { if (IsInstanceOfInterest(instance->label())) { output->push_back(instance); } } void Transform(std::shared_ptr example, std::vector>* output) override { if (IsInstanceOfInterest(example->label())) { output->push_back(example); } } private: bool IsInstanceOfInterest(const RepeatedField& labels) const { if (labels.size() < config_.thresholds_size()) { LOG_EVERY_N_SEC(ERROR, 60) << absl::StrFormat( "Label size(=%ld) should be >= label_threshold size(=%ld), please " "investigate!", labels.size(), config_.thresholds_size()); return false; } for (int i = 0; i < config_.thresholds_size(); ++i) { if (labels.Get(i) >= config_.thresholds(i)) { return true; } } return false; } FilterByLabelConfig config_; }; class FilterByValue : public TransformInterface { public: explicit FilterByValue(FilterByValueConfig config) : config_(std::move(config)) { field_name_ = config_.field_name(); op_ = config_.op(); float_operand_.insert(float_operand_.end(), config_.float_operand().begin(), config_.float_operand().end()); int_operand_.insert(int_operand_.end(), config_.int_operand().begin(), config_.int_operand().end()); string_operand_.insert(string_operand_.end(), config_.string_operand().begin(), config_.string_operand().end()); keep_empty_ = config_.keep_empty(); operand_filepath_ = config_.operand_filepath(); line_id_value_filter_ = std::make_unique( field_name_, op_, float_operand_, int_operand_, string_operand_, operand_filepath_, keep_empty_); } std::string Name() override { return "FilterByValue"; } void Transform(std::shared_ptr instance, std::vector>* output) override { if (IsInstanceOfInterest(instance->line_id())) { output->push_back(instance); } } void Transform(std::shared_ptr example, std::vector>* output) override { if (IsInstanceOfInterest(example->line_id())) { output->push_back(example); } } private: // TODO(huangruiteng): support value filter by feature bool IsInstanceOfInterest(const LineId& line_id) const { tensorflow::Env* env = tensorflow::Env::Default(); return line_id_value_filter_->IsInstanceOfInterest(env, line_id); } FilterByValueConfig config_; std::string field_name_; std::string op_; // gt, ge, eq, lt, le, neq, between bool keep_empty_ = false; std::string operand_filepath_; std::vector float_operand_; std::vector int_operand_; std::vector string_operand_; std::unique_ptr line_id_value_filter_; }; class AddLabel : public TransformInterface { public: explicit AddLabel(AddLabelConfig config) : config_(std::move(config)) { task_configs_.reserve(config_.task_label_configs_size()); for (const auto& t : config_.task_label_configs()) { std::set pos_actions, neg_actions; CHECK(!t.pos_actions().empty()); pos_actions.insert(t.pos_actions().begin(), t.pos_actions().end()); neg_actions.insert(t.neg_actions().begin(), t.neg_actions().end()); CHECK(!internal::HasIntersection(pos_actions, neg_actions)); float sample_rate = t.sample_rate(); CHECK_GE(sample_rate, 0); CHECK_LE(sample_rate, 1.0); task_configs_.push_back({pos_actions, neg_actions, sample_rate}); } for (size_t i = 0; i < task_configs_.size(); ++i) { LOG(INFO) << absl::StrFormat("Task #%d config: %s", i + 1, task_configs_[i].ToString()); } LOG(INFO) << absl::StrFormat("sample_rate = %.4f", config_.new_sample_rate()); std::size_t seed = std::chrono::system_clock::now().time_since_epoch().count(); random_generator_.seed(seed); random_neg_sample_ = std::uniform_real_distribution(0.0, 1.0); } std::string Name() override { return "AddLabel"; } void Transform(std::shared_ptr instance, std::vector>* output) override { DoAddLabel(instance->mutable_line_id(), instance->mutable_label()); output->push_back(instance); } void Transform(std::shared_ptr example, std::vector>* output) override { DoAddLabel(example->mutable_line_id(), example->mutable_label()); output->push_back(example); } private: void DoAddLabel(LineId* mutable_line_id, google::protobuf::RepeatedField* mutable_label) { std::set actions(mutable_line_id->actions().begin(), mutable_line_id->actions().end()); if (!mutable_label->empty() && mutable_label->Get(0) <= 0) { mutable_label->Set(0, internal::INVALID_LABEL); } for (const auto& t : task_configs_) { bool has_pos = internal::HasIntersection(actions, t.pos_actions); bool has_neg = internal::HasIntersection(actions, t.neg_actions); if (!t.neg_actions.empty()) { // If there is given neg_actions if (!has_pos && !has_neg) { mutable_label->Add(internal::INVALID_LABEL); } else if (has_pos) { // (has_pos && !has_neg) || (has_pos && has_neg) mutable_label->Add(internal::POSITIVE_LABEL); } else { // !has_pos && has_neg if (SelectedByNegativeSampling(t)) { mutable_label->Add(config_.negative_value()); } else { mutable_label->Add(internal::INVALID_LABEL); } } } else { // If there is no given neg_actions if (has_pos) { mutable_label->Add(internal::POSITIVE_LABEL); } else { if (SelectedByNegativeSampling(t)) { mutable_label->Add(config_.negative_value()); } else { mutable_label->Add(internal::INVALID_LABEL); } } } } mutable_line_id->set_sample_rate(config_.new_sample_rate()); } bool SelectedByNegativeSampling(const internal::TaskConfig& t) { return IsAlmostEqual(t.sample_rate, 1.0f) || random_neg_sample_(random_generator_) < t.sample_rate; } std::vector task_configs_; std::default_random_engine random_generator_; std::uniform_real_distribution random_neg_sample_; AddLabelConfig config_; }; class LogicalOrTransform : public TransformInterface { public: LogicalOrTransform(std::unique_ptr t1, std::unique_ptr t2) : t1_(std::move(t1)), t2_(std::move(t2)) {} std::string Name() override { return absl::StrFormat("(%s or %s)", t1_->Name(), t2_->Name()); } void Transform(std::shared_ptr instance, std::vector>* output) override { std::vector> intermediates; t1_->Transform(instance, &intermediates); t2_->Transform(instance, &intermediates); if (!intermediates.empty()) { CHECK_LE(intermediates.size(), 2); output->push_back(intermediates.front()); } } void Transform(std::shared_ptr example, std::vector>* output) override { std::vector> intermediates; t1_->Transform(example, &intermediates); t2_->Transform(example, &intermediates); if (!intermediates.empty()) { CHECK_LE(intermediates.size(), 2); output->push_back(intermediates.front()); } } private: std::unique_ptr t1_; std::unique_ptr t2_; }; class CombinedTransform : public TransformInterface { public: CombinedTransform(std::unique_ptr t1, std::unique_ptr t2) : t1_(std::move(t1)), t2_(std::move(t2)) {} std::string Name() override { return absl::StrFormat("(%s and %s)", t1_->Name(), t2_->Name()); } void Transform(std::shared_ptr instance, std::vector>* output) override { std::vector> intermediates; t1_->Transform(instance, &intermediates); for (const auto& intermediate : intermediates) { t2_->Transform(intermediate, output); } } void Transform(std::shared_ptr example, std::vector>* output) override { std::vector> intermediates; t1_->Transform(example, &intermediates); for (const auto& intermediate : intermediates) { t2_->Transform(intermediate, output); } } private: std::unique_ptr t1_; std::unique_ptr t2_; }; std::unique_ptr NewTransformSummary( std::unique_ptr transform, bool print_summary) { return std::make_unique(std::move(transform), print_summary); } std::unique_ptr NewIdentity() { return std::make_unique(); } std::unique_ptr NewFilterByFid(FilterByFidConfig config) { return std::make_unique(std::move(config)); } std::unique_ptr NewFilterByAction( FilterByActionConfig config) { return std::make_unique(std::move(config)); } std::unique_ptr NewFilterByLabel( FilterByLabelConfig config) { return std::make_unique(std::move(config)); } std::unique_ptr NewAddLabel(AddLabelConfig config) { return std::make_unique(std::move(config)); } std::unique_ptr NewFilterByValue( FilterByValueConfig config) { return std::make_unique(std::move(config)); } std::unique_ptr CombineTransforms( std::unique_ptr t1, std::unique_ptr t2) { return std::make_unique(std::move(t1), std::move(t2)); } std::unique_ptr CombineLogicalOrTransforms( std::unique_ptr t1, std::unique_ptr t2) { return std::make_unique(std::move(t1), std::move(t2)); } std::unique_ptr NewTransformFromBasicConfig( BasicTransformConfig config) { std::string name; std::unique_ptr transform = nullptr; switch (config.type_case()) { case (BasicTransformConfig::kFilterByFid): transform = NewFilterByFid(std::move(*config.mutable_filter_by_fid())); break; case (BasicTransformConfig::kFilterByAction): transform = NewFilterByAction(std::move(*config.mutable_filter_by_action())); break; case (BasicTransformConfig::kFilterByLabel): transform = NewFilterByLabel(std::move(*config.mutable_filter_by_label())); break; case (BasicTransformConfig::kAddLabel): transform = NewAddLabel(std::move(*config.mutable_add_label())); break; case (BasicTransformConfig::kFilterByValue): transform = NewFilterByValue(std::move(*config.mutable_filter_by_value())); break; default: throw std::invalid_argument(absl::StrFormat( "transform is not implemented yet. %s", config.ShortDebugString())); } return NewTransformSummary(std::move(transform)); } std::unique_ptr NewTransformFromConfig( const TransformConfig& config) { std::unique_ptr transform = nullptr; for (const auto& c : config.configs()) { std::unique_ptr t; if (c.has_basic_config()) { t = NewTransformFromBasicConfig(c.basic_config()); } else if (c.has_logical_or_config()) { std::unique_ptr t1 = NewTransformFromBasicConfig(c.logical_or_config().x()); std::unique_ptr t2 = NewTransformFromBasicConfig(c.logical_or_config().y()); t = CombineLogicalOrTransforms(std::move(t1), std::move(t2)); } AssignOrCombine(&transform, std::move(t), CombineTransforms); } if (transform == nullptr) { transform = NewIdentity(); } return NewTransformSummary(std::move(transform), true); } } // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/data/transform/cc/transforms.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_NATIVE_TRAINING_DATA_TRANSFORM_CC_TRANSFORMS_H_ #define MONOLITH_NATIVE_TRAINING_DATA_TRANSFORM_CC_TRANSFORMS_H_ #include #include #include "idl/matrix/proto/example.pb.h" #include "idl/matrix/proto/proto_parser.pb.h" #include "monolith/native_training/data/transform/transform_config.pb.h" namespace tensorflow { namespace monolith_tf { using monolith::native_training::data::AddLabelConfig; using monolith::native_training::data::BasicTransformConfig; using monolith::native_training::data::FilterByActionConfig; using monolith::native_training::data::FilterByFidConfig; using monolith::native_training::data::FilterByLabelConfig; using monolith::native_training::data::FilterByValueConfig; using monolith::native_training::data::LogicalOrTransformConfig; using monolith::native_training::data::TransformConfig; using monolith::native_training::data::TransformConfig_OneTransformConfig; class TransformInterface { public: virtual ~TransformInterface() = default; virtual std::string Name() = 0; virtual void Transform( std::shared_ptr<::parser::proto::Instance>, std::vector>*) = 0; virtual void Transform( std::shared_ptr<::monolith::io::proto::Example>, std::vector>*) = 0; }; std::unique_ptr NewTransformSummary( std::unique_ptr transform, bool print_summary = false); std::unique_ptr NewIdentity(); std::unique_ptr NewFilterByFid(FilterByFidConfig config); std::unique_ptr NewFilterByAction( FilterByActionConfig config); std::unique_ptr NewFilterByLabel( FilterByLabelConfig config); std::unique_ptr NewAddLabel(AddLabelConfig config); std::unique_ptr NewFilterByValue( FilterByValueConfig config); std::unique_ptr CombineTransforms( std::unique_ptr t1, std::unique_ptr t2); std::unique_ptr NewTransformFromBasicConfig( BasicTransformConfig config); std::unique_ptr NewTransformFromConfig( const TransformConfig& config); template void AssignOrCombine(T* t1, T t2, F combine_fn) { if (*t1 == nullptr) { *t1 = std::move(t2); return; } *t1 = combine_fn(std::move(*t1), std::move(t2)); } } // namespace monolith_tf } // namespace tensorflow #endif // MONOLITH_NATIVE_TRAINING_DATA_TRANSFORM_CC_TRANSFORMS_H_ ================================================ FILE: monolith/native_training/data/transform/transform_config.proto ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. syntax = "proto2"; package monolith.native_training.data; message FilterByFidConfig { repeated uint64 has_fids = 1; repeated uint64 filter_fids = 2; repeated uint64 select_fids = 3; } message FilterByActionConfig { repeated int32 has_actions = 1; } message AddLabelConfig { message TaskLabelConfig { repeated int32 pos_actions = 1; repeated int32 neg_actions = 2; optional float sample_rate = 3 [default = 1.0]; } repeated TaskLabelConfig task_label_configs = 1; optional float negative_value = 2 [default = 0.0]; optional float new_sample_rate = 3 [default = 1.0]; } message FilterByLabelConfig { repeated float thresholds = 1; } message FilterByValueConfig { required string field_name = 1; required string op = 2; repeated float float_operand = 3; repeated int64 int_operand = 4; repeated string string_operand = 5; optional bool keep_empty = 6 [default = false]; optional string operand_filepath = 7 [default = ""]; } message BasicTransformConfig { oneof type { FilterByFidConfig filter_by_fid = 1; FilterByActionConfig filter_by_action = 2; FilterByLabelConfig filter_by_label = 3; AddLabelConfig add_label = 4; FilterByValueConfig filter_by_value = 5; } } message LogicalOrTransformConfig { required BasicTransformConfig x = 1; required BasicTransformConfig y = 2; } message TransformConfig { message OneTransformConfig { oneof type { BasicTransformConfig basic_config = 1; LogicalOrTransformConfig logical_or_config = 2; } } repeated OneTransformConfig configs = 1; } ================================================ FILE: monolith/native_training/data/transform/transforms.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import abc from typing import List, Union from monolith.native_training.data.transform import transform_config_pb2 from idl.matrix.proto.line_id_pb2 import LineId class Transform(abc.ABC): @abc.abstractmethod def as_proto(self) -> transform_config_pb2.TransformConfig(): pass @abc.abstractmethod def _is_leaf_node(self) -> bool: pass class Compose(Transform): """Composes several transforms together. Args: transforms (list of ``Transform`` objects): list of transforms to compose. Example: >>> transforms.Compose([ >>> transforms.FilterByFid(has_fids=[1]), >>> transforms.FilterByLabel(thresholds=[-100]), >>> ]) """ def __init__(self, transforms: List[Transform]): assert all(isinstance(t, Transform) for t in transforms) self.transforms = transforms def as_proto(self) -> transform_config_pb2.TransformConfig(): config = transform_config_pb2.TransformConfig() for t in self.transforms: config.MergeFrom(t.as_proto()) return config def _is_leaf_node(self) -> bool: return False class FilterByFid(Transform): def __init__(self, has_fids: List[int] = None, filter_fids: List[int] = None, select_fids: List[int] = None): self.has_fids = has_fids self.filter_fids = filter_fids self.select_fids = select_fids def as_proto(self) -> transform_config_pb2.TransformConfig(): config = transform_config_pb2.TransformConfig() transform = config.configs.add() transform.basic_config.filter_by_fid.has_fids.extend(self.has_fids) transform.basic_config.filter_by_fid.filter_fids.extend(self.filter_fids) transform.basic_config.filter_by_fid.select_fids.extend(self.select_fids) return config def _is_leaf_node(self) -> bool: return True class FilterByAction(Transform): def __init__(self, has_actions: List[int] = None): self.has_actions = has_actions def as_proto(self) -> transform_config_pb2.TransformConfig(): config = transform_config_pb2.TransformConfig() transform = config.configs.add() transform.basic_config.filter_by_action.has_actions.extend(self.has_actions) return config def _is_leaf_node(self) -> bool: return True class FilterByLabel(Transform): def __init__(self, thresholds=List[float]): self.thresholds = thresholds def as_proto(self) -> transform_config_pb2.TransformConfig(): config = transform_config_pb2.TransformConfig() transform = config.configs.add() transform.basic_config.filter_by_label.thresholds.extend(self.thresholds) return config def _is_leaf_node(self) -> bool: return True class FilterByValue(Transform): def __init__( self, field_name: str, op: str, operand: Union[float, int, str, List[float], List[int], List[str]], keep_empty: bool = False, ): assert op in { 'gt', 'ge', 'eq', 'lt', 'le', 'neq', 'between', 'in', 'not-in', 'all', 'any', 'diff', 'startswith', 'endswith' } fields = LineId.DESCRIPTOR.fields_by_name assert field_name in fields assert operand is not None field = fields[field_name] string_operand = [] if field.has_options: assert op in {'all', 'any', 'diff'} assert field.cpp_type in { field.CPPTYPE_INT32, field.CPPTYPE_INT64, field.CPPTYPE_UINT32, field.CPPTYPE_UINT64 } if not isinstance(operand, (list, tuple)): assert isinstance(operand, int) int_operand, float_operand = [operand], [] else: assert all(isinstance(o, int) for o in operand) int_operand, float_operand = list(operand), [] elif field.cpp_type in {field.CPPTYPE_DOUBLE, field.CPPTYPE_FLOAT}: if op == 'between': assert all(isinstance(o, (int, float)) for o in operand) int_operand, float_operand = [], [float(o) for o in operand] else: int_operand, float_operand = [], [float(operand)] elif field.cpp_type in { field.CPPTYPE_INT32, field.CPPTYPE_INT64, field.CPPTYPE_UINT32, field.CPPTYPE_UINT64 }: if op in {'in', 'not-in', 'between'}: assert all(isinstance(o, int) for o in operand) int_operand, float_operand = list(operand), [] else: int_operand, float_operand = [int(operand)], [] elif field.cpp_type == field.CPPTYPE_STRING: int_operand, float_operand = [], [] if isinstance(operand, str): string_operand.append(operand) elif isinstance(operand, (list, tuple)): assert all(isinstance(o, str) for o in operand) string_operand.extend(operand) else: raise RuntimeError("params error!") else: raise RuntimeError("params error!") self.field_name = field_name self.op = op self.float_operand = float_operand self.int_operand = int_operand self.string_operand = string_operand self.keep_empty = keep_empty def as_proto(self) -> transform_config_pb2.TransformConfig(): config = transform_config_pb2.TransformConfig() transform = config.configs.add() transform.basic_config.filter_by_value.field_name = self.field_name transform.basic_config.filter_by_value.op = self.op transform.basic_config.filter_by_value.float_operand.extend( self.float_operand) transform.basic_config.filter_by_value.int_operand.extend(self.int_operand) transform.basic_config.filter_by_value.string_operand.extend( self.string_operand) transform.basic_config.filter_by_value.keep_empty = self.keep_empty return config def _is_leaf_node(self) -> bool: return True class AddLabel(Transform): def __init__(self, config: str, negative_value: float, new_sample_rate: float): self.config = config self.negative_value = negative_value self.new_sample_rate = new_sample_rate def as_proto(self) -> transform_config_pb2.TransformConfig(): config = transform_config_pb2.TransformConfig() transform = config.configs.add() transform.basic_config.add_label.negative_value = self.negative_value transform.basic_config.add_label.new_sample_rate = self.new_sample_rate for task in self.config.split(';'): # skip empty parts, e.g. config = '1,2:3:1.0;' if len(task) == 0: continue task_label_config = transform.basic_config.add_label.task_label_configs.add( ) pos_actions, neg_actions, sample_rate = task.split(':') pos_actions_list = [ int(pos) for pos in pos_actions.split(',') if len(pos) > 0 ] neg_actions_list = [ int(neg) for neg in neg_actions.split(',') if len(neg) > 0 ] task_label_config.pos_actions.extend(pos_actions_list) task_label_config.neg_actions.extend(neg_actions_list) task_label_config.sample_rate = float(sample_rate) return config def _is_leaf_node(self) -> bool: return True class LogicalOr(Transform): def __init__(self, x: Transform, y: Transform): self.x = x self.y = y assert x._is_leaf_node() and y._is_leaf_node() def as_proto(self) -> transform_config_pb2.TransformConfig(): config = transform_config_pb2.TransformConfig() transform = config.configs.add() transform.logical_or_config.x.CopyFrom( self.x.as_proto().configs[0].basic_config) transform.logical_or_config.y.CopyFrom( self.y.as_proto().configs[0].basic_config) return config def _is_leaf_node(self) -> bool: return False ================================================ FILE: monolith/native_training/data/transform/transforms_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import unittest from absl import app, logging from monolith.native_training.data.transform import transforms class TransformsTest(unittest.TestCase): def test_filter_by_fid(self): proto = transforms.FilterByFid(has_fids=[1], filter_fids=[2, 3], select_fids=None).as_proto() logging.info(proto) def test_filter_by_action(self): proto = transforms.FilterByAction(has_actions=[4]).as_proto() logging.info(proto) def test_filter_by_label(self): proto = transforms.FilterByLabel(thresholds=[-100, -100]).as_proto() logging.info(proto) def test_add_label(self): proto = transforms.AddLabel(config='1,2:3:1.0;4::0.5', negative_value=0.0, new_sample_rate=0.3).as_proto() logging.info(proto) def test_logical_or(self): proto = transforms.LogicalOr( x=transforms.FilterByAction(has_actions=[1, 2]), y=transforms.FilterByFid(has_fids=[10000000])).as_proto() logging.info(proto) def test_compose(self): transform = transforms.Compose([ transforms.FilterByFid(has_fids=[1], filter_fids=[2, 3], select_fids=None), transforms.FilterByLabel(thresholds=[-100, -100]), transforms.AddLabel(config='1,2:3:1.0;4::0.5', negative_value=0.0, new_sample_rate=0.3), transforms.LogicalOr(x=transforms.FilterByAction(has_actions=[1, 2]), y=transforms.FilterByFid(has_fids=[10000000])) ]) logging.info(transform.as_proto()) def main(_): logging.set_verbosity(logging.INFO) unittest.main() if __name__ == '__main__': app.run(main) ================================================ FILE: monolith/native_training/data/transform_dataset_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import io import time from absl import logging import os import uuid import struct import tensorflow as tf import tempfile from typing import List, BinaryIO from idl.matrix.proto import proto_parser_pb2, example_pb2 from monolith.native_training.data.datasets import PBDataset, PbType from monolith.native_training.data.parsers import (parse_instances, parse_examples) from monolith.native_training.data.transform import transforms fid_v1_mask = (1 << 54) - 1 fid_v2_mask = (1 << 48) - 1 def get_fid_v1(slot: int, signautre: int): return (slot << 54) | (signautre & fid_v1_mask) def get_fid_v2(slot: int, signature: int): return (slot << 48) | (signature & fid_v2_mask) def mock_instance_line_id(index: int, instance, actions: List[int]): instance.line_id.user_id = "test_{}".format(uuid.uuid4()) instance.line_id.uid = 100 instance.line_id.read_count = 0 if 20 <= index < 40 else 1 instance.line_id.video_play_time = 0.0 if 20 <= index < 30 else 1.0 instance.line_id.req_time = int(time.time()) instance.line_id.sample_rate = 0.5 instance.line_id.actions.extend(actions) def generate_instance_or_example(variant_type: str, index: int, labels: List[int], actions: List[int], fid_v1_list: List[int] = None): assert variant_type in {"instance", "example"} if variant_type == "instance": instance = proto_parser_pb2.Instance() instance.fid.extend(fid_v1_list if fid_v1_list else []) else: instance = example_pb2.Example() named_feature = instance.named_feature.add() named_feature.name = "fc_slot_1" named_feature.feature.fid_v1_list.value.extend(fid_v1_list) instance.label.extend(labels) mock_instance_line_id(index, instance, actions) return instance def write_instance_into_file(file: BinaryIO, instance): sort_id = str(instance.line_id.user_id) file.write(struct.pack(' flask.Flask: app = flask.Flask("Monolith_Debugging_Server") worker = DebuggingWorker(FLAGS.model_dir) @app.route("/debugging/variables", methods=["POST"]) def fetch_variables(): try: data = request.get_data() data = json.loads(data) logging.info("Fetch variables req: %s" % data) result = worker.fetch_variables(data.get("variable_names", [])) resp = {STATUS: SUCCESS, MSG: json.dumps(result)} except: resp = {STATUS: FAIL, MSG: traceback.format_exc()} logging.info("Fetch variables resp: %s" % resp) return resp @app.route("/debugging/features", methods=["POST"]) def fetch_features(): try: data = request.get_data() data = json.loads(data) logging.info("Fetch features req: %s" % data) feature_names = data.get("feature_names", []) feature_ids = data.get("feature_ids", []) if len(feature_names) != len(feature_ids): raise Exception( "Size of feature names [%s] and size of feature ids [%s] must be equal." % (len(feature_names), len(feature_ids))) result = worker.fetch_features(feature_names, feature_ids) resp = {STATUS: SUCCESS, MSG: json.dumps(result)} except: resp = {STATUS: FAIL, MSG: traceback.format_exc()} logging.info("Fetch features resp: %s" % resp) return resp return app def main(_): env_utils.setup_hdfs_env() server_app = create_app() server_app.run(host=FLAGS.host, port=FLAGS.port) if __name__ == "__main__": app.run(main) ================================================ FILE: monolith/native_training/demo.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """bazel run -c opt monolith/native_training:demo -- --num_ps=0""" from absl import app from absl import flags from absl import logging import tensorflow as tf from monolith.native_training import cpu_training from monolith.native_training.model import TestFFMModel from monolith.native_training.model_export.export_context import ExportMode FLAGS = flags.FLAGS flags.DEFINE_integer( "num_ps", default=0, help=( "Number of parameter servers. 0 means no parameter server. Everything " "runs on the single local server.")) flags.DEFINE_string( "model_dir", default=None, help="Directory where model parameters, graph, etc are saved.") def main(_): params = TestFFMModel.params() params.name = "test_ffm_model" params.train.per_replica_batch_size = 64 params.serving.export_when_saving = True params.serving.export_mode = ExportMode.DISTRIBUTED cpu_training.local_train(params, num_ps=FLAGS.num_ps, model_dir=FLAGS.model_dir, steps=100, save_checkpoints_steps=50) if __name__ == '__main__': logging.set_verbosity(logging.INFO) tf.compat.v1.disable_eager_execution() app.run(main) ================================================ FILE: monolith/native_training/dense_reload_utils.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import re import tensorflow as tf from absl import logging, flags import numpy as np from collections import defaultdict from typing import Dict, List, Optional, Union, Set, Iterator, Tuple from tensorflow.python.training.saver import Saver from tensorflow.python.client.session import Session from tensorflow.python.training.py_checkpoint_reader import NewCheckpointReader, CheckpointReader from monolith.native_training.basic_restore_hook import CheckpointRestorerListener from monolith.native_training.model_export.export_context import is_exporting CUSTOM_RESTORE_OP = 'custom_restore_op' CustomRestoreListenerKey = 'CustomRestoreListener' # [TODO](fitz) this may not cover all situation PAT = re.compile(r"^.+/part_(\d+)(/.*)?$") DensePat = re.compile( r'''^(.*/)?(\w+(?:_\d+)?(?:\d+)?)/(bias|kernel|trainable_kernel_norm)(.*)''' ) FLAGS = flags.FLAGS # for those name cannot convert auto, ue use re-express to convert them _NameMapping = { re.compile('c_dot/mlp(.*)?$'): 'c_dot/compress_tower{}', re.compile('^dcn/kernel_(\d+)/trainable_norm(.*)?$'): 'kernel_{}_trainable_norm{}', re.compile('^dcn/kernel_(\d+)(.*)?$'): 'kernel_{}{}', re.compile('^dcn/bias_(\d+)/trainable_norm(.*)?$'): 'bias_{}_trainable_norm{}', re.compile('^dcn/bias_(\d+)(.*)?$'): 'bias_{}{}', } def add_mapping_rules(rules: Dict[str, str]): global _NameMapping _NameMapping.update({re.compile(pat): fmt for pat, fmt in rules.items()}) def node_name(name: str): name = name.strip().rstrip('/') if name.startswith('^'): name = name[1:] if ':' in name: frist, second = name.rsplit(':', 1) if second.isdigit(): name = frist return name def get_new_name(name: str): selected = [] name = node_name(name) for term in name.split('/'): if term not in selected: selected.append(term) return '/'.join(selected) def get_guess_name(name: str): for pat, fmt in _NameMapping.items(): matched = pat.match(name) if matched: print(pat, matched.groups()) guess_name = fmt.format(*matched.groups()) return guess_name return name def split_name(name: str) -> int: out = [] for i in range(-1, -len(name), -1): x = name[i] if x.isdigit(): out.append(x) else: break if out: out.reverse() return name[0:len(name) - len(out)], int(''.join(out)) else: return name, 0 def calc_reorder_info(names: List[str], is_ordered: bool = True) -> Tuple[bool, str]: assert names is not None and len(names) > 0, str(names) if not is_ordered: names = names.copy() names.sort(key=lambda x: split_name(x)[1]) _, start = split_name(names[0]) base_name, end = split_name(names[-1]) if start in {0, 1} and end - start == len(names) - 1 and len(names) == 1: return False, 'dense_' if base_name == 'dense' else base_name else: return True, 'dense_' if base_name == 'dense' else base_name def get_full_prefix(short_prefix: str, prefix_set: Set[str]) -> str: out = None for p in prefix_set: if p.endswith(short_prefix): if out is None: out = p elif len(out) < len(p): out = p return out or short_prefix def update_var_name_mapping_for_dense( var_name_mapping: Dict[str, str]) -> Dict[str, str]: dense_layers, prefixs = defaultdict(list), defaultdict(set) for name, origin in var_name_mapping.items(): matched = DensePat.match(name) if matched: prefix = matched.group(1).rstrip('/') if matched.group(1) else '' dense_name = matched.group(2) dense_local_var = matched.group(3) surfix = matched.group(4).lstrip('/') key = f'{dense_name}/{dense_local_var}' dense_layers[key].append( (prefix, dense_name, dense_local_var, surfix, origin)) if dense_local_var == 'bias': prefixs[key].add(prefix) dense_layers_refactor = defaultdict(list) for key, terms_list in dense_layers.items(): for terms in terms_list: prefix, dense_name, dense_local_var, surfix, origin = terms prefix_set = prefixs[f'{dense_name}/bias'] prefix = get_full_prefix(prefix, prefix_set) if prefix not in dense_layers_refactor: dense_layers_refactor[prefix] = defaultdict(list) dense_layers_refactor[prefix][dense_name].append( (prefix, dense_name, dense_local_var, surfix, origin)) for bias_prefix, layers_vars in dense_layers_refactor.items(): dense_names = list(layers_vars) if len(dense_names) > 1: dense_names.sort(key=lambda x: split_name(x)[1]) need_reorder, base = calc_reorder_info(dense_names) for i, dense_name in enumerate(dense_names): new_dense_name = f'{base}{i}' if need_reorder else dense_name for var_terms in layers_vars[dense_name]: prefix, local_var_name, origin = var_terms[0], var_terms[2], var_terms[ -1] if prefix == '' or bias_prefix.endswith(prefix): new_name = '/'.join( [bias_prefix, new_dense_name, local_var_name, var_terms[3]]).rstrip('/') if local_var_name == 'bias': var_name_mapping[new_name] = origin else: if new_name not in var_name_mapping: var_name_mapping[new_name] = origin # note: this may introduce problem, we deal with it at calc_feed_dict for layers_vars in dense_layers_refactor.values(): for var_terms_list in layers_vars.values(): for var_terms in var_terms_list: new_name = '/'.join(var_terms[0:-1]).rstrip('/') if new_name not in var_name_mapping: var_name_mapping[new_name] = var_terms[-1] class CustomRestoreListener(CheckpointRestorerListener): def __init__(self, alias_map: Dict[str, str] = None, clear_nn: bool = False, continue_training: bool = False, model_dir: str = None, enable_alias_map_auto_gen: bool = None): self._alias_map = alias_map self._clear_nn = clear_nn self._continue_training = continue_training self.model_dir = model_dir self.ckpt_name = None self.enable_alias_map_auto_gen = True if enable_alias_map_auto_gen is None else enable_alias_map_auto_gen def begin(self): logging.info('CustomRestoreListener begin ...') if is_exporting(): return checkpoint_state = None try: checkpoint_state = tf.train.get_checkpoint_state( checkpoint_dir=self.model_dir) self.ckpt_name = checkpoint_state.model_checkpoint_path except Exception as e: return if checkpoint_state is None: return graph: tf.Graph = tf.compat.v1.get_default_graph() variables = graph.get_collection('variables') if self._clear_nn: assert self._alias_map is None if self.model_dir: flag_file = os.path.join(self.model_dir, 'clear_nn') if tf.io.gfile.exists(flag_file): logging.info( f'the clear nn flag_file exists, skip clear, {flag_file}') return init_op = tf.compat.v1.global_variables_initializer() setattr(init_op, 'model_dir', self.model_dir) if self._continue_training: gs_var = tf.compat.v1.train.get_or_create_global_step(graph=graph) ph = tf.compat.v1.placeholder(dtype=gs_var.dtype, shape=gs_var.shape, name="global_step_ph") update_gs_op = gs_var.assign(value=ph) graph.add_to_collection(CUSTOM_RESTORE_OP, ([init_op, update_gs_op], [ph], None)) else: graph.add_to_collection(CUSTOM_RESTORE_OP, ([init_op], [None], None)) elif self._need_build_custom_init_graph(variables): assign_ops, placeholders = [], [] for variable in variables: # [TODO](fitz) usually after getting variables from collection, # the variable name is tensor like, with the surfix ':0', add test to ensure it var_name = node_name(variable.name) ph = tf.compat.v1.placeholder(dtype=variable.dtype, shape=variable.shape, name=var_name) # (fitz) since tf name scope mechanism, '_\d' may add as a suffix, # as a result we record the origin_name variable name ph.origin_name = var_name assign_op = variable.assign(value=ph) assign_ops.append(assign_op) placeholders.append(ph) init_op = tf.group(assign_ops) graph.add_to_collection(CUSTOM_RESTORE_OP, ([init_op], placeholders, self._alias_map)) else: logging.info("nothing to do in CustomRestoreListener") def _need_build_custom_init_graph(self, variables: List[tf.Variable]) -> bool: assert self._clear_nn == False # for compat, this may not cover all satuation if not self._alias_map and self.enable_alias_map_auto_gen: # 1) load variable name from ckpt ckpt: CheckpointReader = NewCheckpointReader(self.ckpt_name) all_old_var_names = set(ckpt.get_variable_to_dtype_map().keys()) # 2) check if need alias reload, if we can find all variables in ckpt, no alias reload required cnt = 0 pat = re.compile(r"/part_\d+") for variable in variables: expected_saved_varibale_name = node_name(''.join( pat.split(variable.name))) if expected_saved_varibale_name in all_old_var_names: cnt += 1 if len(variables) == cnt: logging.info("The ckpt is compatable, no need alias reload") return False logging.info( "The ckpt is incompatable, begin to generate alias reload automatical ..." ) logging.info(f'all_old_var_names = {all_old_var_names}') # 3) try to convert old variable name to new one var_name_mapping = {} for name in all_old_var_names: var_name_mapping[get_new_name(name)] = name logging.info(f'var_name_mapping = {var_name_mapping}') update_var_name_mapping_for_dense(var_name_mapping) logging.info(f'var_name_mapping after update = {var_name_mapping}') # 4) generate alias_map alias_map = {} miss_dense_names = defaultdict(list) miss_dense_map = {} for variable in variables: expected_saved_varibale_name = node_name(''.join( pat.split(variable.name))) var_name = node_name(variable.name) if var_name_mapping.get(expected_saved_varibale_name) == None: # record needed info to deal with None value dense variables matched = DensePat.match(expected_saved_varibale_name) if matched: prefix = matched.group(1).rstrip('/') if matched.group(1) else '' dense_name = matched.group(2) dense_local_var = matched.group(3) surfix = matched.group(4).lstrip('/') key = f'{prefix}/{dense_local_var}/{surfix}' miss_dense_names[key].append(dense_name) alias_map[var_name] = var_name_mapping.get(expected_saved_varibale_name) logging.info(f'miss_dense_names : {miss_dense_names}') for k, v in miss_dense_names.items(): if len(v) >= 1: v.sort() sub_value = int(v[0].split('_')[1]) insert_pos = k[:k.rfind('/')].rfind('/') if k[-1] == '/': k = k.rstrip('/') for i, name in enumerate(v): old_name = k[:insert_pos + 1] + v[i] + '/' + k[insert_pos + 1:] new_name = k[:insert_pos + 1] + v[i].split('_')[0] + '_' + str( int(v[i].split('_')[1]) - sub_value) + '/' + k[insert_pos + 1:] miss_dense_map[old_name] = new_name logging.info(f'miss_dense_map : {miss_dense_map}') # 5) check whether alias_map is validate none_values = {name for name, value in alias_map.items() if value is None} logging.info(f'none_values = {none_values}') for name in none_values: expected_saved_varibale_name = node_name(''.join(pat.split(name))) guess_name = get_guess_name(expected_saved_varibale_name) if guess_name == expected_saved_varibale_name: guess_name = miss_dense_map.get(name) if guess_name in var_name_mapping: alias_map[name] = var_name_mapping[guess_name] else: logging.info(f'the guess_name = {guess_name} with name = {name}') logging.warning( 'The ckpt is incompatable, but cannot alias reload automatical, pls spectify an alias_map' ) logging.info(f'alias_map = {alias_map}') return False # logging.info(f"The ckpt is incompatable, begin to generate alias reload automatical done!") logging.info( f"The ckpt is incompatable, begin to generate alias reload automatical done! alias_map = {alias_map}" ) # 6) assign alias_map self._alias_map = alias_map if self._alias_map: new_names_alias_map = set(self._alias_map.values()) new_names_from_var = {node_name(variable.name) for variable in variables} return len(new_names_from_var - new_names_alias_map) > 0 else: return False @classmethod def get(cls): return cls.__instance def infer_variable_name(names: List[str]) -> Set[str]: new_names = set() pat = re.compile(r'/part_\d+') for name in names: items = pat.split(name) if len(items) == 1: new_names.add(items[0]) else: new_names.add(''.join(items)) return new_names def calc_feed_dict(ckpt: CheckpointReader, alias_map: Dict[str, str], placeholders: list) -> Dict[str, np.ndarray]: all_old_var_names = set(ckpt.get_variable_to_dtype_map().keys()) reversed_alias_map = defaultdict(list) for new_name, old_name in alias_map.items(): reversed_alias_map[old_name].append(new_name) all_required_new_names = set(alias_map.keys()) # because tf will merge partitioned variable when saving # we need to infer_variable_name by remove part_xx form variable name all_new_var_names = infer_variable_name(all_required_new_names) if len(all_new_var_names - all_old_var_names) == 0: logging.info('no need to use alias_map to restore ...') return None else: logging.info( f'need restore form alias_map: {all_required_new_names - all_old_var_names}' ) ph_dict = {} for ph in placeholders: if hasattr(ph, 'origin_name'): new_var_name = ph.origin_name else: raise Exception(f'Cannot get origin_name of {ph}') ph_dict[new_var_name] = ph result = {} for old_name, new_name_list in reversed_alias_map.items(): if len(new_name_list) == 1: new_name = new_name_list[0] result[ph_dict[new_name]] = ckpt.get_tensor(old_name) else: # this branch is for partitioned variables old_tensor = ckpt.get_tensor(old_name) # deal with problem maybe introduced by update_var_name_mapping_for_dense new_groups = defaultdict(list) for new_name in new_name_list: matched = DensePat.match(new_name) if matched: key = matched.group(2) new_groups[key].append(new_name) if len(new_groups) > 1: denses = sorted(new_groups, key=lambda x: split_name(x)[1]) new_name_list = new_groups[denses[0]] if len(new_name_list) == 1: new_name = new_name_list[0] result[ph_dict[new_name]] = old_tensor continue # sort the partitioned sub variable by partition index # .+/part_xx/.*, we extract the last part_xx, and sort accrodingly new_name_list = sorted(new_name_list, key=lambda x: int(PAT.match(x).group(1))) # get the first dim for placeholder as splits splits = [ph_dict[name].shape[0] for name in new_name_list] # construct indices_or_sections for numpy.split function indices_or_sections = [0] * (len(splits) - 1) for i, val in enumerate(splits): if i == 0: indices_or_sections[i] = val elif i == len(splits) - 1: break else: indices_or_sections[i] = indices_or_sections[i - 1] + val # split old_tensor into partition, sub_tensors = np.split(old_tensor, indices_or_sections, axis=0) for name, tensor in zip(new_name_list, sub_tensors): result[ph_dict[name]] = tensor return result ================================================ FILE: monolith/native_training/dense_reload_utils_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os from typing import List import tensorflow as tf from tensorflow.keras.initializers import Ones, GlorotNormal from monolith.native_training.dense_reload_utils import infer_variable_name, calc_feed_dict, \ CustomRestoreListener from tensorflow.python.training.py_checkpoint_reader import NewCheckpointReader, CheckpointReader class DenseReloadUtilsTest(tf.test.TestCase): @classmethod def setUpClass(cls): with tf.Graph().as_default(): global_step = tf.compat.v1.train.get_or_create_global_step() partitioner = tf.compat.v1.variable_axis_size_partitioner( max_shard_bytes=1 << 17, max_shards=100) partition_var = tf.compat.v1.get_variable(name='partition', shape=(1280, 512), dtype=tf.float32, partitioner=partitioner, initializer=GlorotNormal()) var = tf.compat.v1.get_variable(name='small_var', shape=(10, 5), dtype=tf.float32, initializer=Ones()) # initialize all of the variables init = tf.compat.v1.global_variables_initializer() saver = tf.compat.v1.train.Saver([partition_var, var, global_step]) with tf.compat.v1.Session() as sses: sses.run(init) saver.save(sses, save_path=f"{os.getcwd()}/ckpt/test", global_step=global_step) @classmethod def tearDownClass(cls): if tf.io.gfile.exists(path=f"{os.getcwd()}/ckpt"): tf.io.gfile.rmtree(path=f"{os.getcwd()}/ckpt") def test_infer_variable_name(self): with tf.Graph().as_default(): global_step = tf.compat.v1.train.get_or_create_global_step() partitioner = tf.compat.v1.variable_axis_size_partitioner( max_shard_bytes=1 << 17, max_shards=100) partition_var = tf.compat.v1.get_variable(name='partition', shape=(1280, 512), dtype=tf.float32, partitioner=partitioner, initializer=GlorotNormal()) var = tf.compat.v1.get_variable(name='small_var', shape=(10, 5), dtype=tf.float32, initializer=Ones()) names = [part.name for part in partition_var._get_variable_list()] self.assertEqual(var.name, 'small_var:0') self.assertSetEqual(infer_variable_name(names), {f'{partition_var.name}:0'}) def test_calc_feed_dict(self): with tf.Graph().as_default(): global_step = tf.compat.v1.train.get_or_create_global_step() partitioner = tf.compat.v1.variable_axis_size_partitioner( max_shard_bytes=1 << 17, max_shards=100) partition_var = tf.compat.v1.get_variable(name='partition2', shape=(1280, 512), dtype=tf.float32, partitioner=partitioner, initializer=GlorotNormal()) var = tf.compat.v1.get_variable(name='small_var2', shape=(10, 5), dtype=tf.float32, initializer=Ones()) alias_map = {'small_var2': 'small_var', 'global_step': 'global_step'} var_ph = tf.compat.v1.placeholder(dtype=tf.float32, shape=(10, 5)) var_ph.origin_name = 'small_var2' gs_ph = tf.compat.v1.placeholder(dtype=tf.int64) gs_ph.origin_name = 'global_step' placeholders = [var_ph, gs_ph] for part in partition_var._get_variable_list(): if part.name.endswith(':0'): var_name = part.name[0:-2] else: var_name = part.name alias_map[var_name] = 'partition' ph = tf.compat.v1.placeholder(dtype=part.dtype, shape=part.shape) ph.origin_name = var_name placeholders.append(ph) ckpt: CheckpointReader = NewCheckpointReader(f"{os.getcwd()}/ckpt/test-0") ph_dict = calc_feed_dict(ckpt, alias_map=alias_map, placeholders=placeholders) self.assertEqual(len(ph_dict), len(alias_map)) for part in partition_var._get_variable_list(): if part.name.endswith(':0'): var_name = part.name[0:-2] else: var_name = part.name vph = None for ph in ph_dict: if ph.origin_name == var_name: vph = ph break assert vph is not None self.assertEqual(part.shape, vph.shape) self.assertEqual(part.shape, ph_dict[vph].shape) def test_alias_map_listener(self): with tf.Graph().as_default(): global_step = tf.compat.v1.train.get_or_create_global_step() partitioner = tf.compat.v1.variable_axis_size_partitioner( max_shard_bytes=1 << 17, max_shards=100) partition_var = tf.compat.v1.get_variable(name='partition2', shape=(1280, 512), dtype=tf.float32, partitioner=partitioner, initializer=GlorotNormal()) var = tf.compat.v1.get_variable(name='small_var2', shape=(10, 5), dtype=tf.float32, initializer=Ones()) alias_map = {'small_var2': 'small_var', 'global_step': 'global_step'} var_ph = tf.compat.v1.placeholder(dtype=tf.float32, shape=(10, 5)) var_ph.origin_name = 'small_var2' gs_ph = tf.compat.v1.placeholder(dtype=tf.int64) gs_ph.origin_name = 'global_step' placeholders = [var_ph, gs_ph] for part in partition_var._get_variable_list(): if part.name.endswith(':0'): var_name = part.name[0:-2] else: var_name = part.name alias_map[var_name] = 'partition' ph = tf.compat.v1.placeholder(dtype=part.dtype, shape=part.shape) ph.origin_name = var_name placeholders.append(ph) listener = CustomRestoreListener(alias_map=alias_map, model_dir=f"{os.getcwd()}/ckpt") listener.begin() def test_clear_nn_listener(self): with tf.Graph().as_default(): global_step = tf.compat.v1.train.get_or_create_global_step() partitioner = tf.compat.v1.variable_axis_size_partitioner( max_shard_bytes=1 << 17, max_shards=100) partition_var = tf.compat.v1.get_variable(name='partition2', shape=(1280, 512), dtype=tf.float32, partitioner=partitioner, initializer=GlorotNormal()) var = tf.compat.v1.get_variable(name='small_var2', shape=(10, 5), dtype=tf.float32, initializer=Ones()) listener = CustomRestoreListener(clear_nn=True, model_dir=f"{os.getcwd()}/ckpt") listener.begin() if __name__ == "__main__": tf.compat.v1.disable_eager_execution() tf.test.main() ================================================ FILE: monolith/native_training/device_utils.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Device Utils. Provide device placement utils and strategies. """ import os from typing import Callable import contextlib from absl import logging, flags import tensorflow as tf from tensorflow.python.training import device_setter from tensorflow.python.framework import ops from tensorflow.python.compiler.xla import xla from monolith.native_training.distribution_utils import get_mpi_rank, enable_sync_training FLAGS = flags.FLAGS _GPU_PLACEMENT_ALLOWED = False def enable_gpu_training(): global _GPU_PLACEMENT_ALLOWED _GPU_PLACEMENT_ALLOWED = True def disable_gpu_training(): global _GPU_PLACEMENT_ALLOWED _GPU_PLACEMENT_ALLOWED = False def is_gpu_training(): global _GPU_PLACEMENT_ALLOWED return _GPU_PLACEMENT_ALLOWED def get_visible_gpus(local_rank, processes_per_gpu=1): """ Visible GPU devices string for session config. Args: local_rank: the process local rank, for example `hvd.local_rank()`. process_per_gpu: the integer number of processes per gpu. Return: String compatible for session_config.gpu_options.visible_device_list, for example, "2" indicates TensorFlow session will map the physical gpu:2 into the TensorFlow virtual string-specified device "GPU:0". """ # TODO: processes_per_gpu :float < 0 allows str of gpus. assert isinstance(processes_per_gpu, int) and processes_per_gpu >= 1 return str(int(local_rank / processes_per_gpu)) _default_device = tf.DeviceSpec.from_string("/device:CPU:0") def _device_rule(device_name): # Guarantee default CPU:0 at op creation. Because # otherwise if any GPU is visiable and kernel is available, # op would be placed on GPU when no device string specified. if not device_name: return _default_device.to_string() # Enforce general placement rule. d = tf.DeviceSpec.from_string(device_name) if (d.device_type == "GPU" and not _GPU_PLACEMENT_ALLOWED) or not d.device_type: # If GPU is illegally assigned, or, device type is empty, # Merge with the _default_device while keep the assigned job,task,replica return d.make_merged_spec(_default_device).to_string() # Don't override the assigned and allowed device string return device_name def skip_device(op: tf.Operation) -> bool: # Enforce commonly-used summary op on CPU. return op.type.startswith("Write") or op.type.endswith("Summary") or \ (op.type == "Const" and op.get_attr("dtype") == tf.string) def default_device_fn(op: tf.Operation): """Default device_fn for Estimator RunConfig.""" return _default_device.to_string() if skip_device(op) else _device_rule(op.device) @contextlib.contextmanager def maybe_device_if_allowed(device_name): """ Monolith disallows soft device placement for training. This is an insurance when default_device_fn is missed/not-enforced in Estimator Runconfig. """ dev = _device_rule(device_name) with tf.device(dev): yield class _FakeNodeDef(object): """A fake NodeDef for _FakeOperation.""" __slots__ = ["op", "name"] def __init__(self): self.op = "" self.name = "" class _FakeOp(object): """A helper class to determine the current device. Supports only the type and device set/get methods needed to run the graph's _apply_device_function method. """ def __init__(self): self._device = "" self.type = "FakeOpPyObj" self.name = "" self.node_def = _FakeNodeDef() @property def device(self): return self._device def _set_device(self, device): self._device = ops._device_string(device) # pylint: disable=protected-access def _set_device_from_string(self, device_str): self._device = device_str def within_placement_context_of(device_name): """Check if the current placement context is .""" fake_op = _FakeOp() ops.get_default_graph()._apply_device_functions(fake_op) return tf.DeviceSpec.from_string(fake_op.device).device_type == device_name.upper() def get_device_fn(cluster=None, task=None) -> Callable: is_mpi_mode = True if 'OMPI_COMM_WORLD_LOCAL_RANK' in os.environ else False is_ps_mode = True if FLAGS.num_ps > 0 else False device = 'GPU' if FLAGS.enable_gpu_training or _GPU_PLACEMENT_ALLOWED else 'CPU' if is_mpi_mode and is_ps_mode and FLAGS.enable_sync_training: rank = get_mpi_rank() job = 'chief' if rank == 0 else 'worker' task = rank if rank == 0 else rank - 1 device_spec = tf.DeviceSpec.from_string(f'/job:{job}/replica:0/task:{task}') else: device_spec = tf.DeviceSpec.from_string(f'/device:{device}:0') def _device_fn(op: tf.Operation) -> str: if skip_device(op): return device_spec.make_merged_spec(_default_device).to_string() if op.device: cur_dev = tf.DeviceSpec.from_string(op.device) return device_spec.make_merged_spec(cur_dev).to_string() else: try: op.get_attr('_class') return op.device except: return device_spec.to_string() if FLAGS.enable_sync_training: assert is_mpi_mode, 'sync training must running under mpi mode' if is_ps_mode: return _device_fn else: return default_device_fn else: if FLAGS.is_local or cluster is None or task is None: return None if task['type']: worker_device = f"/job:{task['type']}/task:{task['index']}" else: worker_device = '/job:worker' return tf.compat.v1.train.replica_device_setter( ps_tasks=FLAGS.num_ps, worker_device=worker_device, merge_devices=True, ps_ops=list(device_setter.STANDARD_PS_OPS), cluster=cluster) def input_device_fn(op: tf.Operation): is_mpi_mode = True if 'OMPI_COMM_WORLD_LOCAL_RANK' in os.environ else False is_ps_mode = True if FLAGS.num_ps > 0 else False if is_mpi_mode and is_ps_mode and FLAGS.enable_sync_training: rank = get_mpi_rank() job = 'chief' if rank == 0 else 'worker' task = rank if rank == 0 else rank - 1 return f'/job:{job}/replica:0/task:{task}/device:CPU:0' return '/device:CPU:0' def model_device_fn(op: tf.Operation) -> str: if skip_device(op): return _default_device.to_string() device = 'GPU' if FLAGS.enable_gpu_training or _GPU_PLACEMENT_ALLOWED else 'CPU' device_spec = tf.DeviceSpec.from_string(f'/device:{device}:0') if op.device: cur_dev = tf.DeviceSpec.from_string(op.device) return device_spec.make_merged_spec(cur_dev).to_string() else: try: op.get_attr('_class') return op.device except: return device_spec.to_string() def serving_input_device_fn(op: tf.Operation) -> str: if op.device: return op.device return '/device:CPU:0' ================================================ FILE: monolith/native_training/device_utils_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from absl import logging import tensorflow as tf from monolith.native_training import device_utils class DeviceUtilsTest(tf.test.TestCase): def test_basic(self): with tf.Graph().as_default() as g, g.device(device_utils.default_device_fn): a = tf.constant(1) self.assertEqual(a.device, "/device:CPU:0") def test_cpu_only(self): device_utils.disable_gpu_training() with tf.Graph().as_default() as g, g.device(device_utils.default_device_fn): with tf.device("/device:GPU:0"): a = tf.constant(1) self.assertEqual(a.device, "/device:CPU:0") def test_str_context(self): device_utils.enable_gpu_training() with tf.Graph().as_default() as g, g.device(device_utils.default_device_fn): a = tf.constant(1) with tf.device("GPU:0"): b = tf.constant(1) c = tf.constant(1) self.assertEqual(a.device, "/device:CPU:0") self.assertEqual(b.device, "/device:GPU:0") self.assertEqual(c.device, "/device:CPU:0") def test_str_nested_contexts(self): device_utils.enable_gpu_training() with tf.Graph().as_default() as g, g.device(device_utils.default_device_fn): a = tf.constant(1) with tf.device("CPU:0"): b = tf.constant(1) with tf.device("GPU:0"): c = tf.constant(1) with tf.device("GPU:1"): d = tf.constant(1) self.assertEqual(a.device, "/device:CPU:0") self.assertEqual(b.device, "/device:CPU:0") self.assertEqual(c.device, "/device:GPU:0") self.assertEqual(d.device, "/device:GPU:1") def test_cpu_device_merge(self): # For example, in async training case, we have device job and task string. device_utils.disable_gpu_training() with tf.Graph().as_default() as g, g.device(device_utils.default_device_fn): with tf.device("/job:my_ps/task:0"): a = tf.constant(1) with tf.device("GPU:0"): assert not device_utils.within_placement_context_of("GPU") assert device_utils.within_placement_context_of("CPU") b = tf.constant(1) self.assertEqual(a.device, "/job:my_ps/task:0/device:CPU:0") self.assertEqual(b.device, "/job:my_ps/task:0/device:CPU:0") def test_gpu_device_merge(self): device_utils.enable_gpu_training() with tf.Graph().as_default() as g, g.device(device_utils.default_device_fn): with tf.device("/job:worker/task:0"): with tf.device("/job:ps"): a = tf.constant(1) with tf.device("GPU:0"): b = tf.constant(1) with device_utils.maybe_device_if_allowed("GPU:1"): assert device_utils.within_placement_context_of("GPU") assert not device_utils.within_placement_context_of("CPU") c = tf.constant(1) self.assertEqual(a.device, "/job:ps/task:0/device:CPU:0") self.assertEqual(b.device, "/job:worker/task:0/device:GPU:0") self.assertEqual(c.device, "/job:worker/task:0/device:GPU:1") def test_process_gpu_map(self): self.assertEqual( device_utils.get_visible_gpus(local_rank=2, processes_per_gpu=1), "2") self.assertEqual( device_utils.get_visible_gpus(local_rank=1, processes_per_gpu=2), "0") self.assertEqual( device_utils.get_visible_gpus(local_rank=2, processes_per_gpu=2), "1") self.assertEqual( device_utils.get_visible_gpus(local_rank=3, processes_per_gpu=2), "1") if __name__ == "__main__": logging.set_verbosity(logging.INFO) tf.compat.v1.disable_eager_execution() tf.test.main() ================================================ FILE: monolith/native_training/distribute/BUILD ================================================ load("@rules_python//python:defs.bzl", "py_library", "py_test") package( default_visibility = [ "//monolith/native_training:__subpackages__", "//monolith/native_training/data/training_instance:__subpackages__", ], ) py_library( name = "str_queue", srcs = ["str_queue.py"], ) py_test( name = "str_queue_test", srcs = ["str_queue_test.py"], deps = [":str_queue"], ) py_library( name = "distributed_dataset", srcs = ["distributed_dataset.py"], deps = [ ":str_queue", "//monolith/native_training:native_task_context", "//monolith/native_training:utils", "//monolith/native_training/hooks:session_hooks", ], ) py_test( name = "distributed_dataset_test", srcs = ["distributed_dataset_test.py"], deps = [ ":distributed_dataset", ], ) ================================================ FILE: monolith/native_training/distribute/distributed_dataset.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from collections import defaultdict import dataclasses from typing import List import traceback from absl import logging import tensorflow as tf from tensorflow.python.data.ops import dataset_ops from monolith.native_training import utils from monolith.native_training import native_task_context from monolith.native_training.distribute import str_queue from monolith.native_training.hooks import session_hooks def create_dynamic_sharding_dataset( glob_patterns: List[str], name="dynamic_sharding_dataset") -> tf.data.Dataset: """The idea here is create 2 queues to create the filename database shared: glob_patterns_queue (element is like /some/path/*) shared: filenames_queue (element is like /some/path/data0) The reason why we have two shared queues is the list of filename is too long and can't fit into the memory. So we need expand on demand. """ with tf.name_scope(name): device = utils.ps_device(0) if native_task_context.get().num_ps > 0 else "" # Queues on ps 0 or host if no ps. with tf.device(device): pattern_queue = str_queue.StrQueue(initial_elements=glob_patterns, name="glob_patterns_queue") @tf.function def glob_pattern(): # We are in critical section already. pattern, out_of_range = pattern_queue._raw_dequeue() if not out_of_range: filenames = tf.io.matching_files(pattern) else: filenames = tf.constant([""]) return filenames, out_of_range filenames_queue = str_queue.StrQueue( critical_section=pattern_queue.critical_section, auto_enqueue_fn=glob_pattern, name="filenames_queue") dequeued_filename = filenames_queue.dequeue() def filename_generator(): filename_bytes, out_of_range = session_hooks.get_current_session().run( dequeued_filename) if out_of_range: raise StopIteration() return filename_bytes.decode() dummy_dataset = tf.data.Dataset.from_tensors(0).repeat() # Instead of map, we directly instantiate the MapDataset # because we don't want to keep preserve_cardinality. filename_dataset = dataset_ops.MapDataset( dummy_dataset, lambda _: tf.py_function( func=filename_generator, inp=[], Tout=tf.string), preserve_cardinality=False) return filename_dataset ================================================ FILE: monolith/native_training/distribute/distributed_dataset_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import tensorflow as tf from tensorflow.data.experimental import CheckpointInputPipelineHook from monolith.native_training.distribute import distributed_dataset from monolith.native_training import native_task_context from monolith.native_training.hooks import session_hooks def gen_test_files(files_dir): """ Generates following files under the folder. a_0.txt a_1.txt ... e_1.txt In each file, it will be some like (this is coming from a_0.txt) a.0.0 a.0.1 """ for c in range(97, 102): for i in range(2): with tf.io.gfile.GFile( os.path.join(files_dir, '{}_{}.txt'.format(chr(c), i)), 'w+') as f: f.write('\n'.join(['{}.{}.{}'.format(chr(c), i, j) for j in range(2)])) class DynamicShardingDatasetTest(tf.test.TestCase): def setUp(self): super().setUp() self.test_dir = os.environ["TEST_TMPDIR"] self.data_dir = os.path.join(self.test_dir, 'test_data') if not tf.io.gfile.exists(self.data_dir): tf.io.gfile.makedirs(self.data_dir) gen_test_files(self.data_dir) self.glob_patterns = [ os.path.join(self.data_dir, basename) for basename in ['a_*.txt', 'b_*.txt', 'c_*.txt', 'd_*.txt', 'e_*.txt'] ] def get_test_session(self): return tf.compat.v1.train.SingularMonitoredSession( hooks=[session_hooks.SetCurrentSessionHook()]) def testBasic(self): ds = distributed_dataset.create_dynamic_sharding_dataset(self.glob_patterns) it = tf.compat.v1.data.make_one_shot_iterator(ds) element = it.get_next() with self.get_test_session() as sess: names = [] for i in range(10): names.append(sess.run(element)) expected = [] for i in range(97, 102): for j in range(2): expected.append( os.path.join(self.data_dir, "{}_{}.txt".format(chr(i), j))) self.assertAllEqual(names, expected) def testEof(self): ds = distributed_dataset.create_dynamic_sharding_dataset([]) it = tf.compat.v1.data.make_one_shot_iterator(ds) v = tf.Variable(0) element = it.get_next() with tf.control_dependencies([element]): add_op = v.assign_add(1) with self.get_test_session() as sess: with self.assertRaises(tf.errors.OutOfRangeError): sess.run(add_op) # Make sure v is not changed self.assertAllEqual(sess.run(v), 0) def testWithOtherDataset(self): filename_dataset = distributed_dataset.create_dynamic_sharding_dataset( self.glob_patterns) dataset = filename_dataset.flat_map(tf.data.TextLineDataset) it = tf.compat.v1.data.make_one_shot_iterator(dataset) element = it.get_next() with self.get_test_session() as sess: lines = [] for i in range(3): lines.append(sess.run(element).decode()) self.assertAllEqual(lines, ["a.0.0", "a.0.1", "a.1.0"]) def testSaveRestore(self): filename_dataset = distributed_dataset.create_dynamic_sharding_dataset( self.glob_patterns) dataset = filename_dataset.flat_map(tf.data.TextLineDataset) it = tf.compat.v1.data.make_one_shot_iterator(dataset) element = it.get_next() saveable_obj = tf.data.experimental.make_saveable_from_iterator( it, external_state_policy="ignore") saver = tf.compat.v1.train.Saver(var_list=[saveable_obj] + tf.compat.v1.global_variables()) with self.get_test_session() as sess: real_sess = session_hooks.get_current_session() self.assertAllEqual(sess.run(element).decode(), "a.0.0") save_path = saver.save( real_sess, os.path.join(os.environ["TEST_TMPDIR"], "save_restore")) self.assertAllEqual(sess.run(element).decode(), "a.0.1") saver.restore(real_sess, save_path) self.assertAllEqual(sess.run(element).decode(), "a.0.1") if __name__ == '__main__': tf.compat.v1.disable_eager_execution() tf.test.main() ================================================ FILE: monolith/native_training/distribute/str_queue.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Callable import tensorflow as tf class StrQueue: """A queue whose element is a string, and supports save/restore. When queue is running out, it will throw OutOfRange error. """ def __init__(self, initial_elements=None, critical_section=None, auto_enqueue_fn=None, capacity=100000, name="StrQueue"): """Args: critical_section - if not None, queue will use this as a critical section instead of creating new one. auto_enqueue_fn - when queue is empty, we will use this enqueue op to fill the queue. Should be a callable returns 2 tensors: 1-D string tensor represents strings to be enqueued and 0-D bool tensor to indicate if it is out of range. """ with tf.name_scope(name) as scope: self._name = name self._auto_enqueue_fn = auto_enqueue_fn self._capacity = capacity self._cs = critical_section or tf.CriticalSection( name="CriticalSection", shared_name=scope + "/CriticalSection") self._arr = tf.Variable(initial_value=tf.constant_initializer("")( shape=[self._capacity], dtype=tf.string, ), trainable=False, name="Queue") self._offset = tf.Variable(0, trainable=False, name="Offset") self._arr_size = tf.Variable(0, trainable=False, name="Size") if initial_elements is None: initial_elements = [] # Here we use a dummy var to init queue with tf.control_dependencies([ self._arr.initializer, self._offset.initializer, self._arr_size.initializer, ]): with tf.control_dependencies([self.enqueue_many(initial_elements)]): var_for_init_value = tf.constant(0) self._var_for_init = tf.Variable(initial_value=var_for_init_value, trainable=False, name="VarForInit") @property def critical_section(self): return self._cs def enqueue_many(self, elements: tf.Tensor, name=None): elements = tf.convert_to_tensor(elements, tf.string) return self._cs.execute(lambda: self._raw_enqueue_many(elements), name=name) def dequeue(self, name=None): """Dequeues an element. Returns 2 elements: element & a bool indicating if we're out of range.""" return self._cs.execute(self._raw_dequeue, name=name) @tf.function def _raw_enqueue_many(self, elements: tf.Tensor): size = tf.size(elements) old_arr_size = self._arr_size - self._offset new_arr_size = old_arr_size + size tf.debugging.Assert(new_arr_size <= self._capacity, [ self._name, " excceeds capacity ", new_arr_size, " v.s. ", self._capacity ]) self._arr[0:old_arr_size].assign(self._arr[self._offset:self._arr_size]) self._arr[old_arr_size:old_arr_size + size].assign(elements) self._offset.assign(0) self._arr_size.assign(new_arr_size) @tf.function def _raw_dequeue(self): tf.debugging.Assert(self._offset <= self._arr_size, [ "Offset should always be less than or equal to arr_size.", "This may indicate an internal error. offset: ", self._offset, " arr_size: ", self._arr_size ]) if self._auto_enqueue_fn is not None: while tf.math.equal(self._offset, self._arr_size): elements, out_of_range = self._auto_enqueue_fn() elements = tf.convert_to_tensor(elements) if out_of_range: break self._raw_enqueue_many(elements) if tf.math.equal(self._offset, self._arr_size): return "", True else: single_element = self._arr[self._offset] self._offset.assign_add(1) return single_element, False ================================================ FILE: monolith/native_training/distribute/str_queue_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from monolith.native_training.distribute import str_queue import tensorflow as tf class QueueTest(tf.test.TestCase): def testBasic(self): q = str_queue.StrQueue() enqueue_op = q.enqueue_many(["test1", "test2"]) dequeue = q.dequeue() self.evaluate(tf.compat.v1.global_variables_initializer()) self.evaluate(enqueue_op) self.assertEqual(self.evaluate(dequeue)[0].decode(), "test1") self.assertEqual(self.evaluate(dequeue)[0].decode(), "test2") def testInit(self): q = str_queue.StrQueue(initial_elements=["test1"]) self.evaluate(tf.compat.v1.global_variables_initializer()) self.assertEqual(self.evaluate(q.dequeue())[0].decode(), "test1") def testOutOfRange(self): q = str_queue.StrQueue() dequeue = q.dequeue() self.evaluate(tf.compat.v1.global_variables_initializer()) _, out_of_range = self.evaluate(dequeue) self.assertEqual(out_of_range, True) def testAutoEnqueue(self): v = tf.Variable([0]) self.evaluate(tf.compat.v1.global_variables_initializer()) ds = tf.data.Dataset.from_tensor_slices([]) it = tf.compat.v1.data.make_one_shot_iterator(ds) @tf.function def auto_enqueue(): new_v = v.assign_add([1]) if new_v > 2: return tf.constant([""]), True return tf.as_string(v), False q = str_queue.StrQueue(auto_enqueue_fn=auto_enqueue) self.evaluate(tf.compat.v1.global_variables_initializer()) self.assertEqual(self.evaluate(q.dequeue())[0].decode(), "1") self.assertEqual(self.evaluate(q.dequeue())[0].decode(), "2") self.assertEqual(self.evaluate(q.dequeue())[1], True) # Simulating in the distributed training, multiple dequeues will be called. self.assertEqual(self.evaluate(q.dequeue())[1], True) if __name__ == "__main__": tf.compat.v1.disable_eager_execution() tf.test.main() ================================================ FILE: monolith/native_training/distributed_ps.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations import re import hashlib import collections import contextlib import itertools from contextlib import nullcontext import copy import os import sys from collections import defaultdict, namedtuple from typing import Callable, DefaultDict, Dict, Iterable, List, Tuple, Optional, NewType from absl import flags from absl import logging import tensorflow as tf from tensorflow.python.framework.ops import Tensor from tensorflow.python.types.core import Value from monolith.native_training import distribution_ops from monolith.native_training import distributed_serving_ops from monolith.native_training import hash_table_ops from monolith.native_training import logging_ops from monolith.native_training import multi_type_hash_table from monolith.native_training import multi_hash_table_ops from monolith.native_training import native_task_context from monolith.native_training import tensor_utils from monolith.native_training import utils from monolith.native_training import entry from monolith.native_training.hash_table_utils import infer_dim_size from monolith.native_training.model_export import export_context from monolith.native_training.data.parsers import sharding_sparse_fids from idl.matrix.proto.example_pb2 import FeatureConfigs, FeatureConfig, PoolingType, OutType, OutConfig import monolith.native_training.embedding_combiners as embedding_combiners from monolith.native_training.data.parsers import get_default_parser_ctx, ParserCtx, ShardingSparseFidsOpParams from monolith.native_training.prefetch_queue import \ enqueue_dicts_with_queue_return, AsyncPushHook, EnqueueHook from monolith.native_training import prefetch_queue FLAGS = flags.FLAGS enable_hvd = os.getenv("MONOLITH_WITH_HOROVOD") if enable_hvd != None: import horovod.tensorflow as hvd from horovod.tensorflow.compression import FP16Compressor # For mock test remote_predict = distributed_serving_ops.remote_predict @contextlib.contextmanager def ps_device(i: int): """We need to clean the device stack first to make tf.function work properly.""" with tf.compat.v1.get_default_graph().colocate_with(None, True), tf.device( utils.ps_device(i)): yield class DistributedHashTable(hash_table_ops.BaseHashTable): """The distribution version of hash table. """ def __init__( self, ps_num, config: entry.HashTableConfigInstance, hash_table_factory: Callable[[int, entry.HashTableConfigInstance], hash_table_ops.BaseHashTable]): self._ps_num = ps_num self._hash_tables = [] # Build learning rate tensor on worker side learning_rate_tensor = config.call_learning_rate_fns() for i in range(self._ps_num): with nullcontext() if export_context.is_exporting_standalone( ) else ps_device(i): # Send learning rate tensor to ps learning_rate_tensor_on_ps = tf.identity(learning_rate_tensor) config.set_learning_rate_tensor(learning_rate_tensor_on_ps) self._hash_tables.append(hash_table_factory(i, config)) self._input_lookup_tensors = {} self._output_lookup_tensors = set() @property def dim_size(self): return self._hash_tables[0].dim_size # Once `lookup` is edited, remember to edit `apply_gradients` too. def lookup(self, ids: tf.Tensor, use_multi_threads=False) -> tf.Tensor: unique_ids = ids unique_ids, idx = tf.unique(ids) indices = tf.math.floormod(unique_ids, self._ps_num) split_ids = distribution_ops.split_by_indices(indices, unique_ids, self._ps_num) split_embeddings = [] for i in range(self._ps_num): with nullcontext() if export_context.is_exporting_standalone( ) else ps_device(i), tf.name_scope("ps_{}".format(i)): hash_table = self._hash_tables[i] ids_part = split_ids[i] embeddings_part = hash_table.lookup(ids_part) self._input_lookup_tensors.update({embeddings_part: i}) split_embeddings.append(embeddings_part) lookup_tensor = distribution_ops.map_id_to_embedding( split_ids, split_embeddings, ids) self._output_lookup_tensors.add(lookup_tensor) return lookup_tensor def _update(self, method_name: str, ids: tf.Tensor, values: tf.Tensor, req_time: tf.Tensor) -> "DistributedHashTable": indices = tf.math.floormod(ids, self._ps_num) split_ids = distribution_ops.split_by_indices(indices, ids, self._ps_num) split_values = distribution_ops.split_by_indices(indices, values, self._ps_num) updated_tables = [] for i in range(self._ps_num): with ps_device(i): ids_part = split_ids[i] values_part = split_values[i] updated_tables.append( getattr(self._hash_tables[i], method_name)(ids_part, values_part, req_time)) return self._copy_with_new_tables(updated_tables) def assign(self, ids: tf.Tensor, values: tf.Tensor, req_time: tf.Tensor = None) -> "DistributedHashTable": if req_time is None: req_time = tf.constant(0, dtype=tf.int64) return self._update("assign", ids, values, req_time) def assign_add(self, ids: tf.Tensor, values: tf.Tensor, req_time: tf.Tensor = None) -> "DistributedHashTable": if req_time is None: req_time = tf.constant(0, dtype=tf.int64) return self._update("assign_add", ids, values, req_time) def apply_gradients(self, ids: tf.Tensor, grads: tf.Tensor, global_step: tf.Tensor, req_time: tf.Tensor = None) -> "DistributedHashTable": if req_time is None: req_time = tf.constant(0, dtype=tf.int64) unique_ids, idx = tf.unique(ids) indices = tf.math.floormod(unique_ids, self._ps_num) split_ids = distribution_ops.split_by_indices(indices, unique_ids, self._ps_num) split_grads = distribution_ops.map_id_to_embedding_gradient_back_prop( split_ids, ids, grads) updated_tables = [] for i in range(self._ps_num): with ps_device(i), tf.name_scope("ps_{}".format(i)): # TODO(leqi.zou): Think of the meaning of dedup here updated_tables.append(self._hash_tables[i].apply_gradients( split_ids[i], split_grads[i], global_step=global_step, enable_dedup=False, req_time=req_time)) return self._copy_with_new_tables(updated_tables) def as_op(self, name=None) -> tf.Operation: name = name or "dht_ao" with tf.control_dependencies([table.as_op() for table in self._hash_tables ]): c = tf.no_op(name=("{}/done".format(name))) return c def _copy_with_new_tables( self, new_tables: List[tf.Tensor]) -> "DistributedHashTable": copied = copy.copy(self) copied.__dict__["_hash_tables"] = new_tables return copied class DistributedMultiTypeHashTable(multi_type_hash_table.BaseMultiTypeHashTable ): def __init__( self, num_ps: int, slot_to_config: Dict[str, entry.HashTableConfigInstance], table_factory: Callable[[int, Dict[str, entry.HashTableConfigInstance]], multi_type_hash_table.BaseMultiTypeHashTable], transfer_float16: bool = False, max_rpc_deadline_millis: int = 30): self._num_ps = num_ps self._slot_to_config = slot_to_config self._tables = [] self._table_support_raw_api = True self.transfer_float16 = transfer_float16 self._max_rpc_deadline_millis = max_rpc_deadline_millis # Build learning rate tensor on worker side slot_to_learning_rate_tensor = dict() for slot, config in slot_to_config.items(): slot_to_learning_rate_tensor[slot] = config.call_learning_rate_fns() packed_slot_to_learning_rate_tensor = tensor_utils.pack_tensors( slot_to_learning_rate_tensor) def support_raw_api(table): return isinstance(table, multi_hash_table_ops.RawMultiTypeHashTable) for i in range(self._num_ps): if export_context.is_exporting_distributed(): ps_graph = export_context.get_current_export_ctx().sub_graph(f"ps_{i}") with ps_graph.as_default(): table = table_factory(i, slot_to_config) self._tables.append(table) # Build lookup graph on the PS side remote_lookup_input = { k: tf.compat.v1.placeholder(dtype=tf.int64, shape=(None,)) for k in slot_to_config } remote_lookup_output = table.lookup(remote_lookup_input) export_context.get_current_export_ctx().add_signature( ps_graph, 'lookup', remote_lookup_input, remote_lookup_output) if support_raw_api(table): raw_remote_lookup_input = { "id": tf.compat.v1.placeholder(dtype=tf.int64, shape=(None,)), "id_split": tf.compat.v1.placeholder(dtype=tf.int64, shape=(None,)), } raw_remote_lookup_output = { "flat_emb": table.raw_lookup( tf.RaggedTensor.from_row_splits( raw_remote_lookup_input["id"], raw_remote_lookup_input["id_split"], validate=False)) } export_context.get_current_export_ctx().add_signature( ps_graph, 'raw_lookup', raw_remote_lookup_input, raw_remote_lookup_output) elif export_context.is_exporting_standalone(): self._tables.append(table_factory(i, slot_to_config)) else: with ps_device(i): # Send learning rate tensor to ps # TODO(leqi.zou): Here we can do some optimization to optimize raw hash table. slot_to_learning_rate_tensor_on_ps = tensor_utils.unpack_tensors( tensor_utils.get_keyed_shape(slot_to_learning_rate_tensor), packed_slot_to_learning_rate_tensor) for slot, config in slot_to_config.items(): config.set_learning_rate_tensor( slot_to_learning_rate_tensor_on_ps[slot]) self._tables.append(table_factory(i, slot_to_config)) self._table_support_raw_api &= support_raw_api(self._tables[-1]) def lookup(self, slot_to_id: Dict[str, tf.Tensor]) -> Dict[str, tf.Tensor]: with tf.name_scope("dmtht_lu"): def emit_lookup_timer_ops(interval): if not export_context.is_exporting(): return [ logging_ops.emit_timer( "embedding_lookup", tf.cast(interval, tf.float32), tags={ "model_name": native_task_context.get().model_name, "ps": str(i) }) ] return [] if self._table_support_raw_api and not self.transfer_float16: table_0 = self._tables[0] dims = table_0.get_table_dim_sizes() ragged_id = table_0.get_ragged_id(slot_to_id) result = distribution_ops.unique_key_with_value_and_offset( ragged_id, dims) index = tf.math.floormod(result.unique_key.values, self._num_ps) splitted_ids, splitted_pos = distribution_ops.ragged_split_by_indices( index, result.unique_key, self._num_ps) filled_buffers = [] interval_ops = [] for i in range(self._num_ps): table: multi_hash_table_ops.RawMultiTypeHashTable = self._tables[i] splitted_id = splitted_ids[i] (splitted_id_values,), send_ts = logging_ops.tensors_timestamp( [splitted_id.values]) splitted_id = tf.RaggedTensor.from_row_splits(splitted_id_values, splitted_id.row_splits, validate=False) if export_context.is_exporting_distributed(): flat_emb, = remote_predict( ["id", "id_split"], [splitted_id.values, splitted_id.row_splits], ["flat_emb"], task=i, old_model_name="ps_{}".format(i), model_name= f"{native_task_context.get().model_name or ''}:ps_{i}", model_version=-1, max_rpc_deadline_millis=self._max_rpc_deadline_millis, output_types=[tf.float32], signature_name="raw_lookup") else: with nullcontext() if export_context.is_exporting_standalone( ) else ps_device(i): flat_emb = table.raw_lookup(splitted_id) (flat_emb,), end_ts = logging_ops.tensors_timestamp([flat_emb]) interval_ops.extend(emit_lookup_timer_ops(end_ts - send_ts)) filled_buffers.append( distribution_ops.fill_with_offset_map(splitted_pos[i], flat_emb, result.value_offset, result.value_buffer, dims)) with tf.control_dependencies(interval_ops): flat_emb = distribution_ops.finalize_shared_tensor(filled_buffers, dtype=tf.float32, shape=[None]) emb = table_0.get_embeddings(ragged_id, flat_emb) polished_emb = {} # Remove unpresented keys and make emb shape known if input shape is known. for k, v in emb.items(): if k in slot_to_id: id = slot_to_id[k] if id.shape[0]: v = tf.reshape(v, shape=[id.shape[0], v.shape[1]]) polished_emb[k] = v return polished_emb else: sharded_slot_to_id: Dict[int, Dict[ str, tf.Tensor]] = collections.defaultdict(dict) slot_to_split_ids = {} for slot in slot_to_id: id = slot_to_id[slot] unique_id, idx = tf.unique(id) index = tf.math.floormod(unique_id, self._num_ps) split_ids = distribution_ops.split_by_indices(index, unique_id, self._num_ps) slot_to_split_ids[slot] = split_ids for i in range(self._num_ps): sharded_slot_to_id[i][slot] = split_ids[i] sharded_slot_to_embedding: Dict[int, Dict[str, tf.Tensor]] = {} if export_context.is_exporting_distributed(): slot_names = sorted(slot_to_split_ids.keys()) slot_to_dim = [ infer_dim_size(self._slot_to_config[slot].table_config) for slot in slot_names ] for i in range(self._num_ps): per_ps_slot_to_id = sharded_slot_to_id[i] # Remote call from Entry to PS # TODO(leqi.zou): Consider a better way to get model name. results = remote_predict( slot_names, [per_ps_slot_to_id[slot] for slot in slot_names], slot_names, task=i, old_model_name="ps_{}".format(i), model_name= f"{native_task_context.get().model_name or ''}:ps_{i}", model_version=-1, max_rpc_deadline_millis=self._max_rpc_deadline_millis, output_types=[tf.float32] * len(slot_names), signature_name="lookup") sharded_slot_to_embedding[i] = { slot_names[j]: tf.reshape(results[j], [-1, slot_to_dim[j]]) for j in range(len(slot_names)) } else: for i in range(self._num_ps): per_ps_slot_to_id = sharded_slot_to_id[i] packed_id = tensor_utils.pack_tensors(per_ps_slot_to_id) packed_id, send_ts = logging_ops.tensors_timestamp(packed_id) with nullcontext() if export_context.is_exporting_standalone( ) else ps_device(i): slot_to_id_on_ps = tensor_utils.unpack_tensors( tensor_utils.get_keyed_shape(per_ps_slot_to_id), packed_id) slot_to_embedding_on_ps = self._tables[i].lookup(slot_to_id_on_ps) packed_embedding = tensor_utils.pack_tensors( slot_to_embedding_on_ps) if self.transfer_float16: packed_embedding = (tf.cast( packed_embedding[0], dtype=tf.float16, name='{}_send_{}_CastToFloat16'.format( packed_embedding[0].op.name, i)), packed_embedding[1]) packed_embedding, recv_ts = logging_ops.tensors_timestamp( packed_embedding) interval = recv_ts - send_ts with tf.control_dependencies(emit_lookup_timer_ops(interval)): packed_embedding = tf.identity_n(packed_embedding) if self.transfer_float16: packed_embedding = (tf.cast( packed_embedding[0], dtype=tf.float32, name='{}_recv_{}_CastToFloat32'.format( packed_embedding[0].op.name, i)), packed_embedding[1]) slot_to_embedding = tensor_utils.unpack_tensors( tensor_utils.get_keyed_shape(slot_to_embedding_on_ps), packed_embedding) sharded_slot_to_embedding[i] = slot_to_embedding slot_to_split_embeddings = {} for slot in slot_to_id: slot_to_split_embeddings[slot] = [ sharded_slot_to_embedding[i][slot] for i in range(self._num_ps) ] slot_to_embedding = {} for slot in slot_to_id: slot_to_embedding[slot] = distribution_ops.map_id_to_embedding( slot_to_split_ids[slot], slot_to_split_embeddings[slot], slot_to_id[slot]) return slot_to_embedding def _update( self, method_name: str, name_scope: str, slot_to_id_and_value: Dict[str, Tuple[tf.Tensor, tf.Tensor]] ) -> DistributedMultiTypeHashTable: with tf.name_scope(name_scope): sharded_slot_to_id_and_value: Dict[int, Dict[str, Tuple[ tf.Tensor, tf.Tensor]]] = collections.defaultdict(dict) for slot, (id, value) in slot_to_id_and_value.items(): index = tf.math.floormod(id, self._num_ps) split_ids = distribution_ops.split_by_indices(index, id, self._num_ps) split_values = distribution_ops.split_by_indices( index, value, self._num_ps) for i in range(self._num_ps): sharded_slot_to_id_and_value[i][slot] = (split_ids[i], split_values[i]) new_tables = [] for i in range(self._num_ps): new_tables.append( getattr(self._tables[i], method_name)(sharded_slot_to_id_and_value[i])) return self._copy_with_new_table(new_tables) def assign( self, slot_to_id_and_value: Dict[str, Tuple[tf.Tensor, tf.Tensor]] ) -> DistributedMultiTypeHashTable: return self._update("assign", "dmtht_a", slot_to_id_and_value) def assign_add( self, slot_to_id_and_value: Dict[str, Tuple[tf.Tensor, tf.Tensor]] ) -> DistributedMultiTypeHashTable: return self._update("assign_add", "dmtht_aa", slot_to_id_and_value) def reinitialize( self, slot: str, ids: tf.Tensor) -> Tuple[DistributedMultiTypeHashTable, tf.Tensor]: if self._table_support_raw_api: with tf.name_scope("dmtht_reinit"): index = tf.math.floormod(ids, self._num_ps) split_ids = distribution_ops.split_by_indices(index, ids, self._num_ps) new_tables, status = [], [] for i in range(self._num_ps): new_table, split_status = self._tables[i].reinitialize( slot, split_ids[i]) new_tables.append(new_table) status.append(split_status) return self._copy_with_new_table(new_tables), tf.concat(status, axis=0) else: raise NotImplementedError( "DistributedMultiTypeHashTable dost not support reinitialize!") def apply_gradients( self, slot_to_id_and_grad: Dict[str, Tuple[tf.Tensor, tf.Tensor]], global_step: tf.Tensor, req_time: Optional[tf.Tensor] = None) -> DistributedMultiTypeHashTable: if req_time is None: req_time = tf.constant(0, dtype=tf.int64) with tf.name_scope("dmtht_ag"): if self._table_support_raw_api and not self.transfer_float16: slot_to_id = {k: v[0] for k, v in slot_to_id_and_grad.items()} slot_to_grad = {k: v[1] for k, v in slot_to_id_and_grad.items()} table_0 = self._tables[0] dims = table_0.get_table_dim_sizes() ragged_id = table_0.get_ragged_id(slot_to_id) result = distribution_ops.unique_key_with_value_and_offset( ragged_id, dims, generate_buffer=False) index = tf.math.floormod(result.unique_key.values, self._num_ps) splitted_ids, splitted_pos = distribution_ops.ragged_split_by_indices( index, result.unique_key, self._num_ps) flat_grad = table_0.get_flat_value(slot_to_grad) new_tables = [] for i in range(self._num_ps): splitted_id = splitted_ids[i] splitted_flat_grad = distribution_ops.fill_with_offset_map_gradient( splitted_pos[i], flat_grad, result.value_offset, dims) table: multi_hash_table_ops.RawMultiTypeHashTable = self._tables[i] with ps_device(i): new_tables.append( table.raw_apply_gradients(splitted_id, splitted_flat_grad, global_step=global_step, req_time=req_time)) return self._copy_with_new_table(new_tables) else: sharded_slot_to_id_and_grad: Dict[int, Dict[str, Tuple[ tf.Tensor, tf.Tensor]]] = collections.defaultdict(dict) for slot, (id, grad) in slot_to_id_and_grad.items(): unique_id, _ = tf.unique(id) index = tf.math.floormod(unique_id, self._num_ps) split_ids = distribution_ops.split_by_indices(index, unique_id, self._num_ps) split_grads = distribution_ops.map_id_to_embedding_gradient_back_prop( split_ids, id, grad) for i in range(self._num_ps): sharded_slot_to_id_and_grad[i][slot] = (split_ids[i], split_grads[i]) new_tables = [] for i in range(self._num_ps): keyed_id = { k: v[0] for k, v in sharded_slot_to_id_and_grad[i].items() } keyed_grad = { k: v[1] for k, v in sharded_slot_to_id_and_grad[i].items() } packed_list = tensor_utils.pack_typed_keyed_tensors( [keyed_id, keyed_grad]) if self.transfer_float16: packed_list[1] = tf.cast(packed_list[1], dtype=tf.float16, name='{}_send_{}_CastToFloat16'.format( packed_list[1].op.name, i)) with ps_device(i): if self.transfer_float16: packed_list[1] = tf.cast(packed_list[1], dtype=tf.float32, name='{}_recv_{}_CastToFloat32'.format( packed_list[1].op.name, i)) keyed_list_on_ps = tensor_utils.unpack_packed_tensors( tensor_utils.get_typed_keyed_shape([keyed_id, keyed_grad]), packed_list) keyed_id_on_ps = keyed_list_on_ps[0] keyed_grad_on_ps = keyed_list_on_ps[1] slot_to_id_and_grad_on_ps = { slot: (keyed_id_on_ps[slot], keyed_grad_on_ps[slot]) for slot in keyed_id_on_ps } new_tables.append(self._tables[i].apply_gradients( slot_to_id_and_grad_on_ps, global_step, req_time=req_time)) return self._copy_with_new_table(new_tables) def as_op(self, name=None): name = name or "dmtht_ao" ops = [] for i in range(self._num_ps): with ps_device(i): ops.append(self._tables[i].as_op(name="{}/sub_{}".format(name, i))) with tf.control_dependencies(ops): return tf.no_op(name=("{}/done".format(name))) def get_table_dim_sizes(self): return self._cc.dims def _copy_with_new_table( self, new_tables: List[tf.Tensor]) -> DistributedMultiTypeHashTable: copied = copy.copy(self) copied._tables = new_tables return copied def get_sub_table_name(strs: List[str]): concat = ",".join(strs) return concat, hashlib.md5(concat.encode()).hexdigest() Partition = NewType("Partition", int) TableName = NewType("TableName", str) Fids = NewType("Fids", tf.Tensor) Emb = NewType("Emb", tf.Tensor) EmbGrad = NewType("EmbGrad", tf.Tensor) FidEmbPair = NewType("FidEmbPair", Tuple[Fids, Emb]) FidEmbGradPair = NewType("FidEmbGradPair", Tuple[Fids, EmbGrad]) LookupData = NewType("LookupData", Dict[Partition, Dict[TableName, Fids]]) UpdateData = NewType("UpdateData", Dict[Partition, Dict[TableName, FidEmbGradPair]]) AssignData = NewType("UpdateData", Dict[Partition, Dict[TableName, FidEmbPair]]) TableFactory = NewType( "TableFactory", Callable[[Partition, Dict[TableName, entry.HashTableConfigInstance]], multi_type_hash_table.BaseMultiTypeHashTable]) FeatureInfo = namedtuple('FeatureInfo', 'slice_dims combiner sub_table') from monolith.native_training.distributed_ps_sync import enable_custom_optimized_hvd, enable_hvd_fid_g2g, \ enable_hvd_fwd_g2g, enable_hvd_bwd_g2g, enable_bps, enable_bps_fid, enable_bps_fwd, enable_bps_bwd, \ enable_bps_bwd_cast, enable_bps_bwd_fake_cast, enable_bps_fwd_gdr, enable_bps_fwd_gdr_g2g, \ enable_bps_bwd_gdr, enable_bps_bwd_gdr_g2g class PartitionedHashTable(object): # Allow pipelined graph execution. _local_queue_hooks: List[prefetch_queue.EnqueueHook | prefetch_queue.AsyncPushHook] _native_multi_hash_table_fake_table = "native_multi_hash_table_fake_table" @classmethod def gen_feature_configs( cls, num_ps: int, feature_name_to_config: Dict[str, entry.HashTableConfigInstance], layout_configs: Dict[str, OutConfig], feature_to_unmerged_slice_dims: Dict[str, List[int]], feature_to_combiner: Dict[str, embedding_combiners.Combiner], use_native_multi_hash_table: bool, transfer_float16: bool = False, unique: Callable = None, enable_gpu_emb: bool = False, use_gpu: bool = False, ): _num_ps: int = 1 if num_ps == 0 else num_ps _use_native_multi_hash_table = use_native_multi_hash_table and not transfer_float16 # feature/slot -> sub_hashtable_name if _use_native_multi_hash_table: _sub_table_name_to_config, feature_to_sub_table = cls.no_merge_feature_config( feature_name_to_config, use_same_table=not enable_gpu_emb) else: _sub_table_name_to_config, feature_to_sub_table = cls.merge_feature_config( feature_name_to_config) feature_info: Dict[str, FeatureInfo] = {} for feature_name, sub_table in feature_to_sub_table.items(): feature_info[feature_name] = FeatureInfo( feature_to_unmerged_slice_dims[feature_name], feature_to_combiner[feature_name], sub_table) feature_configs = FeatureConfigs() # fill feature config for feature_configs for feature_name, info in feature_info.items(): combiner = info.combiner if isinstance(combiner, embedding_combiners.ReduceSum): pooling_type = PoolingType.SUM elif isinstance(combiner, embedding_combiners.ReduceMean): pooling_type = PoolingType.MEAN elif isinstance(combiner, embedding_combiners.FirstN): pooling_type = PoolingType.FIRSTN else: raise Exception("pooling_type error!") max_sequence_length = combiner.max_seq_length fc = FeatureConfig(table=info.sub_table, pooling_type=pooling_type, max_sequence_length=max_sequence_length) fc.slice_dims.extend(info.slice_dims) feature_configs.feature_configs[feature_name].CopyFrom(fc) for out_name, oc in layout_configs.items(): feature_configs.out_configs[out_name].CopyFrom(oc) return ShardingSparseFidsOpParams(_num_ps, _use_native_multi_hash_table, unique, transfer_float16, _sub_table_name_to_config, feature_configs, enable_gpu_emb, use_gpu) @classmethod def no_merge_feature_config( cls, feature_name_to_config: Dict[str, entry.HashTableConfigInstance], use_same_table: bool): sub_table_name_to_config, feature_to_sub_table = {}, {} for feature_name in sorted(feature_name_to_config): feature_to_sub_table[ feature_name] = cls._native_multi_hash_table_fake_table if use_same_table else feature_name return feature_name_to_config, feature_to_sub_table @classmethod def merge_feature_config( cls, feature_name_to_config: Dict[str, entry.HashTableConfigInstance]): # create merged config config_to_feature_name_list: Dict[ str, List[entry.HashTableConfigInstance]] = defaultdict(list) for feature_name in sorted(feature_name_to_config): config = feature_name_to_config[feature_name] config_to_feature_name_list[str(config)].append(feature_name) sub_table_name_to_config, feature_to_sub_table = {}, {} # merged config for config_str, feature_name_list in config_to_feature_name_list.items(): _, sub_table_name = get_sub_table_name(feature_name_list) sub_table_config = copy.copy(feature_name_to_config[feature_name_list[0]]) # replace "fc_slot_*" to "slot_*" old_feature_name_list = [ feature_name[3:] if re.match("^fc_slot_[0-9]*$", feature_name) else feature_name for feature_name in feature_name_list ] _, old_sub_table_name = get_sub_table_name(old_feature_name_list) if old_sub_table_name != sub_table_name: sub_table_config.extra_restore_names.append(old_sub_table_name) sub_table_name_to_config[sub_table_name] = sub_table_config for feature_name in feature_name_list: feature_to_sub_table[feature_name] = sub_table_name return sub_table_name_to_config, feature_to_sub_table def __init__(self, num_ps: int, table_factory: TableFactory, use_native_multi_hash_table: bool, max_rpc_deadline_millis: int = 30, queue_configs: Dict[str, int] = None, parser_ctx=None): self._local_ps: bool = True if num_ps == 0 else False self._max_rpc_deadline_millis = max_rpc_deadline_millis self._queue_configs = queue_configs or {} if parser_ctx is None: parser_ctx = get_default_parser_ctx() self._inner_data_type = parser_ctx.parser_type assert parser_ctx.sharding_sparse_fids_op_params is not None self._num_ps = parser_ctx.sharding_sparse_fids_op_params.num_ps self._use_native_multi_hash_table = parser_ctx.sharding_sparse_fids_op_params.use_native_multi_hash_table self._unique = parser_ctx.sharding_sparse_fids_op_params.unique self.transfer_float16 = parser_ctx.sharding_sparse_fids_op_params.transfer_float16 self._sub_table_name_to_config = parser_ctx.sharding_sparse_fids_op_params.sub_table_name_to_config self._feature_configs = parser_ctx.sharding_sparse_fids_op_params.feature_configs self._enable_gpu_emb = parser_ctx.sharding_sparse_fids_op_params.enable_gpu_emb self._use_gpu = parser_ctx.sharding_sparse_fids_op_params.use_gpu sub_table_to_learning_rate_tensor = { sub_table_name: config.call_learning_rate_fns() for sub_table_name, config in self._sub_table_name_to_config.items() } packed_sub_table_to_learning_rate_tensor = tensor_utils.pack_tensors( sub_table_to_learning_rate_tensor) if self._use_native_multi_hash_table: self._sub_table_names = [self._native_multi_hash_table_fake_table] else: self._sub_table_names = sorted(self._sub_table_name_to_config) if self._enable_gpu_emb: self._shard_num = self._num_ps if enable_bps: import byteps.tensorflow as bps assert bps.size() == self._shard_num self._index = bps.rank() else: assert hvd.size() == self._shard_num self._index = hvd.rank() self._table = table_factory(self._index, self._sub_table_name_to_config) self._output_dims = self._table.get_table_dim_sizes() self._dependency_ops = [] table_name_list = [] feautres_name_list = [] for feature_name, cfg in self._feature_configs.feature_configs.items(): if cfg.table not in table_name_list: table_name_list.append(cfg.table) if feature_name not in feautres_name_list: feautres_name_list.append(feature_name) table_name_list.sort() self._num_table = len(table_name_list) self._num_feature = len(feautres_name_list) else: self._tables = [] for i in range(self._num_ps): if export_context.is_exporting_distributed(): ps_graph = export_context.get_current_export_ctx().sub_graph( f"ps_{i}") with ps_graph.as_default(): table = table_factory(i, self._sub_table_name_to_config) self._tables.append(table) # Build lookup graph on the PS side remote_lookup_input = { sub_table_name: tf.compat.v1.placeholder(dtype=tf.int64, shape=(None,)) for sub_table_name in self._sub_table_name_to_config } remote_lookup_output = table.lookup(remote_lookup_input) export_context.get_current_export_ctx().add_signature( ps_graph, 'lookup', remote_lookup_input, remote_lookup_output) if use_native_multi_hash_table: raw_remote_lookup_input = { "id": tf.compat.v1.placeholder(dtype=tf.int64, shape=(None,)), "id_split": tf.compat.v1.placeholder(dtype=tf.int64, shape=(None,)), } raw_remote_lookup_output = { "flat_emb": table.raw_lookup( tf.RaggedTensor.from_row_splits( raw_remote_lookup_input["id"], raw_remote_lookup_input["id_split"], validate=False)) } export_context.get_current_export_ctx().add_signature( ps_graph, 'raw_lookup', raw_remote_lookup_input, raw_remote_lookup_output) elif export_context.is_exporting_standalone(): self._tables.append(table_factory(i, self._sub_table_name_to_config)) else: with nullcontext() if self._local_ps else ps_device(i): # Send learning rate tensor to ps sub_table_to_learning_rate_tensor_on_ps = tensor_utils.unpack_tensors( tensor_utils.get_keyed_shape(sub_table_to_learning_rate_tensor), packed_sub_table_to_learning_rate_tensor) for sub_table_name, config in self._sub_table_name_to_config.items( ): config.set_learning_rate_tensor( sub_table_to_learning_rate_tensor_on_ps[sub_table_name]) self._tables.append(table_factory(i, self._sub_table_name_to_config)) @property def slot_mapping(self): """Returns slot mapping.""" return self._slot_mapping def _native_hash_table_lookup_raw(self, lookup_data_on_wk: LookupData, lookup_data_on_wk_row_split: LookupData): ps_idx_to_multi_type_resp = {} def emit_lookup_timer_ops(i, interval): return [ logging_ops.emit_timer( "embedding_lookup", tf.cast(interval, tf.float32), tags={ "model_name": native_task_context.get().model_name, "ps": str(i) }) ] interval_ops = [] for i in range(self._num_ps): table: multi_hash_table_ops.RawMultiTypeHashTable = self._tables[i] (splitted_id_values,), send_ts = logging_ops.tensors_timestamp( [lookup_data_on_wk[i][self._native_multi_hash_table_fake_table]]) splitted_id = tf.RaggedTensor.from_row_splits( splitted_id_values, lookup_data_on_wk_row_split[i][ self._native_multi_hash_table_fake_table], validate=False) is_standalone = export_context.is_exporting_standalone() or self._local_ps with nullcontext() if is_standalone else ps_device(i): flat_emb = table.raw_lookup(splitted_id) (flat_emb,), end_ts = logging_ops.tensors_timestamp([flat_emb]) interval_ops.extend(emit_lookup_timer_ops(i, end_ts - send_ts)) ps_idx_to_multi_type_resp[i] = { self._native_multi_hash_table_fake_table: flat_emb } ret = {} with tf.control_dependencies(interval_ops): for i, sub_item in ps_idx_to_multi_type_resp.items(): ret[i] = {} for tname, ts in sub_item.items(): ret[i][tname] = tf.identity(ts) return ret def _lookup_raw(self, lookup_data_on_wk: LookupData): ps_idx_to_multi_type_resp = {} def emit_lookup_timer_ops(i, interval): return [ logging_ops.emit_timer( "embedding_lookup", tf.cast(interval, tf.float32), tags={ "model_name": native_task_context.get().model_name, "ps": str(i) }) ] interval_ops = [] for i in range(self._num_ps): # sub_table_name -> fids tensor multi_type_query: Dict[str, tf.Tensor] = lookup_data_on_wk[i] # Note: this is a python object, use it outside this device context required no data transfer multi_type_query_shape: Dict[ str, List[int]] = tensor_utils.get_keyed_shape(multi_type_query) # to reduce the number of rpc (send/recv ops) call, we pack fids by concat packed_fids_on_worker, send_ts = logging_ops.tensors_timestamp( tensor_utils.pack_tensors(multi_type_query)) is_standalone = export_context.is_exporting_standalone() or self._local_ps with nullcontext() if is_standalone else ps_device(i): packed_fids_on_ps = packed_fids_on_worker # data transfer in logic unpacked_multi_type_query = tensor_utils.unpack_tensors( multi_type_query_shape, packed_fids_on_ps) multi_type_resp_on_ps = self._tables[i].lookup( unpacked_multi_type_query) # Note: this is a python object, use it outside this device context required no data transfer multi_type_resp_shape = tensor_utils.get_keyed_shape( multi_type_resp_on_ps) # to reduce rpc (send/recv ops), we pack embeddings by concat packed_embeddings, emb_sizes = tensor_utils.pack_tensors( multi_type_resp_on_ps) if self.transfer_float16: packed_embeddings_fp16 = tf.cast( packed_embeddings, dtype=tf.float16, name='{}_send_{}_CastToFloat16'.format(packed_embeddings.op.name, i)) packed_embedding_on_ps = (packed_embeddings_fp16, emb_sizes) else: packed_embedding_on_ps = (packed_embeddings, emb_sizes) packed_embedding_on_worker, end_ts = logging_ops.tensors_timestamp( packed_embedding_on_ps) # data transfer in logic interval_ops.extend(emit_lookup_timer_ops(i, end_ts - send_ts)) # on worker, uppack if self.transfer_float16: packed_embeddings_fp16_recv, emb_sizes_recv = packed_embedding_on_worker packed_embeddings_fp32 = tf.cast(packed_embeddings_fp16_recv, dtype=tf.float32, name='{}_recv_{}_CastToFloat32'.format( packed_embeddings_fp16.op.name, i)) packed_embedding_on_worker = (packed_embeddings_fp32, emb_sizes) multi_type_resp_on_worker = tensor_utils.unpack_tensors( multi_type_resp_shape, packed_embedding_on_worker) ps_idx_to_multi_type_resp[i] = multi_type_resp_on_worker ret = {} with tf.control_dependencies(interval_ops): for i, sub_item in ps_idx_to_multi_type_resp.items(): ret[i] = {} for tname, ts in sub_item.items(): ret[i][tname] = tf.identity(ts) return ret def _native_hash_table_lookup_with_remote_predict( self, lookup_data_on_wk: LookupData, lookup_data_on_wk_row_split: LookupData): ps_idx_to_multi_type_resp = {} for i in range(self._num_ps): flat_emb, = remote_predict( ["id", "id_split"], [ lookup_data_on_wk[i][self._native_multi_hash_table_fake_table], lookup_data_on_wk_row_split[i][ self._native_multi_hash_table_fake_table] ], ["flat_emb"], task=i, old_model_name="ps_{}".format(i), model_name=f"{native_task_context.get().model_name or ''}:ps_{i}", model_version=-1, max_rpc_deadline_millis=self._max_rpc_deadline_millis, output_types=[tf.float32], signature_name="raw_lookup") ps_idx_to_multi_type_resp[i] = { self._native_multi_hash_table_fake_table: flat_emb } return ps_idx_to_multi_type_resp def _lookup_with_remote_predict(self, lookup_data_on_wk: LookupData): sub_table_to_dim = [ infer_dim_size( self._sub_table_name_to_config[sub_table_name].table_config) for sub_table_name in self._sub_table_names ] ps_idx_to_multi_type_resp = {} for i in range(self._num_ps): multi_type_query = lookup_data_on_wk[i] # Remote call from Entry to PS # TODO(leqi.zou): Consider a better way to get model name. results = remote_predict( input_tensor_alias=self._sub_table_names, input_tensors=[ multi_type_query[sub_table_name] for sub_table_name in self._sub_table_names ], output_tensor_alias=self._sub_table_names, task=i, old_model_name="ps_{}".format(i), model_name=f"{native_task_context.get().model_name or ''}:ps_{i}", model_version=-1, max_rpc_deadline_millis=self._max_rpc_deadline_millis, output_types=[tf.float32] * len(self._sub_table_names), signature_name="lookup") ps_idx_to_multi_type_resp[i] = { sub_table_name: tf.reshape(results[j], [-1, sub_table_to_dim[j]]) for j, sub_table_name in enumerate(self._sub_table_names) } return ps_idx_to_multi_type_resp def lookup( self, features: Dict[str, tf.Tensor], auxiliary_bundle: Dict[str, tf.Tensor] = None, ret_fused_layout_callable_fn=False, ret_lookup_callable_fn=False, embedding_prefetch_capacity=0, ) -> Dict[str, Union[tf.Tensor, List[tf.Tensor]]]: if self._enable_gpu_emb: ret = self._lookup_gpu(features, auxiliary_bundle) if ret_fused_layout_callable_fn or ret_lookup_callable_fn: def lookup_callable_fn(auxiliary_bundle_, features_): return ret return lookup_callable_fn else: return ret with tf.name_scope("pht_lookup"): if ParserCtx.sharding_sparse_fids_sparse_features_key in features: #assert False, "not support, please use sharding_sparse_fids_with_context before call lookup" # only support for cpu training(without dsworker) shards, fid_offset, feature_offset, nfl_offset, batch_size, nfl_size, \ feature_size, fid_size, emb_size, shards_row_split, shards_row_split_size, \ fid_list_emb_row_lenth, fid_list_table_row_length, fid_list_shard_row_lenth = \ sharding_sparse_fids( features[ParserCtx.sharding_sparse_fids_sparse_features_key], ps_num=self._num_ps, feature_cfgs=self._feature_configs, unique=self._unique(), input_type=self._inner_data_type, parallel_flag=0) else: sharding_features = ParserCtx.sharding_sparse_fids_features_parse_from_features( features) shards, fid_offset, feature_offset, nfl_offset, batch_size, nfl_size, \ feature_size, fid_size, emb_size, shards_row_split, shards_row_split_size = \ sharding_features.get("shards"), sharding_features.get("fid_offset") ,\ sharding_features.get("feature_offset"), sharding_features.get("nfl_offset") ,\ sharding_features.get("batch_size"), sharding_features.get("nfl_size", None), \ sharding_features.get("feature_size", None), sharding_features.get("fid_size", None), \ sharding_features.get("emb_size", None), sharding_features.get("shards_row_split", None), \ sharding_features.get("shards_row_split_size", None) if auxiliary_bundle is None: auxiliary_bundle = {} auxiliary_bundle['__sharding_sparse_fids__fid_offset'] = fid_offset auxiliary_bundle[ '__sharding_sparse_fids__feature_offset'] = feature_offset auxiliary_bundle['__sharding_sparse_fids__nfl_offset'] = nfl_offset auxiliary_bundle['__sharding_sparse_fids__batch_size'] = tf.identity( batch_size, name="batch_size") if nfl_size is not None: auxiliary_bundle['__sharding_sparse_fids__nfl_size'] = nfl_size if feature_size is not None: auxiliary_bundle['__sharding_sparse_fids__feature_size'] = feature_size if fid_size is not None: auxiliary_bundle['__sharding_sparse_fids__fid_size'] = fid_size if emb_size is not None: auxiliary_bundle['__sharding_sparse_fids__emb_size'] = emb_size logging.info(f"sharding_sparse_fids done, {shards}") if ret_lookup_callable_fn: for key, shard_fids in shards.items(): sub_table_name, ps_idx = key.split(':') ps_idx = int(ps_idx) name = '__sharding_sparse_fids__shards@{}@{}'.format( ps_idx, sub_table_name) auxiliary_bundle[name] = shard_fids if self._use_native_multi_hash_table: name = '__sharding_sparse_fids__shards_row_split@{}@{}'.format( ps_idx, sub_table_name) auxiliary_bundle[name] = shards_row_split[key] if shards_row_split_size is not None and shards_row_split_size[ key] is not None: name = '__sharding_sparse_fids__shards_row_split_size@{}@{}'.format( ps_idx, sub_table_name) auxiliary_bundle[name] = shards_row_split_size[key] def fused_layout_callable_fn(auxiliary_bundle_, features_): flattened_embs = [] assert auxiliary_bundle_ is not None for sub_table_name in self._sub_table_names: for ps_idx in range(self._num_ps): flattened_embs.append(auxiliary_bundle_[ f'__sharding_sparse_fids__{sub_table_name}:{ps_idx}:embs']) nfl_offset_ = auxiliary_bundle_['__sharding_sparse_fids__nfl_offset'] feature_offset_ = auxiliary_bundle_[ '__sharding_sparse_fids__feature_offset'] fid_offset_ = auxiliary_bundle_['__sharding_sparse_fids__fid_offset'] batch_size_ = auxiliary_bundle_['__sharding_sparse_fids__batch_size'] nfl_size_ = auxiliary_bundle_.get('__sharding_sparse_fids__nfl_size', None) feature_size_ = auxiliary_bundle_.get( '__sharding_sparse_fids__feature_size', None) fid_size_ = auxiliary_bundle_.get('__sharding_sparse_fids__fid_size', None) emb_size_ = auxiliary_bundle_.get('__sharding_sparse_fids__emb_size', None) if export_context.is_exporting(): fused_layout_use_gpu = export_context.get_current_export_ctx( ).with_remote_gpu else: fused_layout_use_gpu = self._use_gpu with tf.device( "/device:GPU:0" if fused_layout_use_gpu else "/device:CPU:0"): layout_tensors = distribution_ops.fused_embedding_to_layout( flattened_embs, None, #self.fids_list_row_split, v3 not need fids_list_row_split fid_offset=fid_offset_, feature_offset=feature_offset_, nfl_offset=nfl_offset_, batch_size=batch_size_, nfl_size=nfl_size_, feature_size=feature_size_, fid_size=fid_size_, emb_size=emb_size_, variant_type=self._inner_data_type, feature_cfgs=self._feature_configs, ps_num=self._num_ps, version=5) layout_embeddings = self.nest_layout(layout_tensors) ''' if not self._use_gpu: # embedding_prefetch logging.info( f"PartitionedHashTable lookup fused_layout enqueue: {auxiliary_bundle_} {features_}" ) (deq_layout_embeddings, deq_auxiliary_bundle, deq_features), queue = enqueue_dicts_with_queue_return( (layout_embeddings, auxiliary_bundle_, features_), capacity=embedding_prefetch_capacity) if queue: self.add_queue_hook(EnqueueHook(queue)) features_.update(deq_features) auxiliary_bundle_.update(deq_auxiliary_bundle) logging.info( f"PartitionedHashTable lookup fused_layout dequeue: {auxiliary_bundle_} {features_}" ) else: deq_layout_embeddings = layout_embeddings ''' logging.info("fused_embedding_to_layout done!") return layout_embeddings #deq_layout_embeddings def call_lookup(lookup_data_on_wk: LookupData, lookup_data_on_wk_row_split: LookupData, auxiliary_bundle_, features_): with tf.name_scope("pht_lookup"): # ps_idx_to_multi_type_resp: Dict[int, Dict[str, tf.Tensor]] = {} with tf.device("/device:CPU:0"): if export_context.is_exporting_distributed(): if self._use_native_multi_hash_table: ps_idx_to_multi_type_resp = self._native_hash_table_lookup_with_remote_predict( lookup_data_on_wk, lookup_data_on_wk_row_split) else: ps_idx_to_multi_type_resp = self._lookup_with_remote_predict( lookup_data_on_wk) else: if self._use_native_multi_hash_table: ps_idx_to_multi_type_resp = self._native_hash_table_lookup_raw( lookup_data_on_wk, lookup_data_on_wk_row_split) else: ps_idx_to_multi_type_resp = self._lookup_raw(lookup_data_on_wk) for sub_table_name in self._sub_table_names: for ps_idx in range(self._num_ps): embeddings_tensor = ps_idx_to_multi_type_resp[ps_idx][ sub_table_name] auxiliary_bundle_[ f'__sharding_sparse_fids__{sub_table_name}:{ps_idx}:embs'] = embeddings_tensor if not export_context.is_exporting(): fids_tensor = lookup_data_on_wk[ps_idx][sub_table_name] auxiliary_bundle_[ f'__sharding_sparse_fids__{sub_table_name}:{ps_idx}:fids'] = fids_tensor if self._use_native_multi_hash_table: fids_tensor_row_split = lookup_data_on_wk_row_split[ps_idx][ sub_table_name] auxiliary_bundle_[ f'__sharding_sparse_fids__{sub_table_name}:{ps_idx}:fids_row_split'] = fids_tensor_row_split if self._use_gpu: logging.info( f"PartitionedHashTable lookup gpu fused_layout tensor to gpu before: {auxiliary_bundle} {features}" ) self.tensor_move_to_gpu(((auxiliary_bundle_, [ "__sharding_sparse_fids__batch_size", "__sharding_sparse_fids__nfl_size", "__sharding_sparse_fids__feature_size", "__sharding_sparse_fids__fid_size", "__sharding_sparse_fids__emb_size" ]), (features_, ["req_time"]))) logging.info( f"PartitionedHashTable lookup fused_layout enqueue before: {auxiliary_bundle} {features}" ) (dequeued_features, deq_auxiliary_bundle), queue = enqueue_dicts_with_queue_return( (features_, auxiliary_bundle_), capacity=embedding_prefetch_capacity) if queue: self.add_queue_hook(EnqueueHook(queue)) features_.update(dequeued_features) auxiliary_bundle_.update(deq_auxiliary_bundle) logging.info( f"PartitionedHashTable lookup fused_layout dequeue: {auxiliary_bundle} {features}" ) def lookup_callable_fn(auxiliary_bundle_, features_): with tf.name_scope("pht_lookup"): lookup_data_on_wk: LookupData = {} lookup_data_on_wk_row_split: LookupData = {} for sub_table_name in self._sub_table_names: for ps_idx in range(self._num_ps): key = '__sharding_sparse_fids__shards@{}@{}'.format( ps_idx, sub_table_name) if ps_idx not in lookup_data_on_wk: lookup_data_on_wk[ps_idx] = {} lookup_data_on_wk[ps_idx][sub_table_name] = auxiliary_bundle_[key] if self._use_native_multi_hash_table: key = '__sharding_sparse_fids__shards_row_split@{}@{}'.format( ps_idx, sub_table_name) size_key = '__sharding_sparse_fids__shards_row_split_size@{}@{}'.format( ps_idx, sub_table_name) if ps_idx not in lookup_data_on_wk_row_split: lookup_data_on_wk_row_split[ps_idx] = {} if size_key not in auxiliary_bundle_: lookup_data_on_wk_row_split[ps_idx][ sub_table_name] = auxiliary_bundle_[key] else: lookup_data_on_wk_row_split[ps_idx][ sub_table_name] = distribution_ops.normalize_merged_split( auxiliary_bundle_[key], auxiliary_bundle_[size_key]) call_lookup(lookup_data_on_wk, lookup_data_on_wk_row_split, auxiliary_bundle_, features_) return fused_layout_callable_fn(auxiliary_bundle_, features_) if ret_lookup_callable_fn: return lookup_callable_fn with tf.name_scope("pht_lookup"): lookup_data_on_wk: LookupData = {} lookup_data_on_wk_row_split: LookupData = {} for key, shard_fids in shards.items(): sub_table_name, ps_idx = key.split(':') ps_idx = int(ps_idx) if self._use_native_multi_hash_table: shards_row_split_part = shards_row_split[key] else: shards_row_split_part = None if ps_idx in lookup_data_on_wk: lookup_data_on_wk[ps_idx][sub_table_name] = shard_fids lookup_data_on_wk_row_split[ps_idx][ sub_table_name] = shards_row_split_part else: lookup_data_on_wk[ps_idx] = {sub_table_name: shard_fids} lookup_data_on_wk_row_split[ps_idx] = { sub_table_name: shards_row_split_part } call_lookup(lookup_data_on_wk, lookup_data_on_wk_row_split, auxiliary_bundle, features) if ret_fused_layout_callable_fn: return fused_layout_callable_fn else: return fused_layout_callable_fn(auxiliary_bundle, features) def apply_gradients( self, layout_grads_and_vars: List[Tuple[tf.Tensor, tf.Tensor]], global_step: tf.Tensor, req_time: Optional[tf.Tensor] = None, auxiliary_bundle: Dict[str, tf.Tensor] = None, async_function_mgr: prefetch_queue.AsyncFunctionMgr = None, async_push: bool = False, grad_scale: tf.Tensor = None) -> PartitionedHashTable: logging.info( f"PartitionedHashTable apply_gradients {async_push} {async_function_mgr}" ) with tf.device( "/device:GPU:0" if self._enable_gpu_emb else "/device:CPU:0"): if req_time is None: req_time = tf.constant(0, dtype=tf.int64) else: req_time = tf.reduce_max(req_time) assert auxiliary_bundle is not None if self._enable_gpu_emb: assert not async_push return self._apply_gradients_gpu(layout_grads_and_vars, global_step, req_time, auxiliary_bundle, grad_scale=grad_scale) with tf.name_scope("pht_apply_gradients"): layout_grad, layout = zip(*layout_grads_and_vars) flattened_fids, flattened_fids_row_split, flattened_embs = [], [], [] for sub_table_name in self._sub_table_names: for ps_idx in range(self._num_ps): flattened_fids.append(auxiliary_bundle[ f'__sharding_sparse_fids__{sub_table_name}:{ps_idx}:fids']) if self._use_native_multi_hash_table: flattened_fids_row_split.append(auxiliary_bundle[ f'__sharding_sparse_fids__{sub_table_name}:{ps_idx}:fids_row_split'] ) flattened_embs.append(auxiliary_bundle[ f'__sharding_sparse_fids__{sub_table_name}:{ps_idx}:embs']) nfl_offset = auxiliary_bundle['__sharding_sparse_fids__nfl_offset'] feature_offset = auxiliary_bundle[ '__sharding_sparse_fids__feature_offset'] fid_offset = auxiliary_bundle['__sharding_sparse_fids__fid_offset'] batch_size = auxiliary_bundle['__sharding_sparse_fids__batch_size'] with tf.device("/device:GPU:0" if self._use_gpu else "/device:CPU:0"): embeddings_grad = distribution_ops.fused_embedding_to_layout_grad( nfl_offset=nfl_offset, feature_offset=feature_offset, fid_offset=fid_offset, batch_size=batch_size, embeddings_list=flattened_embs, fid_list_row_split=None, #flattened_fids_row_split, v3 no need layout_tensors_grad=layout_grad, layout_tensors_grad_scale=grad_scale, variant_type=self._inner_data_type, feature_cfgs=self._feature_configs, ps_num=self._num_ps) def hash_table_apply_gradients(flattened_fids_, flattened_fids_row_split_, embeddings_grad_, global_step_, req_time_): if self._use_gpu: logging.info( f"PartitionedHashTable apply_gradients fused_layout before tensor_move_to_cpu: \ {flattened_fids_}, {flattened_fids_row_split_}, {embeddings_grad_}, \ {global_step_}, {req_time_}") def tensor_move_to_cpu(*inputs): inputs_info = [] to_cpu_value_list = [] for tensor_ in inputs: to_cpu_dict = defaultdict(int) if isinstance(tensor_, List): for idx in range(len(tensor_)): part_tensor = tensor_[idx] to_cpu_dict[idx] = len(to_cpu_value_list) to_cpu_value_list.append(part_tensor) elif tensor_ is not None: to_cpu_dict[-1] = len(to_cpu_value_list) to_cpu_value_list.append(tensor_) inputs_info.append(to_cpu_dict) with tf.device("/device:CPU:0"): gpu_value_list = tf.identity_n(to_cpu_value_list) outputs = [] for input_idx in range(len(inputs)): to_cpu_dict = inputs_info[input_idx] part_input = inputs[input_idx] for k, idx in to_cpu_dict.items(): if k == -1: part_input = gpu_value_list[idx] continue else: part_input[k] = gpu_value_list[idx] outputs.append(part_input) return tuple(outputs) flattened_fids_, flattened_fids_row_split_, embeddings_grad_, \ global_step_, req_time_ = tensor_move_to_cpu( flattened_fids_, flattened_fids_row_split_, embeddings_grad_, global_step_, req_time_) logging.info( f"PartitionedHashTable apply_gradients fused_layout tensor_move_to_cpu: \ {flattened_fids_}, {flattened_fids_row_split_}, {embeddings_grad_}, \ {global_step_}, {req_time_}") with tf.device("/device:CPU:0"): if self._use_native_multi_hash_table: new_tables = [] for i in range(self._num_ps): splitted_id = tf.RaggedTensor.from_row_splits( flattened_fids_[i], flattened_fids_row_split_[i], validate=False) splitted_flat_grad = embeddings_grad_[i] table: multi_hash_table_ops.RawMultiTypeHashTable = self._tables[ i] with nullcontext() if self._local_ps else ps_device(i): new_tables.append( table.raw_apply_gradients(splitted_id, splitted_flat_grad, global_step=global_step, req_time=req_time_)) else: fids_and_embgrad_pairs = list(zip(flattened_fids_, embeddings_grad_)) logging.info( f"PartitionedHashTable apply_gradients fused_embedding_to_layout_grad done, {fids_and_embgrad_pairs}" ) # reconstruct update data offset = 0 update_data_on_worker: UpdateData = { ps_idx: {} for ps_idx in range(self._num_ps) } for sub_table_name in self._sub_table_names: for ps_idx in range(self._num_ps): update_data_on_worker[ps_idx][ sub_table_name] = fids_and_embgrad_pairs[offset] offset += 1 new_tables = [] for i in range(self._num_ps): keyed_fids, keyed_grads = {}, {} for tbname, (fids, emb_grad) in update_data_on_worker[i].items(): keyed_fids[tbname] = fids keyed_grads[tbname] = emb_grad packed_list = tensor_utils.pack_typed_keyed_tensors( [keyed_fids, keyed_grads]) (packed_fids_on_wk, packed_emb_grad_on_wk, packed_sizes_on_wk) = packed_list typed_keyed_shape = tensor_utils.get_typed_keyed_shape( [keyed_fids, keyed_grads]) if self.transfer_float16: packed_emb_grad_on_wk = tf.cast( packed_emb_grad_on_wk, dtype=tf.float16, name='{}_send_{}_CastToFloat16'.format( packed_emb_grad_on_wk.op.name, i)) with nullcontext() if self._local_ps else ps_device(i): packed_fids_on_ps = packed_fids_on_wk packed_emb_grad_on_ps = packed_emb_grad_on_wk packed_sizes_on_ps = packed_sizes_on_wk if self.transfer_float16: packed_emb_grad_on_ps = tf.cast( packed_emb_grad_on_ps, dtype=tf.float32, name='{}_recv_{}_CastToFloat32'.format( packed_emb_grad_on_ps.op.name, i)) keyed_fids_on_ps, keyed_grads_on_ps = tensor_utils.unpack_packed_tensors( typed_keyed_shape, packed_list=[ packed_fids_on_ps, packed_emb_grad_on_ps, packed_sizes_on_ps ]) partitioned_update_data_on_ps = { tbname: (keyed_fids_on_ps[tbname], keyed_grads_on_ps[tbname]) for tbname in keyed_fids_on_ps } new_tables.append(self._tables[i].apply_gradients( partitioned_update_data_on_ps, global_step_, req_time=req_time_)) return self._copy_with_new_table(new_tables).as_op() if async_function_mgr is None or not async_push: return hash_table_apply_gradients(flattened_fids, flattened_fids_row_split, embeddings_grad, global_step, req_time) else: return async_function_mgr.add_async_function( hash_table_apply_gradients, (flattened_fids, flattened_fids_row_split, embeddings_grad, global_step, req_time), is_async=async_push, queue_name="postpush_queue") def tensor_move_to_gpu(self, inputs): inputs_info = [] to_gpu_value_list = [] for tensor_dict, except_list in inputs: to_gpu_dict = defaultdict(int) for k, v in tensor_dict.items(): if k in except_list: continue to_gpu_dict[k] = len(to_gpu_value_list) to_gpu_value_list.append(v) inputs_info.append((to_gpu_dict, tensor_dict)) with tf.device("/device:GPU:0"): gpu_value_list = tf.identity_n(to_gpu_value_list) for to_gpu_dict, tensor_dict in inputs_info: for k, idx in to_gpu_dict.items(): tensor_dict[k] = gpu_value_list[idx] def as_op(self, name=None): name = name or "pht_as_op" if self._enable_gpu_emb: with tf.control_dependencies(self._dependency_ops): return self._table.as_op(name) ops = [] for i in range(self._num_ps): with nullcontext() if self._local_ps else ps_device(i): ops.append(self._tables[i].as_op(name="{}/sub_{}".format(name, i))) with tf.control_dependencies(ops): return tf.no_op(name=("{}/done".format(name))) def _lookup_gpu( self, features: Dict[str, tf.Tensor], auxiliary_bundle: Dict[str, tf.Tensor] = None, ) -> Dict[str, Union[tf.Tensor, List[tf.Tensor]]]: if enable_bps: import byteps.tensorflow as bps with tf.name_scope("pht_lookup_gpu"): logging.info( f"PartitionedHashTable lookup_gpu fused_layout tensor to gpu before: {features}" ) slot_num = self._num_table recv_emb_splits_tmp = tf.reshape( tf.matmul( tf.reshape( features["__sharding_sparse_fids__shards_table_row_lengths"], [self._shard_num, slot_num]), tf.expand_dims(tf.constant(self._output_dims, dtype=tf.int32), -1) # [slot_num, 1] ), [-1] # flatten ) features["__sharding_sparse_fids__recv_emb_splits"] = recv_emb_splits_tmp #self.tensor_move_to_gpu( # ((features, ["__sharding_sparse_fids__batch_size"]),)) logging.info( f"PartitionedHashTable lookup_gpu fused_layout enqueue before: {features}" ) sharding_features = ParserCtx.sharding_sparse_fids_features_parse_from_features( features) shards_value, shards_row_lengths, shards_table_row_lengths, fid_offset, feature_offset, \ nfl_offset, batch_size, fid_list_emb_row_lenth, recv_emb_splits = \ sharding_features.get("shards_value"), sharding_features.get("shards_row_lengths"), \ sharding_features.get("shards_table_row_lengths"), sharding_features.get("fid_offset") ,\ sharding_features.get("feature_offset"), sharding_features.get("nfl_offset") ,\ sharding_features.get("batch_size"), sharding_features.get("fid_list_emb_row_lenth"), \ sharding_features.get("recv_emb_splits") if auxiliary_bundle is None: auxiliary_bundle = {} auxiliary_bundle['__sharding_sparse_fids__fid_offset'] = fid_offset auxiliary_bundle[ '__sharding_sparse_fids__feature_offset'] = feature_offset auxiliary_bundle['__sharding_sparse_fids__nfl_offset'] = nfl_offset auxiliary_bundle['__sharding_sparse_fids__batch_size'] = batch_size auxiliary_bundle[ "__sharding_sparse_fids__recv_emb_splits"] = recv_emb_splits auxiliary_bundle[ "__sharding_sparse_fids__fid_list_emb_row_lenth"] = fid_list_emb_row_lenth logging.info( f"sharding_sparse_fids done, {shards_value} {shards_row_lengths}") all_fids = shards_value shard_sizes = shards_row_lengths sharded_slot_sizes = shards_table_row_lengths # We exchange the flattened IDs and their splits. # M: num_of_ids, # N: num_of_shards, # K: num_of_merged_tables, # E: num_of_total_embedding_dim. # id_flat_t: [M], id_flat_split_t: [N] # id_size_flat_t: [K*N], id_size_flat_split_t: [N] if enable_bps and enable_bps_fid: logging.info('Enabled BPS for fid alltoall') id_flat_t, id_flat_split_t = bps.alltoall(all_fids, splits=shard_sizes, with_size=True, name='fid_data') # We also add the flat_t sizes. id_size_flat_t = bps.alltoall(sharded_slot_sizes, splits=[slot_num] * self._shard_num, recv_splits=([slot_num] * self._shard_num), name='fid_size') elif enable_custom_optimized_hvd: id_flat_t, id_flat_split_t = hvd.alltoall(all_fids, splits=shard_sizes, with_size=True) # We also add the flat_t sizes. id_size_flat_t = hvd.alltoall(sharded_slot_sizes, splits=[slot_num] * self._shard_num, recv_splits=[slot_num] * self._shard_num) elif enable_hvd: if enable_hvd_fid_g2g: logging.info('Enabled hvd for fid alltoall g2g') with tf.device("/device:GPU:0"): id_flat_t = hvd.alltoall(all_fids, splits=shard_sizes) id_flat_split_t = hvd.alltoall(shard_sizes) id_size_flat_t = hvd.alltoall(sharded_slot_sizes, splits=[slot_num] * self._shard_num) else: id_flat_t = hvd.alltoall(all_fids, splits=shard_sizes) id_flat_split_t = hvd.alltoall(shard_sizes) id_size_flat_t = hvd.alltoall(sharded_slot_sizes, splits=[slot_num] * self._shard_num) auxiliary_bundle["__sharding_sparse_fids__shard_sizes"] = shard_sizes auxiliary_bundle["__sharding_sparse_fids__id_flat_t"] = id_flat_t auxiliary_bundle[ "__sharding_sparse_fids__id_size_flat_t"] = id_size_flat_t # fused_embeddings: [E], fused_splits: [N] # id_offsets: [K*N], emb_offsets: [K*N] req_time = features.get("req_time", None) with tf.device( "/device:GPU:0" if self._enable_gpu_emb else "/device:CPU:0"): if req_time is None: logging.warning(f"fused_embedding_to_layout use default req_time") req_time = tf.constant(0, dtype=tf.int64) else: req_time = tf.reduce_max(req_time) with tf.device("/GPU:0"): fused_embeddings, embedding_splits, id_offsets, emb_offsets, indices = \ self._table.fused_lookup(id_flat_t, id_size_flat_t, self._shard_num, req_time) if FLAGS.enable_alltoall_metrics: with tf.device("/CPU:0"): tf.compat.v1.summary.histogram("fused_embedding_splits", embedding_splits) auxiliary_bundle[ "__sharding_sparse_fids__fused_embeddings"] = fused_embeddings auxiliary_bundle[ "__sharding_sparse_fids__embedding_splits"] = embedding_splits auxiliary_bundle["__sharding_sparse_fids__id_offsets"] = id_offsets auxiliary_bundle["__sharding_sparse_fids__emb_offsets"] = emb_offsets auxiliary_bundle["__sharding_sparse_fids__indices"] = indices deq_auxiliary_bundle, queue = enqueue_dicts_with_queue_return( auxiliary_bundle, capacity=int(self._queue_configs.get("enable_pipelined_fwda2a", 0)), queue_name="queue_lookup_to_fusedEmbA2A") if queue: self.add_queue_hook(EnqueueHook(queue)) auxiliary_bundle.update(deq_auxiliary_bundle) fused_embeddings = auxiliary_bundle.pop( "__sharding_sparse_fids__fused_embeddings") embedding_splits = auxiliary_bundle[ "__sharding_sparse_fids__embedding_splits"] recv_emb_splits = auxiliary_bundle[ "__sharding_sparse_fids__recv_emb_splits"] # recv_embeddings: [E'], recv_embedding_sizes: [N] if enable_bps and enable_bps_fwd: if enable_bps_fwd_gdr: if enable_bps_fwd_gdr_g2g: logging.info('Enabled BPS for fwd embed alltoall GDR (G2G)') with tf.device("/device:GPU:0"): fused_embeddings_gpu = fused_embeddings with tf.device("/device:GPU:0"): recv_embeddings = bps.alltoall(fused_embeddings_gpu, embedding_splits, recv_splits=recv_emb_splits, name="fwd_alltoall_g2g") else: logging.info('Enabled BPS for fwd embed alltoall GDR (C2G)') with tf.device("/device:GPU:0"): recv_embeddings = bps.alltoall_cpu2gpu( fused_embeddings, embedding_splits, recv_splits=recv_emb_splits, name="fwd_alltoall_c2g") else: logging.info('Enabled BPS for fwd embed alltoall') recv_embeddings = bps.alltoall(fused_embeddings, embedding_splits, recv_splits=recv_emb_splits, name="fwd_alltoall") elif enable_custom_optimized_hvd: if enable_hvd_fwd_g2g: logging.info('Enabled optimized hvd for fwd embed alltoall g2g') with tf.device("/device:GPU:0"): recv_embeddings = hvd.alltoall( fused_embeddings, embedding_splits, recv_splits=recv_emb_splits, ) else: logging.info('Enabled optimized hvd for fwd embed alltoall') recv_embeddings = hvd.alltoall( fused_embeddings, embedding_splits, recv_splits=recv_emb_splits, ) elif enable_hvd: if enable_hvd_fwd_g2g: logging.info('Enabled hvd for fwd embed alltoall g2g') with tf.device("/device:GPU:0"): recv_embeddings = hvd.alltoall(fused_embeddings, embedding_splits, name='hvd_fwd_a2a_g2g') else: logging.info('Enabled hvd for fwd embed alltoall') recv_embeddings = hvd.alltoall(fused_embeddings, embedding_splits, name='hvd_fwd_a2a') auxiliary_bundle[ "__sharding_sparse_fids__recv_embeddings"] = recv_embeddings #TODO enable embedding_prefetch_capacity train will slow down ''' deq_auxiliary_bundle, queue = enqueue_dicts_with_queue_return( auxiliary_bundle, capacity=int(self._queue_configs.get("embedding_prefetch_capacity", 0)), queue_name="queue_fusedEmbA2A_to_fusedGather") if queue: self.add_queue_hook(EnqueueHook(queue)) auxiliary_bundle.update(deq_auxiliary_bundle) ''' recv_embeddings = auxiliary_bundle[ "__sharding_sparse_fids__recv_embeddings"] fid_offset = auxiliary_bundle['__sharding_sparse_fids__fid_offset'] feature_offset = auxiliary_bundle[ '__sharding_sparse_fids__feature_offset'] nfl_offset = auxiliary_bundle['__sharding_sparse_fids__nfl_offset'] batch_size = auxiliary_bundle['__sharding_sparse_fids__batch_size'] fid_list_emb_row_lenth = auxiliary_bundle[ '__sharding_sparse_fids__fid_list_emb_row_lenth'] with tf.device("/device:GPU:0"): ''' recv_embeddings_split = tf.split(recv_embeddings, fid_list_emb_row_lenth) flattened_embs = [None] * (self._num_ps * self._num_table) recv_embeddings_split_index = 0 for ps_index in range(self._num_ps): for table_idx in range(self._num_table): flattened_embs[ table_idx * self._num_ps + ps_index] = recv_embeddings_split[recv_embeddings_split_index] recv_embeddings_split_index += 1 ''' layout_tensors = distribution_ops.fused_embedding_to_layout( [recv_embeddings], #flattened_embs, None, #self.fids_list_row_split, v3 not need fids_list_row_split fid_list_emb_row_lenth=fid_list_emb_row_lenth, fid_offset=fid_offset, feature_offset=feature_offset, nfl_offset=nfl_offset, batch_size=batch_size, variant_type=self._inner_data_type, feature_cfgs=self._feature_configs, ps_num=self._num_ps, version=4) #auxiliary_bundle[ # '__sharding_sparse_fids__flattened_embs'] = flattened_embs logging.info("fused_embedding_to_layout done!") return self.nest_layout(layout_tensors) def _apply_gradients_gpu( self, layout_grads_and_vars: List[Tuple[tf.Tensor, tf.Tensor]], global_step: tf.Tensor, req_time: Optional[tf.Tensor] = None, auxiliary_bundle: Dict[str, tf.Tensor] = None, grad_scale: tf.Tensor = None) -> PartitionedHashTable: with tf.name_scope("pht_apply_gradients_gpu"): auxiliary_bundle['__sharding_sparse_fids__global_step'] = global_step auxiliary_bundle["__sharding_sparse_fids__req_time"] = req_time layout_grad, layout = zip(*layout_grads_and_vars) assert auxiliary_bundle is not None fid_offset = auxiliary_bundle.pop('__sharding_sparse_fids__fid_offset') feature_offset = auxiliary_bundle.pop( '__sharding_sparse_fids__feature_offset') nfl_offset = auxiliary_bundle.pop('__sharding_sparse_fids__nfl_offset') batch_size = auxiliary_bundle.pop('__sharding_sparse_fids__batch_size') recv_embeddings = auxiliary_bundle.pop( "__sharding_sparse_fids__recv_embeddings") fid_list_emb_row_lenth = auxiliary_bundle.pop( '__sharding_sparse_fids__fid_list_emb_row_lenth') #flattened_embs = auxiliary_bundle.pop( # '__sharding_sparse_fids__flattened_embs') with tf.device("/device:GPU:0"): embeddings_grad = distribution_ops.fused_embedding_to_layout_grad( nfl_offset=nfl_offset, feature_offset=feature_offset, fid_offset=fid_offset, batch_size=batch_size, embeddings_list=[recv_embeddings], #flattened_embs, fid_list_row_split=None, #flattened_fids_row_split, v3 no need fid_list_emb_row_lenth=fid_list_emb_row_lenth, layout_tensors_grad=layout_grad, layout_tensors_grad_scale=grad_scale, variant_type=self._inner_data_type, feature_cfgs=self._feature_configs, ps_num=self._num_ps, version=4) ''' embeddings_grad_reorder = [None] * (self._num_ps * self._num_table) embeddings_grad_index = 0 for table_idx in range(self._num_table): for ps_index in range(self._num_ps): embeddings_grad_reorder[ table_idx * self._num_ps + ps_index] = embeddings_grad[embeddings_grad_index] embeddings_grad_index += 1 grad_flat = tf.concat(embeddings_grad_reorder, axis=0) ''' grad_flat = embeddings_grad[0] auxiliary_bundle['__sharding_sparse_fids__grad_flat'] = grad_flat deq_auxiliary_bundle, async_optimize_queue = enqueue_dicts_with_queue_return( auxiliary_bundle, capacity=int(self._queue_configs.get("enable_async_optimize", 0)), queue_name="queue_fusedGatherGrad_to_fusedEmbGradA2A") auxiliary_bundle.update(deq_auxiliary_bundle) grad_flat = auxiliary_bundle.pop("__sharding_sparse_fids__grad_flat") recv_emb_splits = auxiliary_bundle.pop( "__sharding_sparse_fids__recv_emb_splits") embedding_splits = auxiliary_bundle.pop( "__sharding_sparse_fids__embedding_splits") if enable_bps and enable_bps_bwd: import byteps.tensorflow as bps from byteps.tensorflow.compression import FP16Compressor as BPSFP16Compressor if enable_bps_bwd_gdr: with tf.device("/device:GPU:0"), tf.name_scope("bps_bwd_alltoall"): if enable_bps_bwd_gdr_g2g: logging.info('Enabled BPS for bwd embed alltoall GDR (G2G)') bwd_tensor_name = "bwd_alltoall_g2g" grad_flat_t = bps.alltoall(grad_flat, recv_emb_splits, recv_splits=embedding_splits, name=bwd_tensor_name) if enable_bps_bwd_cast == 16: # do cast on GPU if enable_bps_bwd_fake_cast: grad_flat_t = grad_flat_t * 0.0 grad_flat_t = tf.cast(grad_flat_t, tf.float32) with tf.device("/device:CPU:0"): grad_flat_t = tf.identity(grad_flat_t) else: logging.info('Enabled BPS for bwd embed alltoall GDR (G2C)') bwd_tensor_name = "bwd_alltoall_g2c" grad_flat_t = bps.alltoall_gpu2cpu(grad_flat, recv_emb_splits, recv_splits=embedding_splits, name=bwd_tensor_name) with tf.device("/device:CPU:0"): # grad_flat_t.device = ._op.device (bps.alltoall_gpu2cpu as GPU op) # However the tensor is on CPU, so we fixed the tensor placement info with an identity. grad_flat_t = tf.identity(grad_flat_t) if enable_bps_bwd_cast == 16: if enable_bps_bwd_fake_cast: grad_flat_t = grad_flat_t * 0.0 grad_flat_t = tf.cast(grad_flat_t, tf.float32) else: logging.info( 'Enabled BPS for bwd embed alltoall with cast optimization') grad_flat_t = bps.alltoall(grad_flat, recv_emb_splits, recv_splits=embedding_splits, name="bwd_alltoall") if enable_bps_bwd_cast == 16: if enable_bps_bwd_fake_cast: grad_flat_t = grad_flat_t * 0.0 grad_flat_t = tf.cast(grad_flat_t, tf.float32) sent_grad_split_size = embedding_splits elif enable_custom_optimized_hvd: if enable_hvd_bwd_g2g: logging.info('Enabled optimized hvd for bwd embed alltoall g2g') with tf.device("/device:GPU:0"): grad_flat_t, sent_grad_split_size = hvd.alltoall( grad_flat, recv_emb_splits, recv_splits=embedding_splits, with_size=True, compression=FP16Compressor) with tf.device("/device:CPU:0"): grad_flat_t = tf.identity(grad_flat_t) else: logging.info('Enabled optimized hvd for bwd embed alltoall') grad_flat_t, sent_grad_split_size = hvd.alltoall( grad_flat, recv_emb_splits, recv_splits=embedding_splits, with_size=True, compression=FP16Compressor) if FLAGS.enable_alltoall_metrics and (self._shard_num > 1): # There is some issue with tf.compat.v1.summary on the horovod alltoall input, # using output instead. They are almost equivalent. shard_sizes = auxiliary_bundle["__sharding_sparse_fids__shard_sizes"] shard_sizes = tf.slice(shard_sizes, [1], [self._shard_num - 1]) total_alltoall_id_size = tf.reduce_sum(shard_sizes) recv_emb_splits = tf.slice(tf.identity(recv_emb_splits), [1], [self._shard_num - 1]) total_alltoall_emb_size = tf.reduce_sum(tf.identity(recv_emb_splits)) tmp_result = tf.reshape(total_alltoall_id_size, [1]) tmp_result2 = tf.reshape(total_alltoall_emb_size, [1]) total_id_dist = hvd.allgather(tmp_result) total_emb_dist = hvd.allgather(tmp_result2) with tf.device("/CPU:0"): tf.compat.v1.summary.histogram("Alltoall_id_dist", total_id_dist) tf.compat.v1.summary.histogram("Alltoall_emb_dist", total_emb_dist) if self._index == 0: min_idx = tf.math.argmin(shard_sizes) max_idx = tf.math.argmax(shard_sizes) min_idx_size = tf.reshape( tf.slice(shard_sizes, tf.reshape(min_idx, [-1]), [1]), []) max_idx_size = tf.reshape( tf.slice(shard_sizes, tf.reshape(max_idx, [-1]), [1]), []) with tf.device("/CPU:0"): tf.compat.v1.summary.histogram("Alltoall_id_splits", shard_sizes) tf.compat.v1.summary.scalar("Alltoall_id_sizes", total_alltoall_id_size) tf.compat.v1.summary.scalar("Alltoall_id_min_idx", min_idx) tf.compat.v1.summary.scalar("Alltoall_id_max_idx", max_idx) tf.compat.v1.summary.scalar("Alltoall_id_min_size", min_idx_size) tf.compat.v1.summary.scalar("Alltoall_id_max_size", max_idx_size) sent_grad_split_size = tf.slice(sent_grad_split_size, [1], [self._shard_num - 1]) total_alltoall_grad_size = tf.reduce_sum(sent_grad_split_size) with tf.device("/CPU:0"): tf.compat.v1.summary.histogram("Alltoall_grad_splits", sent_grad_split_size) tf.compat.v1.summary.scalar("Alltoall_grad_sizes", total_alltoall_grad_size) min_emb = tf.math.argmin(recv_emb_splits) max_emb = tf.math.argmax(recv_emb_splits) min_emb_size = tf.reshape( tf.slice(recv_emb_splits, tf.reshape(min_emb, [-1]), [1]), []) max_emb_size = tf.reshape( tf.slice(recv_emb_splits, tf.reshape(max_emb, [-1]), [1]), []) with tf.device("/CPU:0"): tf.compat.v1.summary.histogram("Alltoall_emb_splits", tf.identity(recv_emb_splits)) tf.compat.v1.summary.scalar("Alltoall_emb_sizes", tf.identity(total_alltoall_emb_size)) tf.compat.v1.summary.scalar("Alltoall_emb_min_idx", min_emb) tf.compat.v1.summary.scalar("Alltoall_emb_max_idx", max_emb) tf.compat.v1.summary.scalar("Alltoall_emb_min_size", min_emb_size) tf.compat.v1.summary.scalar("Alltoall_emb_max_size", max_emb_size) elif enable_hvd: if enable_hvd_bwd_g2g: logging.info('Enabled hvd for bwd embed alltoall g2g') with tf.device("/device:GPU:0"): grad_flat_t = hvd.alltoall(grad_flat, recv_emb_splits, name='hvd_bwd_a2a_g2g') #with tf.device("/device:CPU:0"): # grad_flat_t = tf.identity(grad_flat_t) else: logging.info('Enabled hvd for bwd embed alltoall') grad_flat_t = hvd.alltoall(grad_flat, recv_emb_splits) else: grad_flat_t = grad_flat auxiliary_bundle["__sharding_sparse_fids__grad_flat_t"] = grad_flat_t auxiliary_bundle.pop("__sharding_sparse_fids__shard_sizes") deq_auxiliary_bundle, q = enqueue_dicts_with_queue_return( auxiliary_bundle, capacity=int(self._queue_configs.get("enable_pipelined_bwda2a", 0)), queue_name="queue_fusedEmbGradA2A_to_sparseOptimize") if q: self.add_queue_hook(EnqueueHook(q)) auxiliary_bundle.update(deq_auxiliary_bundle) with tf.device("/GPU:0"): updated_table = self._table.fused_apply_gradient( auxiliary_bundle.pop("__sharding_sparse_fids__id_flat_t"), auxiliary_bundle.pop("__sharding_sparse_fids__indices"), auxiliary_bundle.pop("__sharding_sparse_fids__id_size_flat_t"), auxiliary_bundle.pop("__sharding_sparse_fids__grad_flat_t"), auxiliary_bundle.pop("__sharding_sparse_fids__id_offsets"), auxiliary_bundle.pop("__sharding_sparse_fids__emb_offsets"), auxiliary_bundle.pop("__sharding_sparse_fids__global_step"), auxiliary_bundle.pop("__sharding_sparse_fids__req_time"), self._shard_num) update_op = self._copy_with_new_table_gpu(updated_table) # TODO(zouxuan): add better tests to test the async optimize. if async_optimize_queue: self.add_queue_hook( AsyncPushHook(async_optimize_queue, update_op.as_op())) self._dependency_ops.append(async_optimize_queue.enqueue_op) # return self essentially means to call dependency_ops return self.as_op() else: return update_op.as_op() def _copy_with_new_table(self, new_tables: List[tf.Tensor]) -> PartitionedHashTable: copied = copy.copy(self) copied._tables = new_tables return copied def _copy_with_new_table_gpu(self, new_table: tf.Tensor) -> PartitionedHashTable: copied = copy.copy(self) copied._dependency_ops = copy.copy(self._dependency_ops) copied._table = new_table return copied def _native_hash_table_update( self, method_name: str, name_scope: str, update_data: AssignData) -> PartitionedHashTable: with tf.name_scope(name_scope): sharded_slot_to_id_and_value: Dict[int, Dict[str, Tuple[ tf.Tensor, tf.Tensor]]] = collections.defaultdict(dict) for slot, (id, value) in update_data.items(): index = tf.math.floormod(id, self._num_ps) split_ids = distribution_ops.split_by_indices(index, id, self._num_ps) split_values = distribution_ops.split_by_indices( index, value, self._num_ps) for i in range(self._num_ps): sharded_slot_to_id_and_value[i][slot] = (split_ids[i], split_values[i]) new_tables = [] for i in range(self._num_ps): new_tables.append( getattr(self._tables[i], method_name)(sharded_slot_to_id_and_value[i])) return self._copy_with_new_table(new_tables) def _update(self, method_name: str, name_scope: str, update_data: AssignData) -> PartitionedHashTable: if self._enable_gpu_emb: raise NotImplementedError with tf.name_scope(name_scope): new_tables = [] for i in range(self._num_ps): new_tables.append(getattr(self._tables[i], method_name)(update_data[i])) return self._copy_with_new_table(new_tables) def assign(self, data: AssignData) -> PartitionedHashTable: if self._enable_gpu_emb: raise NotImplementedError if self._use_native_multi_hash_table: return self._update("assign", "dmtht_a", data) else: return self._update("assign", "pht_assign", data) def assign_add(self, data: AssignData) -> PartitionedHashTable: if self._enable_gpu_emb: raise NotImplementedError if self._use_native_multi_hash_table: return self._update("assign_add", "dmtht_aa", data) else: return self._update("assign_add", "pht_assign_add", data) def flatten_layout( self, nested: Dict[str, Union[tf.Tensor, List[tf.Tensor]]]) -> List[tf.Tensor]: result = [] for name in sorted(self._feature_configs.out_configs): value = nested[name] if isinstance(value, (list, tuple)): assert all(isinstance(v, tf.Tensor) for v in value) result.extend(value) else: assert isinstance(value, tf.Tensor) result.append(value) return result def nest_layout( self, tensors: List[tf.Tensor]) -> Dict[str, Union[tf.Tensor, List[tf.Tensor]]]: offset, result = 0, {} for name in sorted(self._feature_configs.out_configs): conf = self._feature_configs.out_configs[name] if conf.out_type == OutType.NONE: sub_list = [] for _ in range(len(conf.slice_configs)): sub_list.append(tensors[offset]) offset += 1 result[name] = sub_list else: result[name] = tensors[offset] offset += 1 return result def add_queue_hook(self, hook): # Allow pipelined graph execution. if not getattr(self, "_local_queue_hooks", None): self._local_queue_hooks = [] self._local_queue_hooks.append(hook) def get_queue_hooks(self): hooks = copy.copy(getattr(self, "_local_queue_hooks", [])) if getattr(self, "_tables", None): hooks.extend( itertools.chain.from_iterable( [t.get_queue_hooks() for t in self._tables])) return hooks ================================================ FILE: monolith/native_training/distributed_ps_benchmark.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import time import os import shutil from absl import flags import tensorflow as tf from tensorflow.core.protobuf import cluster_pb2, config_pb2 from monolith.native_training import (distributed_ps, hash_filter_ops, \ hash_table_ops, utils) from monolith.native_training.runtime.hash_table import \ embedding_hash_table_pb2 PROFILE = False def _generate_config(servers, job_name=utils.PS_JOB_NAME): """Generates a config based on servers""" cluster_def = cluster_pb2.ClusterDef() job = cluster_def.job.add() job.name = job_name for i, server in enumerate(servers): job.tasks[i] = server.target[len('grpc://'):] return config_pb2.ConfigProto(cluster_def=cluster_def) def _get_vocab_hash_table_factory(dim: int): def factory(name_suffix: str, hash_filter: tf.Tensor): config = embedding_hash_table_pb2.EmbeddingHashTableConfig() config.cuckoo.SetInParent() segment = config.entry_config.segments.add() segment.dim_size = dim segment.opt_config.sgd.learning_rate = 1.0 segment.init_config.zeros.SetInParent() return hash_table_ops.hash_table_from_config(config=config, hash_filter=hash_filter, name_suffix=name_suffix) return factory class DistributedHashTableTest(tf.test.TestCase): def lookup(self, enable_dedup, real_run=True): ps_num = 10 servers = [ tf.distribute.Server.create_local_server() for _ in range(ps_num) ] server0 = servers[0] num_elements, dim = 1000000, 16 config = _generate_config(servers) if PROFILE and real_run: log_dir = "/tmp/distributed_ps_benchmark/lookup{}".format( "_dedup" if enable_dedup else "") if os.path.exists(log_dir): shutil.rmtree(log_dir) tf.profiler.experimental.start(log_dir) with tf.compat.v1.Session(server0.target, config=config) as sess: hash_filters = hash_filter_ops.create_hash_filters(ps_num, False) hash_table = distributed_ps.DistributedHashTable( ps_num, hash_filters, _get_vocab_hash_table_factory(dim)) hash_table = hash_table.assign_add( tf.constant([x for x in range(num_elements)], dtype=tf.int64), tf.constant([[x for _ in range(dim)] for x in range(num_elements)], dtype=tf.float32)) start = time.time() if real_run: values = hash_table.lookup(tf.constant( [x // 2 for x in range(num_elements)], dtype=tf.int64), use_multi_threads=True, enable_dedup=enable_dedup) values = sess.run(values) print("wall time(MT) enable_dedup={}: cost {}".format( str(enable_dedup), time.time() - start)) self.assertAllEqual( values, [[x // 2 for _ in range(dim)] for x in range(num_elements)]) else: sess.run(hash_table.as_op()) print("wall time(overhead): cost {}".format(time.time() - start)) if PROFILE and real_run: tf.profiler.experimental.stop() def apply_gradients(self, real_run=True): ps_num = 10 servers = [ tf.distribute.Server.create_local_server() for _ in range(ps_num) ] server0 = servers[0] num_elements, dim = 1000000, 16 config = _generate_config(servers) if PROFILE and real_run: log_dir = "/tmp/distributed_ps_benchmark/apply_gradients" if os.path.exists(log_dir): shutil.rmtree(log_dir) tf.profiler.experimental.start(log_dir) with tf.compat.v1.Session(server0.target, config=config) as sess: hash_filters = hash_filter_ops.create_hash_filters(ps_num, False) hash_table = distributed_ps.DistributedHashTable( ps_num, hash_filters, _get_vocab_hash_table_factory(dim)) hash_table = hash_table.assign_add( tf.constant([x for x in range(num_elements)], dtype=tf.int64), tf.constant([[1 for _ in range(dim)] for _ in range(num_elements)], dtype=tf.float32)) embeddings = hash_table.lookup(tf.constant( [x // 2 for x in range(num_elements)], dtype=tf.int64), use_multi_threads=True, enable_dedup=True) loss = tf.multiply(0.3, embeddings) grads = tf.gradients(loss, embeddings) start = time.time() if real_run: hash_table = hash_table.apply_gradients( tf.constant([x // 2 for x in range(num_elements)], dtype=tf.int64), grads[0]) if PROFILE: sess.run(hash_table.as_op()) else: values = hash_table.lookup(tf.constant( [x // 2 for x in range(num_elements)], dtype=tf.int64), use_multi_threads=True, enable_dedup=True) values = sess.run(values) self.assertAllClose( values, [[0.4 for _ in range(dim)] for x in range(num_elements)]) print("wall time(MT): cost {}".format(time.time() - start)) else: grads = sess.run(grads[0]) if not PROFILE: self.assertAllClose( grads, [[0.3 for _ in range(dim)] for x in range(num_elements)]) print("wall time(overhead): cost {}".format(time.time() - start)) if PROFILE and real_run: tf.profiler.experimental.stop() def test_lookup_overhead(self): self.lookup(False, False) def test_lookup(self): self.lookup(False) def test_lookup_dedup(self): self.lookup(True) def test_apply_gradients_overhead(self): self.apply_gradients(False) def test_apply_gradients(self): self.apply_gradients(True) if __name__ == "__main__": tf.compat.v1.disable_eager_execution() tf.test.main() ================================================ FILE: monolith/native_training/distributed_ps_factory.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Several methods to create hash tables.""" from typing import Dict, Iterable, List, Tuple import tensorflow as tf from idl.matrix.proto.example_pb2 import OutConfig from monolith.native_training import distributed_ps from monolith.native_training import distributed_ps_sync from monolith.native_training import entry from monolith.native_training import hash_filter_ops from monolith.native_training import hash_table_ops from monolith.native_training import multi_type_hash_table from monolith.native_training import multi_hash_table_ops import monolith.native_training.embedding_combiners as embedding_combiners class MultiHashTableFactory: def __init__(self, hash_filters, sync_clients): self._cc_dict = {} self.hash_filters = hash_filters self.sync_clients = sync_clients def __call__(self, idx: int, slot_to_config): k = id(slot_to_config) cc = self._cc_dict.get(k, None) if cc is None: cc = multi_hash_table_ops.convert_to_cached_config(slot_to_config) self._cc_dict[k] = cc return multi_hash_table_ops.MultiHashTable.from_cached_config( cc=cc, hash_filter=self.hash_filters[idx], sync_client=self.sync_clients[idx], name_suffix=str(idx)) def create_in_worker_multi_type_hash_table( shard_num: int, slot_to_config: Dict[str, entry.HashTableConfigInstance], hash_filter: tf.Tensor, sync_client: tf.Tensor = None, queue_configs: Dict[str, int] = None, ): """ Creates a in worker multi-type hash table factory. Args: shard_num: the number of shards for distributing hash tables. """ # The logic here is # merged_slots -> distributed_fused_multitype_table -> alltoall -> hash_table def distributed_multi_type_table_factory(merged_slot_to_config): def multi_type_table_factory(idx): def table_factory(name_suffix, config): return hash_table_ops.hash_table_from_config( config=config, hash_filter=hash_filter, name_suffix="_".join([name_suffix, str(idx)]), sync_client=sync_client) return multi_type_hash_table.MultiTypeHashTable(merged_slot_to_config, table_factory) return distributed_ps_sync.DistributedMultiTypeHashTableMpi( shard_num, multi_type_table_factory, queue_configs) return multi_type_hash_table.MergedMultiTypeHashTable( slot_to_config, distributed_multi_type_table_factory) def create_multi_type_hash_table( num_ps: int, slot_to_config: Dict[str, entry.HashTableConfigInstance], hash_filters: List[tf.Tensor], sync_clients: List[tf.Tensor] = None, reduce_network_packets: bool = False, max_rpc_deadline_millis: int = 30, ): """Create a distributed multi type hash table. Args: reduce_network_packets - if True, it will compact all tensors locally so ps will get less load. Useful when there are a lot of workers. """ if num_ps and sync_clients: assert num_ps == len( sync_clients ), "Number of PS should be equal to number of sync clients, while got {} vs {}".format( num_ps, len(sync_clients)) if not sync_clients: sync_clients = [None] * max(num_ps, 1) if num_ps == 0: def factory(name_suffix, config): return hash_table_ops.hash_table_from_config(config, hash_filter=hash_filters[0], name_suffix=name_suffix, sync_client=sync_clients[0]) def multi_type_factory(merged_slot_to_config): return multi_type_hash_table.MultiTypeHashTable(merged_slot_to_config, factory) return multi_type_hash_table.MergedMultiTypeHashTable( slot_to_config, multi_type_factory) elif not reduce_network_packets: # The logic here is # dedup_slots -> multi hash table -> distributed_hash_table -> hash_table # | worker | ps | def multi_type_factory(merged_slot_to_config): def distributed_factory(name_suffix, config): def factory(idx, config_on_ps): return hash_table_ops.hash_table_from_config( config_on_ps, hash_filter=hash_filters[idx], name_suffix="_".join([name_suffix, str(idx)]), sync_client=sync_clients[idx]) return distributed_ps.DistributedHashTable(num_ps, config, factory) return multi_type_hash_table.MultiTypeHashTable(merged_slot_to_config, distributed_factory) return multi_type_hash_table.MergedMultiTypeHashTable( slot_to_config, multi_type_factory) else: # The logic here is # dedup_slots -> distributed multi hash table -> multi hash table -> hash table # | worker | ps | def distributed_multi_type_factory(merged_slot_to_config): def multi_type_factory(idx: int, slot_to_config_on_ps): def factory(name_suffix, config): return hash_table_ops.hash_table_from_config( config, hash_filter=hash_filters[idx], name_suffix="_".join([name_suffix, str(idx)]), sync_client=sync_clients[idx]) return multi_type_hash_table.MultiTypeHashTable(slot_to_config_on_ps, factory) return distributed_ps.DistributedMultiTypeHashTable( num_ps, merged_slot_to_config, multi_type_factory, max_rpc_deadline_millis=max_rpc_deadline_millis) return multi_type_hash_table.MergedMultiTypeHashTable( slot_to_config, distributed_multi_type_factory) def create_native_multi_hash_table( num_ps: int, slot_to_config: Dict[str, entry.HashTableConfigInstance], hash_filters: List[tf.Tensor], sync_clients: List[tf.Tensor] = None, max_rpc_deadline_millis: int = 30, ): """Create a distributed native multi hash table.""" if num_ps and sync_clients: assert num_ps == len( sync_clients ), "Number of PS should be equal to number of sync clients, while got {} vs {}".format( num_ps, len(sync_clients)) if not sync_clients: sync_clients = [None] * max(num_ps, 1) if num_ps == 0: return multi_hash_table_ops.MultiHashTable.from_configs( configs=slot_to_config, hash_filter=hash_filters[0], sync_client=sync_clients[0]) else: # The logic here is # slots -> distributed multi hash table -> multi hash table # | worker | ps | return distributed_ps.DistributedMultiTypeHashTable( num_ps, slot_to_config, MultiHashTableFactory(hash_filters, sync_clients), max_rpc_deadline_millis=max_rpc_deadline_millis) def create_in_worker_native_multi_hash_table( shard_num: int, slot_to_config: Dict[str, entry.HashTableConfigInstance], hash_filter: tf.Tensor, sync_client: tf.Tensor = None, queue_configs: Dict[str, int] = None, ): # The logic here is # DistributedMultiTypeHashTableMpi -> alltoall -> multi_hash_table def table_factory(idx): return multi_hash_table_ops.MultiHashTable.from_configs( configs=slot_to_config, hash_filter=hash_filter, sync_client=sync_client, name_suffix=str(idx)) return distributed_ps_sync.DistributedMultiTypeHashTableMpi( shard_num, table_factory, queue_configs) def create_partitioned_hash_table( num_ps: int, use_native_multi_hash_table: bool, max_rpc_deadline_millis: int = 30, hash_filters: List[tf.Tensor] = None, sync_clients: List[tf.Tensor] = None, enable_gpu_emb: bool = False, queue_configs: Dict[str, int] = None, ) -> distributed_ps.PartitionedHashTable: num_ps_tmp = num_ps if num_ps > 0 else 1 if hash_filters is None: hash_filters = [None] * num_ps_tmp if sync_clients is None: sync_clients = [None] * num_ps_tmp if use_native_multi_hash_table: # assert enable_gpu_emb == False, "gpu_emb not imple native_multi_hash_table" multi_type_factory = MultiHashTableFactory(hash_filters, sync_clients) else: def multi_type_factory(idx: int, slot_to_config_on_ps): def factory(name_suffix, config): name_suffix = name_suffix if num_ps == 0 else "_".join( [name_suffix, str(idx)]) return hash_table_ops.hash_table_from_config( config, hash_filter=hash_filters[idx], name_suffix=name_suffix, sync_client=sync_clients[idx]) return multi_type_hash_table.MultiTypeHashTable(slot_to_config_on_ps, factory) return distributed_ps.PartitionedHashTable(num_ps, multi_type_factory, use_native_multi_hash_table, max_rpc_deadline_millis, queue_configs=queue_configs) ================================================ FILE: monolith/native_training/distributed_ps_factory_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os os.environ["MONOLITH_WITH_HOROVOD"] = "True" import tensorflow as tf from monolith.native_training import distributed_ps_factory from monolith.native_training import hash_filter_ops from monolith.native_training import test_utils import horovod.tensorflow as hvd def _get_test_slot_to_config(): config = test_utils.generate_test_hash_table_config(4, learning_rate=0.1) return { "1": config, "2": config, } def _get_test_hash_filters(num): return hash_filter_ops.create_hash_filters(num, False) class FactoryTest(tf.test.TestCase): # Since factory itself is very difficult to test. Here we just perform grammar check. def test_create_in_worker_multi_type_hash_table(self): hvd.init() distributed_ps_factory.create_in_worker_multi_type_hash_table( 1, _get_test_slot_to_config(), _get_test_hash_filters(0)[0]) def test_create_in_worker_multi_type_hash_table_with_reduced_alltoall(self): hvd.init() distributed_ps_factory.create_in_worker_multi_type_hash_table( 1, _get_test_slot_to_config(), _get_test_hash_filters(0)[0]) def test_create_multi_type_hash_table_0_ps(self): distributed_ps_factory.create_multi_type_hash_table( 0, _get_test_slot_to_config(), _get_test_hash_filters(0)) def test_create_multi_type_hash_table_2_ps(self): servers, config = test_utils.create_test_ps_cluster(2) with tf.compat.v1.Session(servers[0].target, config=config): distributed_ps_factory.create_multi_type_hash_table( 2, _get_test_slot_to_config(), _get_test_hash_filters(2)) def test_create_multi_type_hash_table_2_ps_with_reduced_packets(self): servers, config = test_utils.create_test_ps_cluster(2) with tf.compat.v1.Session(servers[0].target, config=config): distributed_ps_factory.create_multi_type_hash_table( 2, _get_test_slot_to_config(), _get_test_hash_filters(2), reduce_network_packets=True) def test_create_native_multi_hash_table_0_ps(self): distributed_ps_factory.create_native_multi_hash_table( 0, _get_test_slot_to_config(), _get_test_hash_filters(0)) def test_create_native_multi_hash_table_2_ps(self): servers, config = test_utils.create_test_ps_cluster(2) with tf.compat.v1.Session(servers[0].target, config=config): distributed_ps_factory.create_native_multi_hash_table( 2, _get_test_slot_to_config(), _get_test_hash_filters(2)) if __name__ == "__main__": tf.compat.v1.disable_v2_behavior() tf.test.main() ================================================ FILE: monolith/native_training/distributed_ps_sync.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations import copy import collections import os from typing import Callable, Dict, Tuple, List, Union from absl import flags, logging import tensorflow as tf from monolith.native_training import multi_type_hash_table from monolith.native_training import multi_hash_table_ops from monolith.native_training import distribution_ops from monolith.native_training.prefetch_queue import \ enqueue_dicts_with_queue_return, AsyncPushHook, EnqueueHook from monolith.native_training import feature_utils enable_hvd = os.getenv("MONOLITH_WITH_HOROVOD") enable_custom_optimized_hvd = os.getenv("MONOLITH_WITH_OPTIMIZED_HOROVOD") if enable_hvd != None: import horovod.tensorflow as hvd from horovod.tensorflow.compression import FP16Compressor enable_hvd_fid_g2g = int(os.getenv("MONOLITH_WITH_HOROVOD_FID_G2G", 1)) enable_hvd_fwd_g2g = int(os.getenv("MONOLITH_WITH_HOROVOD_FWD_G2G", 1)) enable_hvd_bwd_g2g = int(os.getenv("MONOLITH_WITH_HOROVOD_BWD_G2G", 1)) enable_bps = int(os.getenv("MONOLITH_WITH_BYTEPS", "0")) enable_bps_fid = int(os.getenv("MONOLITH_WITH_BYTEPS_FID", "1")) enable_bps_fwd = int(os.getenv("MONOLITH_WITH_BYTEPS_FWD", "1")) enable_bps_bwd = int(os.getenv("MONOLITH_WITH_BYTEPS_BWD", "1")) # MONOLITH_WITH_BYTEPS_BWD_CAST # 32: fp32 for embed grad (default) # 16: fp16 for embed grad enable_bps_bwd_cast = int(os.getenv("MONOLITH_WITH_BYTEPS_BWD_CAST", "32")) enable_bps_bwd_fake_cast = int( os.getenv("MONOLITH_WITH_BYTEPS_BWD_FAKE_CAST", "0")) # enable forward alltoall with GDR enable_bps_fwd_gdr = int(os.getenv("MONOLITH_WITH_BYTEPS_FWD_GDR", "0")) enable_bps_fwd_gdr_g2g = int(os.getenv("MONOLITH_WITH_BYTEPS_FWD_GDR_G2G", "0")) # enable backward alltoall with GDR enable_bps_bwd_gdr = int(os.getenv("MONOLITH_WITH_BYTEPS_BWD_GDR", "0")) enable_bps_bwd_gdr_g2g = int(os.getenv("MONOLITH_WITH_BYTEPS_BWD_GDR_G2G", "0")) FLAGS = flags.FLAGS flags.DEFINE_bool("enable_alltoall_metrics", default=False, help=("Whether to turn on alltoall detailed stats.")) flags.DEFINE_string( "enable_alltoall_metrics_for_slot", default=None, help="ID of the merged slot to summary alltoall stats. For example:" "(af17bbdba2be72580bf5c8c43975078c for merged slot of fc_clk_ads_4d)") class DistributedMultiTypeHashTableMpi( multi_type_hash_table.BaseMultiTypeHashTable): def __init__( self, shard_num: int, table_factory: Callable[[int], Union[ # when use_native_multi_hash_table=False multi_type_hash_table.MultiTypeHashTable, # when use_native_multi_hash_table=True multi_hash_table_ops.MultiHashTable]], queue_configs: Dict[str, int] = None): self._shard_num = shard_num if enable_bps: import byteps.tensorflow as bps assert bps.size() == self._shard_num self._index = bps.rank() else: assert hvd.size() == self._shard_num self._index = hvd.rank() self._table = table_factory(self._index) self._output_dims = self._table.get_table_dim_sizes() self._queue_configs = queue_configs or {} self._dependency_ops = [] def lookup(self, slot_to_id: Dict[str, tf.Tensor], auxiliary_bundle: Dict[str, tf.Tensor | tf.RaggedTensor], early_reorder_indicies_res_pack=None) -> Dict[str, tf.Tensor]: if enable_bps: import byteps.tensorflow as bps sorted_slot_keys = sorted(slot_to_id.keys()) slot_num = len(sorted_slot_keys) assert early_reorder_indicies_res_pack is not None, \ "Support for reorder_fids_in_data_pipeline=False is dropped. Please set it to True" all_fids, shard_sizes, sharded_slot_sizes, emb_offset_sz, fused_embedding_offsets, req_time = \ early_reorder_indicies_res_pack if FLAGS.enable_alltoall_metrics: slot_name = FLAGS.enable_alltoall_metrics_for_slot if slot_name and slot_name in sorted_slot_keys: m = sorted_slot_keys.index(slot_name) with tf.device("/CPU:0"): tf.compat.v1.summary.scalar( "{}_size".format(slot_name), tf.reduce_sum( tf.gather( sharded_slot_sizes, [m + i * slot_num for i in range(self._shard_num)]))) with tf.device("/CPU:0"): tf.compat.v1.summary.scalar("all_fids_size", tf.size(all_fids)) tf.compat.v1.summary.histogram("shard_sizes", shard_sizes) tf.compat.v1.summary.histogram("sharded_slot_sizes", sharded_slot_sizes) # We exchange the flattened IDs and their splits. # M: num_of_ids, # N: num_of_shards, # K: num_of_merged_tables, # E: num_of_total_embedding_dim. # id_flat_t: [M], id_flat_split_t: [N] # id_size_flat_t: [K*N], id_size_flat_split_t: [N] if enable_bps and enable_bps_fid: logging.info('Enabled BPS for fid alltoall') id_flat_t, id_flat_split_t = bps.alltoall(all_fids, splits=shard_sizes, with_size=True, name='fid_data') # We also add the flat_t sizes. id_size_flat_t = bps.alltoall(sharded_slot_sizes, splits=[slot_num] * self._shard_num, recv_splits=([slot_num] * self._shard_num), name='fid_size') elif enable_custom_optimized_hvd: id_flat_t, id_flat_split_t = hvd.alltoall(all_fids, splits=shard_sizes, with_size=True) # We also add the flat_t sizes. id_size_flat_t = hvd.alltoall(sharded_slot_sizes, splits=[slot_num] * self._shard_num, recv_splits=[slot_num] * self._shard_num) elif enable_hvd: if enable_hvd_fid_g2g: logging.info('Enabled hvd for fid alltoall g2g') with tf.device("/device:GPU:0"): id_flat_t = hvd.alltoall(all_fids, splits=shard_sizes) id_size_flat_t = hvd.alltoall(sharded_slot_sizes, splits=[slot_num] * self._shard_num) else: id_flat_t = hvd.alltoall(all_fids, splits=shard_sizes) id_size_flat_t = hvd.alltoall(sharded_slot_sizes, splits=[slot_num] * self._shard_num) auxiliary_bundle["shard_sizes"] = shard_sizes with tf.device("/device:GPU:0"): auxiliary_bundle["fused_embedding_offsets"] = tf.split( fused_embedding_offsets, emb_offset_sz, axis=0, name="concat_emb_offsets_split") auxiliary_bundle["emb_offset_sz"] = emb_offset_sz auxiliary_bundle["id_flat_t"] = id_flat_t # Note: id_flat_split_t is not being used in later computation. auxiliary_bundle["id_size_flat_t"] = id_size_flat_t # fused_embeddings: [E], fused_splits: [N] # id_offsets: [K*N], emb_offsets: [K*N] with tf.device("/GPU:0"): fused_embeddings, embedding_splits, id_offsets, emb_offsets, indices = \ self._table.fused_lookup(id_flat_t, id_size_flat_t, self._shard_num, req_time) if FLAGS.enable_alltoall_metrics: with tf.device("/CPU:0"): tf.compat.v1.summary.histogram("fused_embedding_splits", embedding_splits) auxiliary_bundle["fused_embeddings"] = fused_embeddings auxiliary_bundle["embedding_splits"] = embedding_splits auxiliary_bundle["id_offsets"] = id_offsets auxiliary_bundle["emb_offsets"] = emb_offsets auxiliary_bundle["indices"] = indices auxiliary_bundle["recv_emb_splits"] = tf.reshape( tf.matmul( tf.reshape(sharded_slot_sizes, [self._shard_num, slot_num]), tf.expand_dims(tf.constant(self._output_dims, dtype=tf.int32), -1) # [slot_num, 1] ), [-1] # flatten ) auxiliary_bundle["recv_embeddings_size"] = tf.reduce_sum( auxiliary_bundle["recv_emb_splits"]) auxiliary_bundle, queue = enqueue_dicts_with_queue_return( auxiliary_bundle, capacity=int(self._queue_configs.get("enable_pipelined_fwda2a", 0)), queue_name="queue_lookup_to_fusedEmbA2A") if queue: self.add_queue_hook(EnqueueHook(queue)) fused_embeddings = auxiliary_bundle.pop("fused_embeddings") embedding_splits = auxiliary_bundle["embedding_splits"] recv_emb_splits = auxiliary_bundle["recv_emb_splits"] # recv_embeddings: [E'], recv_embedding_sizes: [N] if enable_bps and enable_bps_fwd: if enable_bps_fwd_gdr: if enable_bps_fwd_gdr_g2g: logging.info('Enabled BPS for fwd embed alltoall GDR (G2G)') with tf.device("/device:GPU:0"): fused_embeddings_gpu = fused_embeddings with tf.device("/device:GPU:0"): recv_embeddings = bps.alltoall(fused_embeddings_gpu, embedding_splits, recv_splits=recv_emb_splits, name="fwd_alltoall_g2g") else: logging.info('Enabled BPS for fwd embed alltoall GDR (C2G)') with tf.device("/device:GPU:0"): recv_embeddings = bps.alltoall_cpu2gpu(fused_embeddings, embedding_splits, recv_splits=recv_emb_splits, name="fwd_alltoall_c2g") else: logging.info('Enabled BPS for fwd embed alltoall') recv_embeddings = bps.alltoall(fused_embeddings, embedding_splits, recv_splits=recv_emb_splits, name="fwd_alltoall") elif enable_custom_optimized_hvd: if enable_hvd_fwd_g2g: logging.info('Enabled optimized hvd for fwd embed alltoall g2g') with tf.device("/device:GPU:0"): recv_embeddings = hvd.alltoall( fused_embeddings, embedding_splits, recv_splits=recv_emb_splits, ) else: logging.info('Enabled optimized hvd for fwd embed alltoall') recv_embeddings = hvd.alltoall( fused_embeddings, embedding_splits, recv_splits=recv_emb_splits, ) elif enable_hvd: if enable_hvd_fwd_g2g: logging.info('Enabled hvd for fwd embed alltoall g2g') with tf.device("/device:GPU:0"): recv_embeddings = hvd.alltoall(fused_embeddings, embedding_splits, name='hvd_fwd_a2a_g2g') else: logging.info('Enabled hvd for fwd embed alltoall') recv_embeddings = hvd.alltoall(fused_embeddings, embedding_splits, name='hvd_fwd_a2a') auxiliary_bundle["recv_embeddings"] = recv_embeddings with tf.device("/device:GPU:0"): # GPUQueue: Pass to GPU at Enqueue auxiliary_bundle["recv_embeddings"] = tf.identity( auxiliary_bundle["recv_embeddings"]) auxiliary_bundle, queue = enqueue_dicts_with_queue_return( auxiliary_bundle, capacity=int(self._queue_configs.get("embedding_prefetch_capacity", 0)), queue_name="queue_fusedEmbA2A_to_fusedGather") if queue: self.add_queue_hook(EnqueueHook(queue)) # auxiliary_bundle includes all dequeued tensors, if a prefetch queue in-between. recv_embeddings = auxiliary_bundle.pop("recv_embeddings") fused_embedding_offsets = auxiliary_bundle["fused_embedding_offsets"] with tf.device("/device:GPU:0"): outputs = distribution_ops.fused_gather_embeddings_by_input( recv_embeddings, fused_embedding_offsets, self._output_dims) # a.k.a merged_slot_to_embedding slot_to_embedding = {k: outputs[i] for i, k in enumerate(sorted_slot_keys)} return slot_to_embedding, auxiliary_bundle # TODO(zouxuan): assign is broken. def assign( self, slot_to_id_and_value: Dict[str, Tuple[tf.Tensor, tf.Tensor]] ) -> DistributedMultiTypeHashTableMpi: raise NotImplementedError # TODO(zouxuan): assign_add is broken. def assign_add( self, slot_to_id_and_value: Dict[str, Tuple[tf.Tensor, tf.Tensor]] ) -> DistributedMultiTypeHashTableMpi: raise NotImplementedError def reinitialize( self, slot: str, ids: tf.Tensor) -> Tuple[DistributedMultiTypeHashTableMpi, tf.Tensor]: raise NotImplementedError( "DistributedMultiTypeHashTableMpi dost not support reinitialize!") # Apply_gradients uses fused update. def apply_gradients(self, slot_to_grad: Dict[str, tf.Tensor], auxiliary_bundle: Dict[str, tf.Tensor], global_step: tf.Tensor, req_time: tf.Tensor = None, scale: tf.Tensor = 1) -> DistributedMultiTypeHashTableMpi: auxiliary_bundle['global_step'] = global_step if req_time is None: req_time = tf.constant(0, dtype=tf.int64) auxiliary_bundle["req_time"] = req_time sorted_slot_keys = sorted(slot_to_grad.keys()) sorted_grads = [slot_to_grad[k] for k in sorted_slot_keys] recv_embeddings_size = auxiliary_bundle.pop("recv_embeddings_size") fused_embedding_offsets = auxiliary_bundle.pop("fused_embedding_offsets") # make this depend on fusion op before allreduce, # so allreduce can be dispatched before alltoall with tf.control_dependencies(feature_utils.control_ops): with tf.device("/device:GPU:0"): grad_flat = distribution_ops.fused_gather_embeddings_by_input_gradient( recv_embeddings_size, sorted_grads, fused_embedding_offsets, self._output_dims, scale) with tf.device("/device:GPU:0"): if enable_bps_bwd_cast == 16: auxiliary_bundle['grad_flat'] = tf.cast(grad_flat, tf.float16) else: auxiliary_bundle['grad_flat'] = tf.identity(grad_flat) # Here we add a queue to let the optimize stage non-blocking and # interleaving at the next round of update. auxiliary_bundle, async_optimize_queue = enqueue_dicts_with_queue_return( auxiliary_bundle, capacity=int(self._queue_configs.get("enable_async_optimize", 0)), queue_name="queue_fusedGatherGrad_to_fusedEmbGradA2A") grad_flat = auxiliary_bundle.pop("grad_flat") recv_emb_splits = auxiliary_bundle.pop("recv_emb_splits") embedding_splits = auxiliary_bundle.pop("embedding_splits") if enable_bps and enable_bps_bwd: import byteps.tensorflow as bps from byteps.tensorflow.compression import FP16Compressor as BPSFP16Compressor if enable_bps_bwd_gdr: with tf.device("/device:GPU:0"), tf.name_scope("bps_bwd_alltoall"): if enable_bps_bwd_gdr_g2g: logging.info('Enabled BPS for bwd embed alltoall GDR (G2G)') bwd_tensor_name = "bwd_alltoall_g2g" grad_flat_t = bps.alltoall(grad_flat, recv_emb_splits, recv_splits=embedding_splits, name=bwd_tensor_name) if enable_bps_bwd_cast == 16: # do cast on GPU if enable_bps_bwd_fake_cast: grad_flat_t = grad_flat_t * 0.0 grad_flat_t = tf.cast(grad_flat_t, tf.float32) with tf.device("/device:CPU:0"): grad_flat_t = tf.identity(grad_flat_t) else: logging.info('Enabled BPS for bwd embed alltoall GDR (G2C)') bwd_tensor_name = "bwd_alltoall_g2c" grad_flat_t = bps.alltoall_gpu2cpu(grad_flat, recv_emb_splits, recv_splits=embedding_splits, name=bwd_tensor_name) with tf.device("/device:CPU:0"): # grad_flat_t.device = ._op.device (bps.alltoall_gpu2cpu as GPU op) # However the tensor is on CPU, so we fixed the tensor placement info with an identity. grad_flat_t = tf.identity(grad_flat_t) if enable_bps_bwd_cast == 16: if enable_bps_bwd_fake_cast: grad_flat_t = grad_flat_t * 0.0 grad_flat_t = tf.cast(grad_flat_t, tf.float32) else: logging.info( 'Enabled BPS for bwd embed alltoall with cast optimization') grad_flat_t = bps.alltoall(grad_flat, recv_emb_splits, recv_splits=embedding_splits, name="bwd_alltoall") if enable_bps_bwd_cast == 16: if enable_bps_bwd_fake_cast: grad_flat_t = grad_flat_t * 0.0 grad_flat_t = tf.cast(grad_flat_t, tf.float32) sent_grad_split_size = embedding_splits elif enable_custom_optimized_hvd: if enable_hvd_bwd_g2g: logging.info('Enabled optimized hvd for bwd embed alltoall g2g') with tf.device("/device:GPU:0"): grad_flat_t, sent_grad_split_size = hvd.alltoall( grad_flat, recv_emb_splits, recv_splits=embedding_splits, with_size=True, compression=FP16Compressor) with tf.device("/device:CPU:0"): grad_flat_t = tf.identity(grad_flat_t) else: logging.info('Enabled optimized hvd for bwd embed alltoall') grad_flat_t, sent_grad_split_size = hvd.alltoall( grad_flat, recv_emb_splits, recv_splits=embedding_splits, with_size=True, compression=FP16Compressor) if FLAGS.enable_alltoall_metrics and (self._shard_num > 1): # There is some issue with tf.compat.v1.summary on the horovod alltoall input, # using output instead. They are almost equivalent. shard_sizes = auxiliary_bundle["shard_sizes"] shard_sizes = tf.slice(shard_sizes, [1], [self._shard_num - 1]) total_alltoall_id_size = tf.reduce_sum(shard_sizes) recv_emb_splits = tf.slice(tf.identity(recv_emb_splits), [1], [self._shard_num - 1]) total_alltoall_emb_size = tf.reduce_sum(tf.identity(recv_emb_splits)) tmp_result = tf.reshape(total_alltoall_id_size, [1]) tmp_result2 = tf.reshape(total_alltoall_emb_size, [1]) total_id_dist = hvd.allgather(tmp_result) total_emb_dist = hvd.allgather(tmp_result2) with tf.device("/CPU:0"): tf.compat.v1.summary.histogram("Alltoall_id_dist", total_id_dist) tf.compat.v1.summary.histogram("Alltoall_emb_dist", total_emb_dist) if self._index == 0: min_idx = tf.math.argmin(shard_sizes) max_idx = tf.math.argmax(shard_sizes) min_idx_size = tf.reshape( tf.slice(shard_sizes, tf.reshape(min_idx, [-1]), [1]), []) max_idx_size = tf.reshape( tf.slice(shard_sizes, tf.reshape(max_idx, [-1]), [1]), []) with tf.device("/CPU:0"): tf.compat.v1.summary.histogram("Alltoall_id_splits", shard_sizes) tf.compat.v1.summary.scalar("Alltoall_id_sizes", total_alltoall_id_size) tf.compat.v1.summary.scalar("Alltoall_id_min_idx", min_idx) tf.compat.v1.summary.scalar("Alltoall_id_max_idx", max_idx) tf.compat.v1.summary.scalar("Alltoall_id_min_size", min_idx_size) tf.compat.v1.summary.scalar("Alltoall_id_max_size", max_idx_size) sent_grad_split_size = tf.slice(sent_grad_split_size, [1], [self._shard_num - 1]) total_alltoall_grad_size = tf.reduce_sum(sent_grad_split_size) with tf.device("/CPU:0"): tf.compat.v1.summary.histogram("Alltoall_grad_splits", sent_grad_split_size) tf.compat.v1.summary.scalar("Alltoall_grad_sizes", total_alltoall_grad_size) min_emb = tf.math.argmin(recv_emb_splits) max_emb = tf.math.argmax(recv_emb_splits) min_emb_size = tf.reshape( tf.slice(recv_emb_splits, tf.reshape(min_emb, [-1]), [1]), []) max_emb_size = tf.reshape( tf.slice(recv_emb_splits, tf.reshape(max_emb, [-1]), [1]), []) with tf.device("/CPU:0"): tf.compat.v1.summary.histogram("Alltoall_emb_splits", tf.identity(recv_emb_splits)) tf.compat.v1.summary.scalar("Alltoall_emb_sizes", tf.identity(total_alltoall_emb_size)) tf.compat.v1.summary.scalar("Alltoall_emb_min_idx", min_emb) tf.compat.v1.summary.scalar("Alltoall_emb_max_idx", max_emb) tf.compat.v1.summary.scalar("Alltoall_emb_min_size", min_emb_size) tf.compat.v1.summary.scalar("Alltoall_emb_max_size", max_emb_size) elif enable_hvd: if enable_hvd_bwd_g2g: logging.info('Enabled hvd for bwd embed alltoall g2g') with tf.device("/device:GPU:0"): grad_flat_t = hvd.alltoall(grad_flat, recv_emb_splits, name='hvd_bwd_a2a_g2g') else: logging.info('Enabled hvd for bwd embed alltoall') grad_flat_t = hvd.alltoall(grad_flat, recv_emb_splits) else: grad_flat_t = grad_flat auxiliary_bundle["grad_flat_t"] = grad_flat_t auxiliary_bundle.pop("shard_sizes") auxiliary_bundle, q = enqueue_dicts_with_queue_return( auxiliary_bundle, capacity=int(self._queue_configs.get("enable_pipelined_bwda2a", 0)), queue_name="queue_fusedEmbGradA2A_to_sparseOptimize") if q: self.add_queue_hook(EnqueueHook(q)) with tf.control_dependencies(feature_utils.dense_opt_ops): with tf.device("/GPU:0"): updated_table = self._table.fused_apply_gradient( auxiliary_bundle.pop("id_flat_t"), auxiliary_bundle.pop("indices"), auxiliary_bundle.pop("id_size_flat_t"), auxiliary_bundle.pop("grad_flat_t"), auxiliary_bundle.pop("id_offsets"), auxiliary_bundle.pop("emb_offsets"), auxiliary_bundle.pop("global_step"), auxiliary_bundle.pop("req_time"), self._shard_num) update_op = self._copy_with_new_table(updated_table) # TODO(zouxuan): add better tests to test the async optimize. if async_optimize_queue: self.add_queue_hook(AsyncPushHook(async_optimize_queue, update_op.as_op())) self._dependency_ops.append(async_optimize_queue.enqueue_op) # return self essentially means to call dependency_ops return self else: return update_op def _update(self, method: str, slot_to_id_and_value: Dict[str, Tuple[tf.Tensor, tf.Tensor]], *args, **kwargs): raise NotImplementedError def as_op(self, name=None): with tf.control_dependencies(self._dependency_ops): return self._table.as_op(name) def get_table_dim_sizes(self): return self._tables[0].get_table_dim_sizes() def _copy_with_new_table(self, table: multi_type_hash_table.BaseMultiTypeHashTable): copied = copy.copy(self) copied._dependency_ops = copy.copy(self._dependency_ops) copied._table = table return copied ================================================ FILE: monolith/native_training/distributed_ps_sync_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os from absl.testing import parameterized os.environ["MONOLITH_WITH_HOROVOD"] = "True" import tensorflow as tf from monolith.native_training import distributed_ps from monolith.native_training import distribution_ops from monolith.native_training import distributed_ps_sync from monolith.native_training import hash_table_ops from monolith.native_training import learning_rate_functions from monolith.native_training import multi_type_hash_table from monolith.native_training import test_utils from monolith.native_training.multi_hash_table_ops import MultiHashTable import horovod.tensorflow as hvd def gen_test_configs(): return { "1": test_utils.generate_test_hash_table_config(1, learning_rate=1.0), "2": test_utils.generate_test_hash_table_config( 2, learning_rate=learning_rate_functions.PolynomialDecay( initial_learning_rate=1.0, decay_steps=10, end_learning_rate=2.0)) } def multi_type_table_factory(idx: int): def table_factory(name_suffix: str, config): return hash_table_ops.hash_table_from_config(config, name_suffix=name_suffix + str(idx)) return multi_type_hash_table.MultiTypeHashTable(gen_test_configs(), table_factory) def native_multi_hash_table_factory(idx: int): return MultiHashTable.from_configs(configs=gen_test_configs(), name_suffix=str(idx)) class DistributedMultiTypeHashTableMpiTest(tf.test.TestCase, parameterized.TestCase): @parameterized.parameters([(False,)]) def testBasic(self, use_native_multi_hash_table): table_factory = (native_multi_hash_table_factory if use_native_multi_hash_table else multi_type_table_factory) hvd.init() with self.session() as sess: global_step = tf.compat.v1.train.get_or_create_global_step() self.evaluate(tf.compat.v1.global_variables_initializer()) self.evaluate(tf.compat.v1.assign(global_step, 0)) table = distributed_ps_sync.DistributedMultiTypeHashTableMpi( hvd.size(), table_factory) slot_to_ids = { "1": tf.constant([1, 1], dtype=tf.int64), "2": tf.constant([2], dtype=tf.int64) } # First lookup, nothing exists, returns 0 simply. reordred = distribution_ops.fused_reorder_by_indices( [slot_to_ids["1"], slot_to_ids["2"]], 1, [1, 2]) reordred = (*reordred, None) # add timestamp emb, auxiliary_bundle = table.lookup(slot_to_ids, {}, reordred) emb_value = sess.run(emb) self.assertAllClose(emb_value["1"], [[0], [0]]) self.assertAllClose(emb_value["2"], [[0, 0]]) updated_table = table.apply_gradients( { "1": tf.constant([[0.5], [0.5]], dtype=tf.float32), "2": tf.constant([[0.5, 1.0]], dtype=tf.float32) }, auxiliary_bundle=auxiliary_bundle, global_step=tf.constant(0, dtype=tf.int64), req_time=tf.constant(0, dtype=tf.int64)) emb, auxiliary_bundle = updated_table.lookup(slot_to_ids, {}, reordred) emb_value = sess.run(emb) sum_multiplier = hvd.size() self.assertAllClose(emb_value["1"], [[-1 * sum_multiplier], [-1 * sum_multiplier]]) self.assertAllClose(emb_value["2"], [[-0.5 * sum_multiplier, -1.0 * sum_multiplier]]) if __name__ == "__main__": tf.compat.v1.disable_eager_execution() tf.test.main() ================================================ FILE: monolith/native_training/distributed_ps_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from collections import defaultdict import math import os import copy os.environ["MONOLITH_WITH_HOROVOD"] = "1" import itertools from typing import Dict, List import numpy as np from absl.testing import parameterized import logging import tensorflow as tf import tensorflow.python.ops.resources as resources from tensorflow.python.framework import test_util from monolith.native_training import distributed_ps from monolith.native_training import hash_table_ops from monolith.native_training.multi_hash_table_ops import MultiHashTable from monolith.native_training import learning_rate_functions from monolith.native_training import multi_type_hash_table from monolith.native_training import utils from monolith.native_training import test_utils from monolith.native_training.model_export import export_context from monolith.native_training import entry from idl.matrix.proto.example_pb2 import FeatureConfigs, FeatureConfig, PoolingType, OutType, OutConfig import monolith.native_training.embedding_combiners as embedding_combiners from monolith.native_training.runtime.hash_table import embedding_hash_table_pb2 from monolith.native_training.data.feature_utils import string_to_variant from idl.matrix.proto.example_pb2 import Example from monolith.native_training.data.parsers import ParserCtx, sharding_sparse_fids_with_context def factory(idx: int, config): return hash_table_ops.hash_table_from_config(config=config, name_suffix=str(idx)) class DistributedHashTableTest(tf.test.TestCase): def test_basic(self): servers, config = test_utils.create_test_ps_cluster(2) table_config = test_utils.generate_test_hash_table_config(2) with tf.compat.v1.Session(servers[0].target, config=config) as sess: hash_table = distributed_ps.DistributedHashTable(2, table_config, factory) hash_table = hash_table.assign_add( tf.constant([1, 2, 3], dtype=tf.int64), tf.constant([[1, 1], [2, 2], [3, 3]], dtype=tf.float32)) values = hash_table.lookup(tf.constant([1, 2, 3], dtype=tf.int64)) values = sess.run(values) self.assertAllEqual(values, [[1, 1], [2, 2], [3, 3]]) def test_assign(self): servers, config = test_utils.create_test_ps_cluster(2) table_config = test_utils.generate_test_hash_table_config(2) with tf.compat.v1.Session(servers[0].target, config=config) as sess: hash_table = distributed_ps.DistributedHashTable(2, table_config, factory) hash_table = hash_table.assign( tf.constant([1, 2, 3], dtype=tf.int64), tf.constant([[1, 1], [2, 2], [3, 3]], dtype=tf.float32)) values1 = hash_table.lookup(tf.constant([1, 2, 3], dtype=tf.int64)) # Ensure the second assign happens after the first lookup with tf.control_dependencies([values1]): hash_table = hash_table.assign( tf.constant([2, 3, 4], dtype=tf.int64), tf.constant([[1, 1], [2, 2], [3, 3]], dtype=tf.float32)) values2 = hash_table.lookup(tf.constant([1, 2, 3, 4], dtype=tf.int64)) values1, values2 = sess.run([values1, values2]) self.assertAllEqual(values1, [[1, 1], [2, 2], [3, 3]]) self.assertAllEqual(values2, [[1, 1], [1, 1], [2, 2], [3, 3]]) def test_lookup_dedup(self): servers, config = test_utils.create_test_ps_cluster(2) table_config = test_utils.generate_test_hash_table_config(2) with tf.compat.v1.Session(servers[0].target, config=config) as sess: hash_table = distributed_ps.DistributedHashTable(2, table_config, factory) hash_table = hash_table.assign_add( tf.constant([1, 2, 3], dtype=tf.int64), tf.constant([[1, 1], [2, 2], [3, 3]], dtype=tf.float32)) values = hash_table.lookup(tf.constant([1, 1, 3], dtype=tf.int64)) values = sess.run(values) self.assertAllEqual(values, [[1, 1], [1, 1], [3, 3]]) def test_apply_gradients(self): table_config = test_utils.generate_test_hash_table_config(1) g = tf.Graph() with g.as_default(): hash_table = distributed_ps.DistributedHashTable(2, table_config, factory) ids = tf.constant([0, 1], dtype=tf.int64) values = hash_table.lookup(ids) loss = 2 * values grads = tf.gradients(loss, values) global_step = tf.constant(0, dtype=tf.int64) hash_table = hash_table.apply_gradients(ids, grads[0], global_step) new_values = hash_table.lookup(ids) servers, config = test_utils.create_test_ps_cluster(2) with tf.compat.v1.Session(servers[0].target, config=config, graph=g) as sess: new_values = sess.run(new_values) self.assertAllEqual(new_values, [[-2], [-2]]) def test_apply_gradients_with_learning_rate_function(self): table_config = test_utils.generate_test_hash_table_config( 1, learning_rate=learning_rate_functions.PolynomialDecay( initial_learning_rate=1.0, decay_steps=10, end_learning_rate=2.0)) g = tf.Graph() with g.as_default(): global_step = tf.compat.v1.train.get_or_create_global_step() hash_table = distributed_ps.DistributedHashTable(2, table_config, factory) ids = tf.constant([0, 1], dtype=tf.int64) values = hash_table.lookup(ids) loss = 2 * values grads = tf.gradients(loss, values) hash_table = hash_table.apply_gradients(ids, grads[0], global_step) new_values = hash_table.lookup(ids) servers, config = test_utils.create_test_ps_cluster(2) with tf.compat.v1.Session(servers[0].target, config=config, graph=g) as sess: resources.initialize_resources(resources.shared_resources()).run() self.evaluate(tf.compat.v1.global_variables_initializer()) values_eval = sess.run(new_values) self.assertAllEqual(values_eval, [[-2], [-2]]) self.evaluate(tf.compat.v1.assign_add(global_step, 1)) values_eval = sess.run(new_values) self.assertAllClose(values_eval, [[-4.2], [-4.2]]) def test_apply_gradients_with_duplicates(self): table_config = test_utils.generate_test_hash_table_config(1) g = tf.Graph() with g.as_default(): hash_table = distributed_ps.DistributedHashTable(2, table_config, factory) ids = tf.constant([0, 3, 0, 1], dtype=tf.int64) values = hash_table.lookup(ids) loss = 2 * values grads = tf.gradients(loss, values) global_step = tf.constant(0, dtype=tf.int64) hash_table = hash_table.apply_gradients(ids, grads[0], global_step) new_values = hash_table.lookup(ids) servers, config = test_utils.create_test_ps_cluster(2) with tf.compat.v1.Session(servers[0].target, config=config, graph=g) as sess: new_values = sess.run(new_values) self.assertAllEqual(new_values, [[-4], [-2], [-4], [-2]]) def test_apply_gradients_with_different_ids(self): table_config = test_utils.generate_test_hash_table_config(1) g = tf.Graph() with g.as_default(): hash_table = distributed_ps.DistributedHashTable(2, table_config, factory) ids = tf.constant([1, 0], dtype=tf.int64) bp_ids = tf.constant([1, 1], dtype=tf.int64) values = hash_table.lookup(ids) loss = -2 * values grads = tf.gradients(loss, values) global_step = tf.constant(0, dtype=tf.int64) hash_table = hash_table.apply_gradients(bp_ids, grads[0], global_step) new_values = hash_table.lookup(ids) servers, config = test_utils.create_test_ps_cluster(2) with tf.compat.v1.Session(servers[0].target, config=config, graph=g) as sess: new_values = sess.run(new_values) self.assertAllEqual(new_values, [[4], [0]]) def gen_multi_type_table_factory(global_name_prefix=""): def multi_type_table_factory(ps_num: int, slot_to_config_on_ps): def table_factory(name_suffix: str, config): return hash_table_ops.hash_table_from_config( config, name_suffix=global_name_prefix + name_suffix + str(ps_num)) return multi_type_hash_table.MultiTypeHashTable(slot_to_config_on_ps, table_factory) return multi_type_table_factory def gen_native_multi_hash_table_factory(global_name_prefix=""): def native_multi_hash_table_factory(ps_num: int, slot_to_config): return MultiHashTable.from_configs(configs=slot_to_config, name_suffix=global_name_prefix + str(ps_num)) return native_multi_hash_table_factory class DistributedMultiTypeHashTableTest(tf.test.TestCase, parameterized.TestCase): @parameterized.parameters([(True,), (False,)]) def testBasic(self, use_native_multi_hash_table): servers, config = test_utils.create_test_ps_cluster(2) with tf.compat.v1.Session(servers[0].target, config=config) as sess: slot_to_config = { "1": test_utils.generate_test_hash_table_config(1), "2": test_utils.generate_test_hash_table_config( 2, learning_rate=lambda: 1.0) } table_factory = (gen_native_multi_hash_table_factory() if use_native_multi_hash_table else gen_multi_type_table_factory()) hash_table = distributed_ps.DistributedMultiTypeHashTable( 2, slot_to_config, table_factory) ids1 = tf.constant([1, 2], dtype=tf.int64) values1 = tf.constant([[-1], [-2]], dtype=tf.float32) ids2 = tf.constant([3], dtype=tf.int64) values2 = tf.constant([[-3, -3]], dtype=tf.float32) updated_hash_table = hash_table.assign_add({ "1": (ids1, values1), "2": (ids2, values2) }) values = updated_hash_table.lookup({"1": ids1, "2": ids2}) sess.run(tf.compat.v1.global_variables_initializer()) sess.run(tf.compat.v1.local_variables_initializer()) resources.initialize_resources(resources.shared_resources()).run() values = sess.run(values) self.assertAllEqual(values["1"], [[-1], [-2]]) self.assertAllEqual(values["2"], [[-3, -3]]) global_step = tf.constant(0, dtype=tf.int64) updated_hash_table = hash_table.apply_gradients( { "1": (ids1, values1 / 2), "2": (ids2, values2 / 2) }, global_step, req_time=tf.constant(0, dtype=tf.int64)) values = updated_hash_table.lookup({"1": ids1, "2": ids2}) values, _ = sess.run([values, updated_hash_table.as_op()]) self.assertAllEqual(values["1"], [[-0.5], [-1]]) self.assertAllEqual(values["2"], [[-1.5, -1.5]]) @parameterized.parameters([(True,), (False,)]) def test_assign_and_reinitialize(self, use_native_multi_hash_table): servers, config = test_utils.create_test_ps_cluster(2) with tf.compat.v1.Session(servers[0].target, config=config) as sess: slot_to_config = { "1": test_utils.generate_test_hash_table_config(1), "2": test_utils.generate_test_hash_table_config(2) } table_factory = (gen_native_multi_hash_table_factory() if use_native_multi_hash_table else gen_multi_type_table_factory()) hash_table = distributed_ps.DistributedMultiTypeHashTable( 2, slot_to_config, table_factory) ids1 = tf.constant([1, 2], dtype=tf.int64) values1 = tf.constant([[-1], [-2]], dtype=tf.float32) ids2 = tf.constant([3], dtype=tf.int64) values2 = tf.constant([[-3, -3]], dtype=tf.float32) updated_hash_table = hash_table.assign({ "1": (ids1, values1), "2": (ids2, values2) }) values = updated_hash_table.lookup({"1": ids1, "2": ids2}) sess.run(tf.compat.v1.global_variables_initializer()) sess.run(tf.compat.v1.local_variables_initializer()) resources.initialize_resources(resources.shared_resources()).run() values = sess.run(values) self.assertAllEqual(values["1"], [[-1], [-2]]) self.assertAllEqual(values["2"], [[-3, -3]]) updated_hash_table = hash_table.assign({ "1": (ids1, values1 / 2), "2": (ids2, values2 / 2) }) values = updated_hash_table.lookup({"1": ids1, "2": ids2}) values = sess.run(values) self.assertAllEqual(values["1"], [[-0.5], [-1]]) self.assertAllEqual(values["2"], [[-1.5, -1.5]]) if use_native_multi_hash_table: ids11 = tf.constant([1, 2, 3], dtype=tf.int64) updated_hash_table, status1 = updated_hash_table.reinitialize( "1", ids11) updated_hash_table, status2 = updated_hash_table.reinitialize( "3", ids11) values = updated_hash_table.lookup({"1": ids11, "2": ids2, "3": ids11}) values, status1, status2 = sess.run([values, status1, status2]) self.assertAllEqual(values["1"], [[0], [0], [0]]) self.assertAllEqual(values["2"], [[-1.5, -1.5]]) self.assertAllEqual(status1, [1, 1, 0]) self.assertAllEqual(status2, [-1, -1, -1]) @parameterized.parameters([(True,), (False,)]) def test_apply_gradients_with_learning_rate_function( self, use_native_multi_hash_table): servers, config = test_utils.create_test_ps_cluster(2) with tf.compat.v1.Session(servers[0].target, config=config) as sess: global_step = tf.compat.v1.train.get_or_create_global_step() slot_to_config = { "1": test_utils.generate_test_hash_table_config( 1, learning_rate=learning_rate_functions.PolynomialDecay( initial_learning_rate=1.0, decay_steps=10, end_learning_rate=2.0)), "2": test_utils.generate_test_hash_table_config( 2, learning_rate=lambda: 1.0) } table_factory = (gen_native_multi_hash_table_factory() if use_native_multi_hash_table else gen_multi_type_table_factory()) hash_table = distributed_ps.DistributedMultiTypeHashTable( 2, slot_to_config, table_factory) ids1 = tf.constant([1, 2], dtype=tf.int64) values1 = tf.constant([[-1], [-2]], dtype=tf.float32) ids2 = tf.constant([3], dtype=tf.int64) values2 = tf.constant([[-3, -3]], dtype=tf.float32) updated_hash_table = hash_table.assign_add({ "1": (ids1, values1), "2": (ids2, values2) }) values = updated_hash_table.lookup({"1": ids1, "2": ids2}) global_step = tf.compat.v1.train.get_or_create_global_step() sess.run(tf.compat.v1.global_variables_initializer()) sess.run(tf.compat.v1.local_variables_initializer()) resources.initialize_resources(resources.shared_resources()).run() values = sess.run(values) self.assertAllEqual(values["1"], [[-1], [-2]]) self.assertAllEqual(values["2"], [[-3, -3]]) updated_hash_table = hash_table.apply_gradients( { "1": (ids1, values1 / 2), "2": (ids2, values2 / 2) }, global_step) values = updated_hash_table.lookup({"1": ids1, "2": ids2}) values, _ = sess.run([values, updated_hash_table.as_op()]) self.assertAllEqual(values["1"], [[-0.5], [-1]]) self.assertAllEqual(values["2"], [[-1.5, -1.5]]) self.evaluate(tf.compat.v1.assign_add(global_step, 1)) updated_hash_table = hash_table.apply_gradients( { "1": (ids1, values1 / 2), "2": (ids2, values2 / 2) }, global_step) values = updated_hash_table.lookup({"1": ids1, "2": ids2}) values, _ = sess.run([values, updated_hash_table.as_op()]) self.assertAllClose(values["1"], [[0.05], [0.1]]) self.assertAllClose(values["2"], [[0, 0]]) @parameterized.parameters([(True,), (False,)]) def test_apply_gradients_float16(self, use_native_multi_hash_table): servers, config = test_utils.create_test_ps_cluster(2) with tf.compat.v1.Session(servers[0].target, config=config) as sess: slot_to_config = { "1": test_utils.generate_test_hash_table_config(dim=1, use_float16=True), "2": test_utils.generate_test_hash_table_config(dim=2, use_float16=True), } table_factory = (gen_native_multi_hash_table_factory() if use_native_multi_hash_table else gen_multi_type_table_factory()) hash_table = distributed_ps.DistributedMultiTypeHashTable( num_ps=2, slot_to_config=slot_to_config, table_factory=table_factory, transfer_float16=True) ids1 = tf.constant([1, 2], dtype=tf.int64) values1 = tf.constant([[-1], [-2]], dtype=tf.float32) ids2 = tf.constant([3], dtype=tf.int64) values2 = tf.constant([[-3, -3]], dtype=tf.float32) loss1 = 2 * values1 loss2 = 3 * values2 grads1 = tf.gradients(loss1, values1) grads2 = tf.gradients(loss2, values2) hash_table = hash_table.apply_gradients( { '1': {ids1, grads1[0]}, '2': {ids2, grads2[0]}, }, global_step=tf.constant(1, dtype=tf.int64)) values = hash_table.lookup({"1": ids1, "2": ids2}) sess.run(tf.compat.v1.global_variables_initializer()) sess.run(tf.compat.v1.local_variables_initializer()) resources.initialize_resources(resources.shared_resources()).run() res = sess.run(values) self.assertAllEqual(res["1"], [[-2.], [-2.]]) self.assertAllEqual(res["2"], [[-3., -3.]]) class DistributedMultiTypeHashTableServingTest(tf.test.TestCase, parameterized.TestCase): @parameterized.parameters([(True,), (False,)]) def test_export_model(self, use_native_multi_hash_table): table_factory = (gen_native_multi_hash_table_factory() if use_native_multi_hash_table else gen_multi_type_table_factory()) servers, config = test_utils.create_test_ps_cluster(2) with tf.compat.v1.Session(servers[0].target, config=config) as sess: slot_to_config = { "1": test_utils.generate_test_hash_table_config(1), "2": test_utils.generate_test_hash_table_config(2) } # Exporting distributed saved model with tf.Graph().as_default(): export_ctx = export_context.ExportContext() with export_context.enter_export_mode( export_context.ExportMode.DISTRIBUTED, export_ctx): hash_table = distributed_ps.DistributedMultiTypeHashTable( 2, slot_to_config, table_factory) self.assertAllEqual(export_ctx.sub_graph_num, 2) result = hash_table.lookup({"1": tf.constant([1, 2], dtype=tf.int64)}) self.assertEqual(result["1"].shape, [2, 1]) # Exporting standalone saved model with tf.Graph().as_default(): export_ctx = export_context.ExportContext() with export_context.enter_export_mode( export_context.ExportMode.STANDALONE, export_ctx): hash_table = distributed_ps.DistributedMultiTypeHashTable( 2, slot_to_config, table_factory) self.assertAllEqual(export_ctx.sub_graph_num, 0) hash_table.lookup({"1": tf.constant([1], dtype=tf.int64)}) # Normal training with tf.Graph().as_default(): hash_table = distributed_ps.DistributedMultiTypeHashTable( 2, slot_to_config, table_factory) hash_table.lookup({"1": tf.constant([1], dtype=tf.int64)}) def gen_multi_table_factory(global_name_prefix=""): def multi_table_factory(idx: int, configs: Dict[str, entry.HashTableConfigInstance]): def factory(name_suffix, config): return hash_table_ops.hash_table_from_config( config, name_suffix=global_name_prefix + "_".join([name_suffix, str(idx)])) return multi_type_hash_table.MultiTypeHashTable(configs, factory) return multi_table_factory class PartitionedHashTableTest(tf.test.TestCase, parameterized.TestCase): @classmethod def gen_table_config(cls, dims: List[int], use_float16: bool = False, learning_rate: float = 1.0, enable_gpu_emb: bool = False): assert len(dims) >= 1 table_config = embedding_hash_table_pb2.EmbeddingHashTableConfig() if enable_gpu_emb: table_config.gpucuco.SetInParent() else: table_config.cuckoo.SetInParent() for i, dim in enumerate(dims): segment = table_config.entry_config.segments.add() segment.dim_size = dim segment.init_config.zeros.SetInParent() segment.comp_config.fp32.SetInParent() if i == 0: segment.opt_config.ftrl.SetInParent() segment.opt_config.stochastic_rounding_float16 = use_float16 else: segment.opt_config.adagrad.SetInParent() segment.opt_config.stochastic_rounding_float16 = use_float16 return entry.HashTableConfigInstance(table_config, [learning_rate] * len(dims)) @classmethod def gen_out_config(cls, feature_to_unmerged_slice_dims: Dict[str, List[int]], layout_names: List[str]): features = list(feature_to_unmerged_slice_dims.keys()) feature_stats = {name: 0 for name in features} layout_configs = {} for i, layout in enumerate( itertools.zip_longest(*feature_to_unmerged_slice_dims.values())): out_conf = OutConfig(out_type=OutType.CONCAT) assert len(features) == len(layout) out_dim = 0 for name, dim_size in zip(features, layout): if dim_size is None: continue out_dim += dim_size slice_config = out_conf.slice_configs.add() slice_config.feature_name = name slice_config.start = feature_stats[name] slice_config.end = slice_config.start + dim_size pooling_type = PoolingType.SUM feature_stats[name] += dim_size shape = out_conf.shape.add() shape.dims.extend([-1, out_dim]) layout_configs[layout_names[i]] = out_conf return layout_configs @classmethod def setUpClass(cls): if test_util.is_gpu_available(cuda_only=True): import horovod.tensorflow as hvd hvd.init() @classmethod def get_parser_ctx(cls, num_ps, enable_gpu_emb, use_gpu, use_native_multi_hash_table): feature_to_unmerged_slice_dims = { "uid": [1, 4], 'gid': [1, 4, 8], "cid": [1, 4, 8], } feature_to_combiner = { 'uid': embedding_combiners.ReduceSum(), 'gid': embedding_combiners.ReduceSum(), 'cid': embedding_combiners.ReduceSum(), } feature_name_to_config = { name: cls.gen_table_config(dims=dims, enable_gpu_emb=enable_gpu_emb) for name, dims in feature_to_unmerged_slice_dims.items() } layout_configs = cls.gen_out_config(feature_to_unmerged_slice_dims, layout_names=['bias', 'vec', 'deep']) parser_ctx = ParserCtx(True) parser_ctx.parser_type = 'example' parser_ctx.sharding_sparse_fids_op_params = distributed_ps.PartitionedHashTable.gen_feature_configs( num_ps=num_ps, feature_name_to_config=feature_name_to_config, layout_configs=layout_configs, feature_to_combiner=feature_to_combiner, feature_to_unmerged_slice_dims=feature_to_unmerged_slice_dims, use_native_multi_hash_table=use_native_multi_hash_table, unique=lambda: True, transfer_float16=False, enable_gpu_emb=enable_gpu_emb, use_gpu=use_gpu) return parser_ctx @classmethod def gen_data(cls, num_ps, sub_table_name_to_config, with_emb: bool = False, method: str = 'random', value: float = 1.0): assert method in {'random', 'const'} data_tf, data_np = {}, {} for i in range(num_ps): data_tf[i], data_np[i] = {}, {} for tbname, conf in sub_table_name_to_config.items(): size = sum( seg.dim_size for seg in conf._table_config.entry_config.segments) fids_np = np.array([i + num_ps * j for j in range(size)], dtype=np.int64) fids_tf = tf.constant(value=fids_np, dtype=tf.int64, name=f'{tbname}_fids') if with_emb: emb_size = sum( segment.dim_size for segment in conf._table_config.entry_config.segments) if method == 'random': embs_np = np.random.uniform(-1, 1, size=(size, emb_size)) else: embs_np = np.ones( shape=(size, emb_size), dtype=np.float32) * value #fids_np.reshape([size, 1]) embs_tf = tf.constant(value=embs_np, dtype=tf.float32, name=f'{tbname}_embss') data_tf[i][tbname] = (fids_tf, embs_tf) data_np[i][tbname] = (fids_np, embs_np) else: data_tf[i][tbname] = fids_tf data_np[i][tbname] = fids_np return data_tf, data_np @classmethod def gen_variant_tensor(cls, batch_size: int): examples = [] for i in range(batch_size): example = Example() start = i * 3 named_feature = example.named_feature.add() named_feature.name = 'uid' named_feature.feature.fid_v2_list.value.append(start) named_feature = example.named_feature.add() named_feature.name = 'gid' named_feature.feature.fid_v2_list.value.append(start + 1) named_feature = example.named_feature.add() named_feature.name = 'cid' named_feature.feature.fid_v2_list.value.append(start + 2) logging.info(f" {i}/{batch_size}:{example}") examples.append(example.SerializeToString()) example_strs = tf.constant(value=examples, dtype=tf.string, name='examples') return string_to_variant(example_strs, variant_type='example') #with tf.compat.v1.Session(servers[0].target, config=config) as sess, test_util.use_gpu() if use_gpu else sess.graph.device(lambda op: '/CPU:0') def _test_basic(self, use_native_multi_hash_table, use_gpu): enable_gpu_emb = False #not support hash_table.assign num_ps = 2 parser_ctx = self.get_parser_ctx(num_ps, enable_gpu_emb, use_gpu, use_native_multi_hash_table) sub_table_name_to_config = parser_ctx.sharding_sparse_fids_op_params.sub_table_name_to_config servers, config = test_utils.create_test_ps_cluster(num_ps) config.share_cluster_devices_in_session = True config.experimental.share_session_state_in_clusterspec_propagation = True # grappler doesn't really understand RaggedTensor. config.graph_options.rewrite_options.disable_meta_optimizer = True with tf.compat.v1.Session(servers[0].target, config=config) as sess: hash_table = distributed_ps.PartitionedHashTable( num_ps, gen_native_multi_hash_table_factory() if use_native_multi_hash_table else gen_multi_table_factory(), use_native_multi_hash_table=use_native_multi_hash_table, parser_ctx=parser_ctx) if use_native_multi_hash_table: sess.run(tf.compat.v1.global_variables_initializer()) sess.run(tf.compat.v1.local_variables_initializer()) resources.initialize_resources(resources.shared_resources()).run() assign_data, assign_data_np = self.gen_data(num_ps, sub_table_name_to_config, with_emb=True) hash_table = hash_table.assign(assign_data) assign_add_data, assign_add_data_np = self.gen_data( num_ps, sub_table_name_to_config, with_emb=True) hash_table = hash_table.assign_add(assign_add_data) lookup_data, _ = self.gen_data(num_ps, sub_table_name_to_config, with_emb=False) values = hash_table._lookup_raw(lookup_data) real_result = sess.run(values) #logging.info( # f"xxx {lookup_data} {assign_data_np} {assign_add_data_np} {real_result}" #) for part in assign_data_np: for tbname in assign_data_np[part]: fid1, emb1 = assign_data_np[part][tbname] fid2, emb2 = assign_add_data_np[part][tbname] self.assertAllClose(real_result[part][tbname], emb1 + emb2) @parameterized.parameters([(True,), (False,)]) def test_basic(self, use_native_multi_hash_table): self._test_basic(use_native_multi_hash_table, use_gpu=False) @parameterized.parameters([(True,), (False,)]) @test_util.run_gpu_only def test_basic_gpu(self, use_native_multi_hash_table): self._test_basic(use_native_multi_hash_table, use_gpu=True) def _test_lookup(self, use_native_multi_hash_table, use_gpu): enable_gpu_emb = False #not support hash_table.assign num_ps = 2 parser_ctx = self.get_parser_ctx(num_ps, enable_gpu_emb, use_gpu, use_native_multi_hash_table) sub_table_name_to_config = parser_ctx.sharding_sparse_fids_op_params.sub_table_name_to_config servers, config = test_utils.create_test_ps_cluster(num_ps) config.share_cluster_devices_in_session = True config.experimental.share_session_state_in_clusterspec_propagation = True # grappler doesn't really understand RaggedTensor. config.graph_options.rewrite_options.disable_meta_optimizer = True with tf.compat.v1.Session(servers[0].target, config=config) as sess: hash_table = distributed_ps.PartitionedHashTable( num_ps, gen_native_multi_hash_table_factory() if use_native_multi_hash_table else gen_multi_table_factory(), use_native_multi_hash_table=use_native_multi_hash_table, parser_ctx=parser_ctx) hash_table._inner_data_type = 'example' if use_native_multi_hash_table: sess.run(tf.compat.v1.global_variables_initializer()) sess.run(tf.compat.v1.local_variables_initializer()) resources.initialize_resources(resources.shared_resources()).run() x = 2.0 assign_data, assign_data_np = self.gen_data(num_ps, sub_table_name_to_config, with_emb=True, method='const', value=x) logging.info(f"show assign_data_np {assign_data_np}") hash_table = hash_table.assign(assign_data) sparse_features = self.gen_variant_tensor(batch_size=num_ps * 3) auxiliary_bundle = {} features = {} sharding_sparse_fids_with_context(sparse_features, features, parser_ctx) layouts = hash_table.lookup(features, auxiliary_bundle=auxiliary_bundle) layouts = sess.run(layouts) auxiliary_bundle_ret = sess.run(auxiliary_bundle) expect = { 'bias': np.array([[1., 1., 1.], [1., 1., 1.], [1., 1., 1.], [1., 1., 1.], [0., 1., 1.], [0., 1., 1.]], dtype=np.float32), 'deep': np.array([[ 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1. ], [ 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1. ], [ 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1. ], [ 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1. ], [ 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1. ], [ 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1. ]], dtype=np.float32), 'vec': np.array([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], [0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1.], [0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1.]], dtype=np.float32) } for key, value in layouts.items(): logging.info(f" {key} {value} --- {expect[key] * x}") self.assertAllClose(value, expect[key] * x) @parameterized.parameters([(True,), (False,)]) def test_lookup(self, use_native_multi_hash_table): self._test_lookup(use_native_multi_hash_table, use_gpu=False) @parameterized.parameters([(True,), (False,)]) @test_util.run_gpu_only def test_lookup_gpu(self, use_native_multi_hash_table): self._test_lookup(use_native_multi_hash_table, use_gpu=True) def _test_apply_gradients(self, use_native_multi_hash_table, use_gpu): enable_gpu_emb = False #not support hash_table.assign num_ps = 2 parser_ctx = self.get_parser_ctx(num_ps, enable_gpu_emb, use_gpu, use_native_multi_hash_table) sub_table_name_to_config = parser_ctx.sharding_sparse_fids_op_params.sub_table_name_to_config servers, config = test_utils.create_test_ps_cluster(num_ps) config.share_cluster_devices_in_session = True config.experimental.share_session_state_in_clusterspec_propagation = True # grappler doesn't really understand RaggedTensor. config.graph_options.rewrite_options.disable_meta_optimizer = True with tf.compat.v1.Session(servers[0].target, config=config) as sess: hash_table = distributed_ps.PartitionedHashTable( num_ps, gen_native_multi_hash_table_factory() if use_native_multi_hash_table else gen_multi_table_factory(), use_native_multi_hash_table=use_native_multi_hash_table, parser_ctx=parser_ctx) hash_table._inner_data_type = 'example' if use_native_multi_hash_table: sess.run(tf.compat.v1.global_variables_initializer()) sess.run(tf.compat.v1.local_variables_initializer()) resources.initialize_resources(resources.shared_resources()).run() init_val = 3.0 assign_data, assign_data_np = self.gen_data(num_ps, sub_table_name_to_config, with_emb=True, method='const', value=init_val) hash_table = hash_table.assign(assign_data) sparse_features = self.gen_variant_tensor(batch_size=num_ps * 3) auxiliary_bundle = {} features = {} sharding_sparse_fids_with_context(sparse_features, features, parser_ctx) layouts = hash_table.lookup(features, auxiliary_bundle=auxiliary_bundle) layout_grads_and_vars, init_grad = [], 2.0 for name in sorted(hash_table._feature_configs.out_configs): layout = layouts[name] layout_grads_and_vars.append( (tf.ones_like(layout, dtype=tf.float32) * init_grad, layout)) global_step = tf.constant(0, dtype=tf.int64) apply_gradients_op = hash_table.apply_gradients( layout_grads_and_vars, global_step, auxiliary_bundle=auxiliary_bundle) lookup_data, lookup_data_np = self.gen_data(num_ps, sub_table_name_to_config, with_emb=False) with tf.control_dependencies([apply_gradients_op]): values = hash_table._lookup_raw(lookup_data) values = sess.run(values) #logging.info(f"xx values: {lookup_data_np} {values}") if use_native_multi_hash_table: shards = { 'uid:0': [0, 6, 12], 'uid:1': [3, 9, 15], 'cid:0': [2, 8, 14], 'cid:1': [5, 11, 17], 'gid:0': [4, 10, 16], 'gid:1': [1, 7, 13] } else: shards = { '9871d3a2c554b27151cacf1422eec048:0': [0, 6, 12], '9871d3a2c554b27151cacf1422eec048:1': [3, 9, 15], 'c4a398dea30b21551ae4c09454001dba:0': [2, 8, 14, 4, 10, 16], 'c4a398dea30b21551ae4c09454001dba:1': [5, 11, 17, 1, 7, 13] } learning_rate = 1.0 initial_accumulator_value = 0.1 beta = 0.0 ada_grad = learning_rate / math.sqrt( init_grad * init_grad + initial_accumulator_value) * init_grad n = initial_accumulator_value + init_grad * init_grad sigma = (math.sqrt(n) - math.sqrt(initial_accumulator_value)) / learning_rate z = init_grad - sigma * init_val ftrl_grad = -learning_rate * z / (math.sqrt(n) + beta) for name, value in shards.items(): tb_name, idx = name.split(':') lu_value = values[int(idx)][tb_name] for i in value: k = int(i / num_ps) if k < lu_value.shape[0]: for j, x in enumerate(lu_value[k]): if j == 0: self.assertAlmostEqual(x, ftrl_grad, delta=1e-6) else: self.assertAlmostEqual(init_val - ada_grad, x, delta=1e-6) @parameterized.parameters([(True,), (False,)]) def test_apply_gradients(self, use_native_multi_hash_table): self._test_apply_gradients(use_native_multi_hash_table, use_gpu=False) @parameterized.parameters([(True,), (False,)]) @test_util.run_gpu_only def test_apply_gradients_gpu(self, use_native_multi_hash_table): self._test_apply_gradients(use_native_multi_hash_table, use_gpu=True) @parameterized.parameters([(True,), (False,)]) @test_util.run_gpu_only def test_apply_gradients_for_gpu_emb(self, use_native_multi_hash_table): use_gpu = True enable_gpu_emb = True num_ps = 1 parser_ctx = self.get_parser_ctx(num_ps, enable_gpu_emb, use_gpu, use_native_multi_hash_table) sub_table_name_to_config = parser_ctx.sharding_sparse_fids_op_params.sub_table_name_to_config def run_once(task_name, loop, parser_ctx_, grad_list=None): servers, config = test_utils.create_test_ps_cluster(num_ps) config.share_cluster_devices_in_session = True config.experimental.share_session_state_in_clusterspec_propagation = True # grappler doesn't really understand RaggedTensor. config.graph_options.rewrite_options.disable_meta_optimizer = True with tf.Graph().as_default(), tf.compat.v1.Session(servers[0].target, config=config) as sess: hash_table = distributed_ps.PartitionedHashTable( num_ps, gen_native_multi_hash_table_factory(task_name) if use_native_multi_hash_table else gen_multi_table_factory(task_name), use_native_multi_hash_table=use_native_multi_hash_table, parser_ctx=parser_ctx_) hash_table._inner_data_type = 'example' if use_native_multi_hash_table: sess.run(tf.compat.v1.global_variables_initializer()) sess.run(tf.compat.v1.local_variables_initializer()) resources.initialize_resources(resources.shared_resources()).run() sparse_features = self.gen_variant_tensor(batch_size=num_ps * 3) auxiliary_bundle = {} features = {} sharding_sparse_fids_with_context(sparse_features, features, parser_ctx_) ret_grad_list = defaultdict(lambda: []) apply_gradients_op = tf.no_op() for loop_idx in range(loop): with tf.control_dependencies([apply_gradients_op]): layouts = hash_table.lookup(features.copy(), auxiliary_bundle=auxiliary_bundle) layouts = sess.run(layouts) layout_grads_and_vars = [] for name in sorted(hash_table._feature_configs.out_configs): layout = layouts[name] if grad_list: grad = grad_list[name][loop_idx] else: grad = tf.random.uniform(layout.shape) grad = sess.run(grad) ret_grad_list[name].append(grad) layout_grads_and_vars.append((grad, layout)) global_step = tf.constant(0, dtype=tf.int64) apply_gradients_op = hash_table.apply_gradients( layout_grads_and_vars, global_step, auxiliary_bundle=auxiliary_bundle) with tf.control_dependencies([apply_gradients_op]): layouts = hash_table.lookup(features, auxiliary_bundle=auxiliary_bundle) layouts = sess.run(layouts) return layouts, ret_grad_list layouts, ret_grad_list = run_once('gpu_emb_', 2, parser_ctx) #logging.info(f"xx values: {layouts} {ret_grad_list}") layouts_cpu, ret_grad_list = run_once( "cpu_", 2, self.get_parser_ctx(num_ps, False, False, use_native_multi_hash_table), ret_grad_list) #logging.info(f"xx values: {layouts_cpu} {ret_grad_list}") assert len(layouts) == len(layouts_cpu) for name, value in layouts.items(): assert name in layouts_cpu value_2 = layouts_cpu[name] assert len(value) == len(value_2) for a, b in zip(value, value_2): assert len(a) == len(b) #print(a, b) for i in range(len(a)): self.assertAlmostEqual(a[i], b[i], delta=1e-6) if __name__ == "__main__": tf.compat.v1.disable_eager_execution() tf.test.main() ================================================ FILE: monolith/native_training/distributed_serving_ops.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from contextlib import nullcontext from typing import List from absl import logging import tensorflow as tf from monolith.agent_service.agent_service_pb2 import ServerType from monolith.agent_service.backends import SyncBackend from monolith.utils import get_libops_path from monolith.native_training.runtime.ops import gen_monolith_ops from monolith.native_training import utils from monolith.native_training.model_export.export_context import is_exporting_standalone from monolith.native_training.runtime.parameter_sync import \ parameter_sync_pb2 gen_distributed_serving_ops = gen_monolith_ops def remote_predict(input_tensor_alias, input_tensors, output_tensor_alias, model_name, task, old_model_name, model_version=-1, fail_op_on_rpc_error=True, max_rpc_deadline_millis=30, output_types=None, name=None, signature_name='serving_default'): """Runs a predict in remote process through rpc. Args: input_tensor_alias: input tensor alias for Predict input_tensors: input tensors for Predict output_tensor_alias: output tensor alias for Predict task: Parameter Server index model_name: model_name that the Predict is running on model_version: the model version for the Predict call. If unset, the highest version available for serving will be targeted. max_rpc_deadline_millis: rpc deadline in millis output_types: output types for Predict name: name for the op in the graph signature_name: the signature def for remote graph inference Returns: output_tensors as a result of the Predict. Raises ValueError if model_name value is missing. """ if model_name is None: raise ValueError('model_name must be specified.') return (gen_distributed_serving_ops.tf_serving_remote_predict( input_tensor_alias, input_tensors, output_tensor_alias, model_name=model_name, old_model_name=old_model_name, task=task, model_version=model_version, fail_op_on_rpc_error=fail_op_on_rpc_error, max_rpc_deadline_millis=max_rpc_deadline_millis, signature_name=signature_name, output_types=output_types, name=name))[2] def create_parameter_sync_clients(ps_num: int,) -> List[tf.Tensor]: logging.info("Create parameter sync clients.") if ps_num == 0: return [parameter_sync_client_from_config()] sync_clients = list() for i in range(ps_num): ps_device_name = utils.ps_device(i) with nullcontext() if is_exporting_standalone() else tf.device( ps_device_name): sync_clients.append(parameter_sync_client_from_config(name_suffix=str(i))) return sync_clients def parameter_sync_client_from_config( config: parameter_sync_pb2.ClientConfig = None, name_suffix: str = "") -> tf.Tensor: return gen_distributed_serving_ops.MonolithParameterSyncClient( config=config.SerializeToString() if config else '', shared_name="MonolithSyncClient_" + name_suffix) def refresh_sync_config(sync_backend: SyncBackend, ps_index: int) -> bytes: saved_model, online_ps_replicas = sync_backend.get_sync_targets( f"ps_{ps_index}") config = parameter_sync_pb2.ClientConfig() if isinstance(online_ps_replicas, list): config.targets.extend(online_ps_replicas) elif isinstance(online_ps_replicas, dict): for addr, target_extra_info in online_ps_replicas.items(): config.targets.append(addr) config.targets_extra_info.append(target_extra_info) config.model_name = saved_model config.signature_name = "hashtable_assign" config.timeout_in_ms = 3000 return config.SerializeToString() def create_dummy_sync_client() -> tf.Tensor: return gen_distributed_serving_ops.MonolithDummySyncClient( shared_name="MonolithDummySyncClient") def create_dummy_sync_server(address: str) -> tf.Tensor: return gen_distributed_serving_ops.MonolithDummySyncServer(address=address) class ParameterSyncClient(object): def __init__(self, client: tf.Tensor): self._client = client def create_sync_op(self, config_str: tf.Tensor): return gen_distributed_serving_ops.monolith_parameter_sync( self._client, config_str) def as_op(self): return tf.group(self._client) @property def handle(self): return self._client class DummySyncServer(object): def __init__(self, address: str): self._server = create_dummy_sync_server(address) def shutdown(self): return gen_distributed_serving_ops.monolith_dummy_sync_server_shutdown( self._server) def get_port(self): return gen_distributed_serving_ops.monolith_dummy_sync_server_get_port( self._server) def as_op(self): return tf.group(self._server) @property def handle(self): return self._server ================================================ FILE: monolith/native_training/distributed_serving_ops_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import json import tensorflow as tf from absl import logging from monolith.agent_service import utils from monolith.agent_service import backends from monolith.agent_service.replica_manager import ReplicaWatcher from monolith.agent_service.mocked_zkclient import FakeKazooClient from monolith.agent_service.data_def import ReplicaMeta from monolith.native_training import hash_table_ops from monolith.native_training import distributed_serving_ops from monolith.native_training.distributed_serving_ops import ParameterSyncClient, \ DummySyncServer from monolith.native_training.runtime.parameter_sync import \ parameter_sync_pb2 def test_dummy_sync_server(server_num: int): return [DummySyncServer("localhost:0") for _ in range(server_num)] def test_parameter_sync_client(targets): config = parameter_sync_pb2.ClientConfig() config.targets.extend(targets) return ParameterSyncClient( distributed_serving_ops.parameter_sync_client_from_config(config=config)) def _get_id_tensor(x): return tf.constant(x, dtype=tf.int64) class ParameterSyncOpsTest(tf.test.TestCase): def test_parameter_sync_client(self): servers = test_dummy_sync_server(2) ports = [server.get_port() for server in servers] with self.session() as sess: ports = sess.run(ports) targets = ["localhost:{}".format(port[0]) for port in ports] client = test_parameter_sync_client(targets) dim = 3 hash_table = hash_table_ops.test_hash_table(dim, learning_rate=0.1, sync_client=client.handle) id_tensor = _get_id_tensor([0, 0, 1]) embeddings = hash_table.lookup(id_tensor) loss = -embeddings grads = tf.gradients(loss, embeddings) global_step = _get_id_tensor(0) hash_table = hash_table.apply_gradients(id_tensor, grads[0], global_step=global_step) new_embeddings = hash_table.lookup(_get_id_tensor([0, 1])) new_embeddings = sess.run(new_embeddings) self.assertAllClose(new_embeddings, [[0.2, 0.2, 0.2], [0.1, 0.1, 0.1]]) config = parameter_sync_pb2.ClientConfig() config.targets.extend(targets) result = sess.run(client.create_sync_op(config.SerializeToString())) print(json.dumps(json.loads(result[0]), indent=2)) sess.run([server.shutdown() for server in servers]) def test_refresh_sync_config_1(self): def mock_replica_watcher(ps_index: int): zk = FakeKazooClient() zk.start() config = utils.AgentConfig(bzid="demo", base_name="test_ffm_model", deploy_type='ps', replica_id=0, num_ps=10) path_prefix = f'/{config.bzid}/service/{config.base_name}' replica_path = f'{path_prefix}/ps:{ps_index}/{config.replica_id}' replica_meta = ReplicaMeta(address="localhost:8500", stat=utils.ModelState.AVAILABLE) replica_meta_bytes = bytes(replica_meta.to_json(), encoding='utf-8') zk.ensure_path(replica_path) zk.set(replica_path, replica_meta_bytes) replica_watcher = ReplicaWatcher(zk, config) replica_watcher.watch_data() return replica_watcher, zk replica_watcher, zk = mock_replica_watcher(1) config_str = distributed_serving_ops.refresh_sync_config( replica_watcher.to_sync_wrapper(), 1) config = parameter_sync_pb2.ClientConfig() config.ParseFromString(config_str) self.assertEqual(config.model_name, "ps_1") logging.info('targets: %s', config.targets) self.assertEqual(config.targets, ["localhost:8500"]) replica_watcher.stop() zk.stop() def test_refresh_sync_config_2(self): # prepare envs bd = backends.ZKBackend('demo', zk_servers='127.0.0.1:9999') bd._zk = FakeKazooClient() bd.start() container = backends.Container("default", "asdf") service_info = backends.ContainerServiceInfo(grpc="localhost:8888", http="localhost:8889", archon="localhost:8890", agent="localhost:8891", idc="lf") bd.report_service_info(container, service_info) bd.sync_available_saved_models( container, { backends.SavedModel("test_ffm_model", "ps_0"), backends.SavedModel("test_ffm_model", "ps_1"), backends.SavedModel("test_ffm_model", "ps_2"), }) # test sync targets bd.subscribe_model("test_ffm_model") config = parameter_sync_pb2.ClientConfig() config_str = distributed_serving_ops.refresh_sync_config(bd, 1) config.ParseFromString(config_str) self.assertEqual(config.model_name, "test_ffm_model:ps_1") self.assertEqual(config.targets, ["localhost:8888"]) bd.stop() if __name__ == "__main__": logging.set_verbosity(logging.INFO) tf.compat.v1.disable_eager_execution() tf.test.main() ================================================ FILE: monolith/native_training/distribution_ops.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging import os from typing import List, Tuple, Optional, NamedTuple import tensorflow as tf from idl.matrix.proto.example_pb2 import FeatureConfigs from monolith.native_training.runtime.ops import gen_monolith_ops gen_distribution_ops = gen_monolith_ops def split_by_indices(indices: tf.Tensor, tensor: tf.Tensor, num_splits: int) -> tf.Tensor: """ Split |input| elements in into |num_splits| tensors based on indices. |input| is treated as a list of tensors. """ return gen_distribution_ops.monolith_split_by_indices(indices, tensor, num_splits) @tf.RegisterGradient("MonolithSplitByIndices") def _split_by_indices_gradient(op: tf.Operation, *grads): indices = op.inputs[0] tensor = op.inputs[1] tensor_grad = gen_distribution_ops.monolith_split_by_indices_gradient( indices, tensor, grads) return None, tensor_grad def ragged_split_by_indices( indices: tf.Tensor, num: tf.RaggedTensor, num_splits: int) -> Tuple[List[tf.RaggedTensor], List[tf.RaggedTensor]]: """Split a int64 ragged tensor into |num_splits| ragged tensor based on indices. Returns splitted ragged tensor and splitted original position of each number in ragged tensor. For example, indices = [0, 1, 0, 1] num = [[4, 3, 2], [1]] num_splits = 2 ===> [ [[4, 2], []], [[3], [1]], ], [ [[0, 2], []], [[1], [3]], ] """ splitted_num, splitted_num_splits, splitted_pos = gen_distribution_ops.monolith_ragged_split_by_indices( indices, num.values, num.row_splits, num_splits=num_splits) results = [] pos = [] for i in range(num_splits): results.append( tf.RaggedTensor.from_row_splits(splitted_num[i], splitted_num_splits[i], validate=False)) pos.append( tf.RaggedTensor.from_row_splits(splitted_pos[i], splitted_num_splits[i], validate=False)) return results, pos class _UniqueKeyWithValueAndOffsetResult(NamedTuple): unique_key: tf.RaggedTensor value_offset: tf.RaggedTensor value_buffer: tf.Tensor def unique_key_with_value_and_offset(key: tf.RaggedTensor, dims: List[int], generate_buffer=True): """Uniques the keys within each key[i], and generates the corresponding value offset map. For key[i][j], the coresponding value's length is dims[i]. unique_key - the unique result of each key[i] value_buffer - a SharedTensor represents all values concated. value_offset - a ragged tensor with ragged_rank=2. value_offset[i][j] repensents the all offsets in value_buffer for unique_key[i][j]. So if we know the value of key[i][j] is v. We need to fill value_buffer[value_offset[i][j][0]:value_offset[i][j]+dims[i]] = v value_buffer[value_offset[i][j][1]:value_offset[i][j]+dims[i]] = v ... For example, key = [[0, 1, 0], [0]] dims = [2, 3] => unique_key = [[0, 1], [0]] value_offset = [[[0, 4], [2]], [[6]]] value_buffer = float buffer with length 2*3 + 3*1 = 9 """ results = gen_distribution_ops.monolith_unique_key_with_value_and_offset( key.values, key.row_splits, dims=dims, generate_buffer=generate_buffer) return _UniqueKeyWithValueAndOffsetResult( unique_key=tf.RaggedTensor.from_row_splits(results[0], results[1], validate=False), value_offset=tf.RaggedTensor.from_nested_row_splits( results[2], [results[1], results[3]], validate=False), value_buffer=results[4]) def fill_with_offset_map(pos: tf.RaggedTensor, value: tf.Tensor, value_offset_map: tf.RaggedTensor, value_buffer: tf.Tensor, dims: List[int]) -> tf.Tensor: """Fill the |value| to |value_buffer| for each |pos| in |value_offset_map|. Specifically, for each pos[i][j], we extrac value slice from value (v), we got all positions for pos[i][j], which are value_offset_map.values[pos[i][j]][0], value_offset_map.values[pos[i][j]][1] ... And fill the value_buffer. For example, pos = [[0, 1], [2]] value = [0, 1, 2, 3, 4, 5, 6] value_offset_map = [[[0, 4], [2]], [[6]]] dims = [2, 3] => value_buffer = [0, 1, 2, 3, 0, 1, 4, 5, 6] """ value_offset_map_1d = value_offset_map.values return gen_distribution_ops.monolith_fill_with_offset_map( pos.values, pos.row_splits, value, value_offset_map_1d.values, value_offset_map_1d.row_splits, value_buffer, dims=dims, ) def fill_with_offset_map_gradient(pos: tf.RaggedTensor, grad: tf.Tensor, value_offset_map: tf.RaggedTensor, dims: List[int]) -> tf.Tensor: value_offset_map_1d = value_offset_map.values return gen_distribution_ops.monolith_fill_with_offset_map_gradient( pos.values, pos.row_splits, grad, value_offset_map_1d.values, value_offset_map_1d.row_splits, dims=dims, ) @tf.RegisterGradient("MonolithFillWithOffsetMap") def _fill_with_offset_map_gradient(op: tf.Operation, grad): value_offset_map = tf.RaggedTensor.from_nested_row_splits( op.inputs[3], [op.inputs[1], op.inputs[4]], validate=False) pos = tf.RaggedTensor.from_row_splits(op.inputs[0], op.inputs[1]) backprop_grad = fill_with_offset_map_gradient(pos, grad, value_offset_map, dims=op.get_attr("dims")) return None, None, backprop_grad, None, None, None def finalize_shared_tensor(shared_tensor_handles: List[tf.Tensor], dtype, shape): """Finalize a shared tensor and it won't be accessible in the future. shared_tensor_handles - the *same handle* which repeats several times. The reason why it is a list is to build a meaningful dependencies for output tensor, which is useful for gradient calculation. For example, t = SharedTensor() t1 = FillPart(t, data0) t2 = FillPart(t, data1) t = finalize_shared_tensor([t1, t2]) In this case, we want the grad on t can be propagated back to data0, data1. """ return gen_distribution_ops.monolith_finalize_shared_tensor( shared_tensor_handles, dtype=dtype, shape=shape) @tf.RegisterGradient("MonolithFinalizeSharedTensor") def _finalize_shared_tensor_gradient(op: tf.Operation, grad): return grad def reorder_by_indices(input: tf.Tensor, shard_ids: tf.Tensor, num_of_shards: int) -> List[tf.Tensor]: """ Reorder the input based on precomputed shard_ids from the caller. Example 1: input: [1, 2, 3, 2] shard_ids: [1, 0, 1, 0] num_of_shards: 2 output => [2, 3, 1] shard_sizes => [1, 2] Args: input: 1-D int64/2-D float tensor with shape [N,] shard_ids: 1-D int32 tensor with shape [N], shard_ids[i] represents the shard is for input[i, ...] num_of_shards: a int32 scalar, representing the number of shards. Returns: Output: reordered 1-D int64/2-D float tensor with shape [M,], M<=N. Shard_sizes: 1-D int32 tensor with shape [num_of_shards]. """ return gen_distribution_ops.monolith_reorder_by_indices( input, shard_ids, num_of_shards) def fused_reorder_by_indices( inputs: List[tf.Tensor], num_of_shards: int, dim_sizes: List[int]) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor]: """ Reorder and dedup int64 values in a list of tensors according to the sharding. Note that the deduplication is applied per tensor of inputs. In other words, the dedup process does not check duplicated ints across tensors of inputs. With dim_sizes of the embedding merged slot, it maps each fid to the offset in the expected fused embedding to be generated based on the output fids. For more intuitive cases and explanations, check out the unit test cases. Examples: inputs: [[0, 1, 0], [3, 2, 3], [5, 6, 7]] num_of_shards: 2 dim_sizes: [1, 2, 3] output => [0,2,6,1,3,5,7] shard_sizes => [3,4] sharded_slot_sizes => [1,1,1,1,1,2] fused_embedding_offsets => [0,6,0,7,1,7,9,3,12] Args: inputs: List of 1-D int64 tensors. num_of_shards: a int32 scalar, representing the number of shards. Returns: output: reordered 1-D int64 tensor. shard_sizes: 1-D int32 tensor with shape (num_of_shards). sharded_slot_sizes: 1-D int32 tensor with shape (num_of_shards * len(inputs)). fused_embedding_offsets: 1-D int32 tensor """ # We only trigger this N-1 sharding scheme when it is beyond single host mode. rank0_empty_shard = os.environ.get('MONOLITH_SYNC_EMPTY_RANK0_PS_SHARD', '1') == '1' and num_of_shards > 4 return gen_distribution_ops.fused_reorder_by_indices(inputs, num_of_shards, dim_sizes, rank0_empty_shard) # # An Alternative Implementation based on TensorFlow Builtin Ops: # with tf.name_scope('fused_reorder_by_indicies'): # inputs = [tf.unique(ids)[0] for ids in inputs] # shard_indicies = [tf.cast(tf.math.floormod(ids, num_of_shards), dtype=tf.int32) for ids in inputs] # sharded_slot_lists = [tf.dynamic_partition(ids, indicies, num_of_shards) for ids, indicies in zip(inputs, shard_indicies)] # outputs = [] # shard_sizes = [] # sharded_slot_sizes = [] # for i in range(num_of_shards): # sizes_per_shard = [] # for m in range(len(inputs)): # outputs.append(sharded_slot_lists[m][i]) # sizes_per_shard.append(tf.size(sharded_slot_lists[m][i])) # shard_sizes.append(tf.reduce_sum(sizes_per_shard)) # sharded_slot_sizes.extend(sizes_per_shard) # output = tf.concat(outputs, axis=0) # return output, tf.convert_to_tensor(shard_sizes), tf.convert_to_tensor(sharded_slot_sizes) def map_id_to_embedding(ids: List[tf.Tensor], embeddings: List[tf.Tensor], input: tf.Tensor, use_multi_threads: bool = True) -> tf.Tensor: """ Map int64 in input to embedding. Output will have an extra dim at last which equals to embedding dim. The length of ids and embeddings must match. Args: ids: a list of 1-D int64 tensor. embeddings: a list of 2-D float32 tensor. Represents mapping. use_multi_threads: True if the caller wants to use multi-threads. """ if len(ids) != len(embeddings): raise ValueError( "ids length and embeddings lenght must match. {} vs {}".format( len(ids), len(embeddings))) return gen_distribution_ops.monolith_map_id_to_embedding( ids, embeddings, input, use_multi_threads=use_multi_threads) def fused_embedding_to_layout( embeddings_list: List[tf.Tensor], fid_list_row_split: List[tf.Tensor], fid_offset: tf.Tensor, feature_offset: tf.Tensor, nfl_offset: tf.Tensor, batch_size: tf.Tensor, variant_type: str, feature_cfgs: FeatureConfigs, ps_num: int, fid_list_emb_row_lenth: tf.Tensor = None, nfl_size: tf.Tensor = None, feature_size: tf.Tensor = None, fid_size: tf.Tensor = None, emb_size: tf.Tensor = None, parallel_flag: int = 0, version: int = 3, ): assert variant_type in { 'example', 'example_batch', 'examplebatch', 'instance' } variant_type = 'example_batch' if variant_type == 'examplebatch' else variant_type feature_cfgs_str = feature_cfgs.SerializeToString() N = 0 for layout, conf in feature_cfgs.out_configs.items(): N += len(conf.shape) if version != 4: assert fid_list_emb_row_lenth is None if version == 5: layout_tensors = gen_distribution_ops.monolith_embedding_to_layout_v5( embeddings_list=embeddings_list, fid_offset=fid_offset, feature_offset=feature_offset, nfl_offset=nfl_offset, batch_size=batch_size, nfl_size=nfl_size, feature_size=feature_size, fid_size=fid_size, emb_size=emb_size, num_out=N, variant_type=variant_type, feature_cfgs=feature_cfgs_str, ps_num=ps_num, parallel_flag=parallel_flag) elif version == 4: assert fid_list_emb_row_lenth is not None layout_tensors = gen_distribution_ops.monolith_embedding_to_layout_v4( embeddings_list=embeddings_list, fid_list_emb_row_lenth=fid_list_emb_row_lenth, fid_offset=fid_offset, feature_offset=feature_offset, nfl_offset=nfl_offset, batch_size=batch_size, num_out=N, variant_type=variant_type, feature_cfgs=feature_cfgs_str, ps_num=ps_num, parallel_flag=parallel_flag) elif version == 3: layout_tensors = gen_distribution_ops.monolith_embedding_to_layout_v3( embeddings_list=embeddings_list, fid_offset=fid_offset, feature_offset=feature_offset, nfl_offset=nfl_offset, batch_size=batch_size, num_out=N, variant_type=variant_type, feature_cfgs=feature_cfgs_str, ps_num=ps_num, parallel_flag=parallel_flag) elif version == 2: layout_tensors = gen_distribution_ops.monolith_embedding_to_layout_v2( embeddings_list=embeddings_list, fid_list_row_split=fid_list_row_split, fid_offset=fid_offset, feature_offset=feature_offset, nfl_offset=nfl_offset, batch_size=batch_size, num_out=N, variant_type=variant_type, feature_cfgs=feature_cfgs_str, ps_num=ps_num, parallel_flag=parallel_flag) else: layout_tensors = gen_distribution_ops.monolith_embedding_to_layout( embeddings_list=embeddings_list, fid_offset=fid_offset, feature_offset=feature_offset, nfl_offset=nfl_offset, batch_size=batch_size, num_out=N, variant_type=variant_type, feature_cfgs=feature_cfgs_str) return layout_tensors @tf.RegisterGradient("MonolithEmbeddingToLayout") def _fused_embedding_to_layout_grad(op: tf.Operation, *grads): M = op.get_attr("M") # fid_num embeddings_list = op.inputs[0:M] fid_offset, feature_offset, nfl_offset, batch_size = op.inputs[M], op.inputs[ M + 1], op.inputs[M + 2], op.inputs[M + 3] variant_type = op.get_attr("variant_type") feature_cfgs_str = op.get_attr("feature_cfgs") embeddings_grad_list = gen_distribution_ops.monolith_embedding_to_layout_grad( embeddings_list=embeddings_list, fid_offset=fid_offset, feature_offset=feature_offset, nfl_offset=nfl_offset, batch_size=batch_size, tensors_grad=grads, variant_type=variant_type, feature_cfgs=feature_cfgs_str) return embeddings_grad_list + [None] * 4 @tf.RegisterGradient("MonolithEmbeddingToLayoutV2") def _fused_embedding_to_layout_grad_v2(op: tf.Operation, *grads): M = op.get_attr("M") # fid_num pre = 0 embeddings_list = op.inputs[0:M] pre += M fid_list_row_split = op.inputs[pre:pre + M] pre += M fid_offset, feature_offset, nfl_offset, batch_size = op.inputs[ pre], op.inputs[pre + 1], op.inputs[pre + 2], op.inputs[pre + 3] variant_type = op.get_attr("variant_type") feature_cfgs_str = op.get_attr("feature_cfgs") ps_num = op.get_attr("ps_num") embeddings_grad_list = gen_distribution_ops.monolith_embedding_to_layout_grad_v2( embeddings_list=embeddings_list, fid_list_row_split=fid_list_row_split, fid_offset=fid_offset, feature_offset=feature_offset, nfl_offset=nfl_offset, batch_size=batch_size, tensors_grad=grads, variant_type=variant_type, feature_cfgs=feature_cfgs_str, ps_num=ps_num, parallel_flag=0) return embeddings_grad_list + [None] * (M + 4) @tf.RegisterGradient("MonolithEmbeddingToLayoutV3") def _fused_embedding_to_layout_grad_v3(op: tf.Operation, *grads): M = op.get_attr("M") # fid_num pre = 0 embeddings_list = op.inputs[0:M] pre += M fid_offset, feature_offset, nfl_offset, batch_size = op.inputs[ pre], op.inputs[pre + 1], op.inputs[pre + 2], op.inputs[pre + 3] variant_type = op.get_attr("variant_type") feature_cfgs_str = op.get_attr("feature_cfgs") ps_num = op.get_attr("ps_num") embeddings_grad_list = gen_distribution_ops.monolith_embedding_to_layout_grad_v3( embeddings_list=embeddings_list, fid_offset=fid_offset, feature_offset=feature_offset, nfl_offset=nfl_offset, batch_size=batch_size, tensors_grad=grads, variant_type=variant_type, feature_cfgs=feature_cfgs_str, ps_num=ps_num, parallel_flag=0) return embeddings_grad_list + [None] * 4 @tf.RegisterGradient("MonolithEmbeddingToLayoutV4") def _fused_embedding_to_layout_grad_v4(op: tf.Operation, *grads): M = op.get_attr("M") # fid_num pre = 0 embeddings_list = op.inputs[0:M] pre += M fid_list_emb_row_lenth = op.inputs[pre] pre += 1 fid_offset, feature_offset, nfl_offset, batch_size = op.inputs[ pre], op.inputs[pre + 1], op.inputs[pre + 2], op.inputs[pre + 3] variant_type = op.get_attr("variant_type") feature_cfgs_str = op.get_attr("feature_cfgs") ps_num = op.get_attr("ps_num") embeddings_grad_list = gen_distribution_ops.monolith_embedding_to_layout_grad_v4( embeddings_list=embeddings_list, fid_list_emb_row_lenth=fid_list_emb_row_lenth, fid_offset=fid_offset, feature_offset=feature_offset, nfl_offset=nfl_offset, batch_size=batch_size, tensors_grad=grads, variant_type=variant_type, feature_cfgs=feature_cfgs_str, ps_num=ps_num, parallel_flag=0) return embeddings_grad_list + [None] * 5 @tf.RegisterGradient("MonolithEmbeddingToLayoutV5") def _fused_embedding_to_layout_grad_v5(op: tf.Operation, *grads): M = op.get_attr("M") # fid_num pre = 0 embeddings_list = op.inputs[0:M] pre += M fid_offset, feature_offset, nfl_offset, batch_size = op.inputs[ pre], op.inputs[pre + 1], op.inputs[pre + 2], op.inputs[pre + 3] variant_type = op.get_attr("variant_type") feature_cfgs_str = op.get_attr("feature_cfgs") ps_num = op.get_attr("ps_num") embeddings_grad_list = gen_distribution_ops.monolith_embedding_to_layout_grad_v3( embeddings_list=embeddings_list, fid_offset=fid_offset, feature_offset=feature_offset, nfl_offset=nfl_offset, batch_size=batch_size, tensors_grad=grads, variant_type=variant_type, feature_cfgs=feature_cfgs_str, ps_num=ps_num, parallel_flag=0) return embeddings_grad_list + [None] * 8 def fused_embedding_to_layout_grad( nfl_offset: tf.Tensor, feature_offset: tf.Tensor, fid_offset: tf.Tensor, batch_size: tf.Tensor, embeddings_list: List[tf.Tensor], fid_list_row_split: List[tf.Tensor], layout_tensors_grad: List[tf.Tensor], variant_type: str, feature_cfgs: FeatureConfigs, ps_num: int, fid_list_emb_row_lenth: tf.Tensor = None, layout_tensors_grad_scale=None, parallel_flag=0, version: int = 3, ) -> List[tf.Tensor]: feature_cfgs_str = feature_cfgs.SerializeToString() assert variant_type in { 'example', 'example_batch', 'examplebatch', 'instance' } variant_type = 'example_batch' if variant_type == 'examplebatch' else variant_type if layout_tensors_grad_scale is not None: logging.info( f"fused_embedding_to_layout_grad use layout_tensors_grad_scale") #TODO fuse layout_tensors_grad_scale to op layout_tensors_grad = [ layout_tensors_grad_scale * grad for grad in layout_tensors_grad ] if version != 4: assert fid_list_emb_row_lenth is None if version == 4: assert fid_list_emb_row_lenth is not None embeddings_grad_list = gen_distribution_ops.monolith_embedding_to_layout_grad_v4( embeddings_list=embeddings_list, fid_list_emb_row_lenth=fid_list_emb_row_lenth, fid_offset=fid_offset, feature_offset=feature_offset, nfl_offset=nfl_offset, batch_size=batch_size, tensors_grad=layout_tensors_grad, variant_type=variant_type, feature_cfgs=feature_cfgs_str, ps_num=ps_num, parallel_flag=parallel_flag) elif version == 3 or version == 5: embeddings_grad_list = gen_distribution_ops.monolith_embedding_to_layout_grad_v3( embeddings_list=embeddings_list, fid_offset=fid_offset, feature_offset=feature_offset, nfl_offset=nfl_offset, batch_size=batch_size, tensors_grad=layout_tensors_grad, variant_type=variant_type, feature_cfgs=feature_cfgs_str, ps_num=ps_num, parallel_flag=parallel_flag) elif version == 2: embeddings_grad_list = gen_distribution_ops.monolith_embedding_to_layout_grad_v2( embeddings_list=embeddings_list, fid_list_row_split=fid_list_row_split, fid_offset=fid_offset, feature_offset=feature_offset, nfl_offset=nfl_offset, batch_size=batch_size, tensors_grad=layout_tensors_grad, variant_type=variant_type, feature_cfgs=feature_cfgs_str, ps_num=ps_num, parallel_flag=parallel_flag) else: embeddings_grad_list = gen_distribution_ops.monolith_embedding_to_layout_grad( embeddings_list=embeddings_list, fid_offset=fid_offset, feature_offset=feature_offset, nfl_offset=nfl_offset, batch_size=batch_size, tensors_grad=layout_tensors_grad, variant_type=variant_type, feature_cfgs=feature_cfgs_str) return embeddings_grad_list @tf.RegisterGradient("MonolithMapIdToEmbedding") def _map_id_to_embedding_gradient(op: tf.Operation, grads: tf.Tensor): num_splits = op.get_attr("num_splits") ids = [op.inputs[i] for i in range(num_splits)] input = op.inputs[2 * num_splits] embedding_grads = gen_distribution_ops.monolith_map_id_to_embedding_gradient( ids, input, grads) return [None] * num_splits + embedding_grads + [None] def map_id_to_embedding_gradient_back_prop(ids: tf.Tensor, input: tf.Tensor, grads: tf.Tensor): """ The manual back prop for MonolithMapIdToEmbedding. Returns: output: A list of 2-D tensors [K, dim], sum(K)=N """ embedding_grads = gen_distribution_ops.monolith_map_id_to_embedding_gradient( ids, input, grads) return embedding_grads def gather_embeddings_by_input(ids: tf.Tensor, embeddings: tf.Tensor, input: tf.Tensor, use_multi_threads: bool = False) -> tf.Tensor: """ Gather embeddings based on input with a shape [N] and an ids:embeddings map. The ids with a shape [M] is mapped element-wise to embeddings with a shape [M, dim], e.g., for any index i, ids(i)'s embedding is embeddings(i). Example: ids: [1, 2, 3] embeddings: [[1., 1.], [2., 2.], [3., 3.]] input: [1, 3, 2, 3] output=>[[1., 1.], [3., 3.], [2., 2.], [3., 3.]] index_mapping=>[0, 2, 1, 2] Args: ids: a 1-D int64 tensor [M]. embeddings: a 2-D float32 tensor [M, dim]. Mapped in order with ids. input: a int32 tensor with shape [N], N >= M. Input value is range from 0 to M-1. Returns: output: a 2-D tensor [N, dim]. index_mapping: a 1-D tensor [N]. """ return gen_distribution_ops.monolith_gather_embeddings_by_input( ids, embeddings, input, use_multi_threads=use_multi_threads) @tf.RegisterGradient("MonolithGatherEmbeddingsByInput") def _gather_embeddings_by_ids_gradient( op: tf.Operation, grads: tf.Tensor, index_mapping_grads: Optional[tf.Tensor]): ids = op.inputs[0] index_mapping = op.outputs[1] embedding_grads = gen_distribution_ops.monolith_gather_embeddings_by_input_gradient( ids, grads, index_mapping) return [None, embedding_grads, None] def fused_gather_embeddings_by_input( fused_embeddings: tf.Tensor, fused_embedding_offsets: List[tf.Tensor], embedding_dims: List[int]) -> List[tf.Tensor]: return gen_distribution_ops.monolith_fused_gather_embeddings_by_input( fused_embeddings, fused_embedding_offsets, embedding_dims=embedding_dims) def fused_gather_embeddings_by_input_gradient( fused_embeddings: tf.Tensor, grads: List[tf.Tensor], embedding_offsets: List[tf.Tensor], embedding_dims: List[int], scale = 1 ) -> tf.Tensor: return gen_distribution_ops.monolith_fused_gather_embeddings_by_input_gradient( fused_embeddings, grads, embedding_offsets, embedding_dims=embedding_dims, scale=scale) def reduce_mean(id_indices: tf.Tensor, id_values: tf.Tensor, id_length: tf.Tensor, name: str = None): """ Very similar to tf.sparse.reduce_mean. The difference is now id_values is a 2-D tensors instead of 1-D tensor. Args: id_indices: 2-D tensor represents a list of positions of id_values. id_values: 2-D tensor which represents a list of actual values. (Value is 1-D tensor) id_length: should be a shape which equals to [batch_size] """ return gen_distribution_ops.monolith_reduce_mean(id_indices, id_values, id_length, name=name) def gather_embeddings_by_ids_gradient_back_prop(ids: tf.Tensor, grads: tf.Tensor, index_mapping: tf.Tensor): """ The manual back prop for MonolithGatherEmbeddingsByInput. Returns: output: a 2-D tensor [N, dim]. """ embedding_grads = gen_distribution_ops.monolith_gather_embeddings_by_input_gradient( ids, grads, index_mapping) return embedding_grads @tf.RegisterGradient("MonolithReduceMean") def _reduce_mean_gradient(op: tf.Operation, grads: tf.Tensor): id_indices = op.inputs[0] id_value_grads = gen_distribution_ops.monolith_reduce_mean_gradient( id_indices, grads) return None, id_value_grads, None def reduce_sum(id_indices: tf.Tensor, id_values: tf.Tensor, id_length: tf.Tensor, name=None): """ Very similar to tf.sparse.reduce_sum. The difference is now id_values is a 2-D tensors instead of 1-D tensor. Args: id_indices: 2-D tensor represents a list of positions of id_values. id_values: 2-D tensor which represents a list of actual values. (Value is 1-D tensor) id_length: should be a shape which equals to [batch_size] """ return gen_distribution_ops.monolith_reduce_sum(id_indices, id_values, id_length, name=name) @tf.RegisterGradient("MonolithReduceSum") def _reduce_sum_gradient(op: tf.Operation, grads: tf.Tensor): id_indices = op.inputs[0] id_value_grads = gen_distribution_ops.monolith_reduce_sum_gradient( id_indices, grads) return None, id_value_grads, None def reduce_sqrtn(id_indices: tf.Tensor, id_values: tf.Tensor, id_length: tf.Tensor): """ Very similar to the combiner method in tf.tpu.experimental.embedding.TPUEmbedding The input is a Args: id_indices: 2-D tensor represents a list of positions of id_values. id_values: 2-D tensor which represents a list of actual values. (Value is 1-D tensor) id_length: should be a shape which equals to [batch_size] """ return gen_distribution_ops.monolith_reduce_square_norm( id_indices, id_values, id_length) @tf.RegisterGradient("MonolithReduceSquareNorm") def _reduce_sum_gradient(op: tf.Operation, grads: tf.Tensor): id_indices = op.inputs[0] id_values = op.inputs[1] id_value_grads = gen_distribution_ops.monolith_reduce_square_norm_gradient( id_indices, id_values, grads) return None, id_value_grads, None def fused_sorted_segment_sum(indices: List[tf.Tensor], values: List[tf.Tensor], shapes: List[tf.Tensor]): """ It combines multiple segment_sum into one GPU kernel. Args: indicies: List of Indices a.k.s 1-D SORTED segment ids values: List of Values to scatter into the output tensor. shapes: List of Shapes that Must have the same type as indices. Output: reduced: output tensors, i-the tensor has a shape `shapes[i]`. """ return gen_distribution_ops.monolith_fused_segment_sum( indices, values, shapes) @tf.RegisterGradient("MonolithFusedSegmentSum") def _FusedSegmentSumGrad(op, *grads): n = len(grads) updates_grads = [ tf.gather_nd(grad, tf.expand_dims(indices, -1)) # Similar to fused ScatterNd for grad, indices in zip(grads, op.inputs[:n]) ] return [None] * n + updates_grads + [None] * n def fused_reduce_sum_and_split(id_indices: tf.Tensor, id_values: tf.Tensor, id_length: tf.Tensor, split_dims: List[int], name: str = None): """ Very similar to tf.sparse.reduce_sum. It combines with a fused split op. Args: id_indices: 1-D tensor represents a list of positions of id_values. id_values: 2-D tensor which represents a list of actual values. (Value is 1-D tensor) id_length: should be a shape which equals to "batch_size" split_dims: dimensions for the split vectors. Sum(split_dims)=id_values.dim(1) Output: reduced: M output tensors, and i-th tensor has a shape [bs, split_dims[i]]. """ id_indices = tf.expand_dims(id_indices, -1) id_length = tf.cast(tf.expand_dims(id_length, 0), dtype=tf.int64) # To remove cast support int32 for cpu num_of_splits = len(split_dims) return gen_distribution_ops.monolith_fused_reduce_sum_and_split(id_indices, id_values, id_length, num_of_splits, split_dims, name=name) @tf.RegisterGradient("MonolithFusedReduceSumAndSplit") def _fused_reduce_sum_and_split_gradient(op: tf.Operation, *grads): id_indices = op.inputs[0] split_dims = op.get_attr("split_dims") id_value_grads = gen_distribution_ops.monolith_fused_reduce_sum_and_split_gradient( id_indices, grads, split_dims=split_dims) return None, id_value_grads, None def fused_reduce_and_split_gpu(splits: List[tf.Tensor], embeddings: List[tf.Tensor], slice_dims: List[List[int]], name: str = None) -> List[tf.Tensor]: """ Output: Args: splits: list of N 'row_splits' attribute of fid ragged tensors embeddings: list of N embeddings slice_dims: list of N slice_dims Output: reduced: M output tensors, and i-th tensor has a shape [bs, flat_slice_dims[i]]. where flat_slice_dims=concat(slice_dims), and M=len(flat_slice_dims) """ flat_slice_dims = [] row_split_splits = [] row_split_idx = 0 for i in range(len(slice_dims)): s = slice_dims[i] flat_slice_dims.extend(s) row_split_splits.append(row_split_idx) row_split_idx += splits[i].shape[0] row_split_splits.append(row_split_idx) with tf.device("/device:CPU:0"): fused_splits = tf.cast(tf.concat(splits, 0), tf.int32) return gen_distribution_ops.monolith_fused_reduce_and_split_gpu( fused_splits, embeddings, slice_dims=flat_slice_dims, num_slices=len(flat_slice_dims), row_split_splits=row_split_splits, name=name) @tf.RegisterGradient("MonolithFusedReduceAndSplitGPU") def _fused_reduce_and_split_gpu_grad(op: tf.Operation, *grads): row_split_splits = op.get_attr('row_split_splits') slice_dims = op.get_attr('slice_dims') return [None] + gen_distribution_ops.monolith_fused_reduce_and_split_gpu_grad( op.inputs[0], op.inputs[1:len(row_split_splits)], grads, row_split_splits=row_split_splits, slice_dims=slice_dims ) def normalize_merged_split(row_split: tf.Tensor, row_split_size: tf.Tensor) -> tf.Tensor: return gen_distribution_ops.monolith_normalize_merged_split(row_split, row_split_size) ================================================ FILE: monolith/native_training/distribution_ops_benchmark.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import shutil import time import tensorflow as tf from monolith.native_training import distribution_ops class DistributionOpsBenchmarkTest(tf.test.TestCase): def map_id_to_embedding(self, use_multi_threads): log_dir = "/tmp/distribution_ops_benchmark/map_id_to_embedding{}".format( "_multi_threads" if use_multi_threads else "") if os.path.exists(log_dir): shutil.rmtree(log_dir) options = tf.profiler.experimental.ProfilerOptions(host_tracer_level=3, python_tracer_level=0, device_tracer_level=0) tf.profiler.experimental.start(log_dir, options=options) num_elements, dim, ps_num = 1000000, 16, 10 ids = tf.constant([x for x in range(num_elements)], dtype=tf.int64) embeddings = tf.constant( [[x for x in range(dim)] for _ in range(num_elements)], dtype=tf.float32) indices = tf.math.floormod(ids, ps_num) split_ids = distribution_ops.split_by_indices(indices, ids, ps_num) split_embeddings = distribution_ops.split_by_indices( indices, embeddings, ps_num) embeddings_mapped = distribution_ops.map_id_to_embedding( split_ids, split_embeddings, ids, use_multi_threads=use_multi_threads) self.assertAllEqual(embeddings, embeddings_mapped) tf.profiler.experimental.stop() def test_gather_embeddings_by_ids_basic(self): num_features = 100000 with tf.compat.v1.Session() as sess: embeddings = tf.ones([num_features, 32]) id_tensor = tf.constant([x for x in range(num_features)], dtype=tf.int64) input = tf.constant([[y % num_features for y in range(x, x + 4)] for x in range(num_features)], dtype=tf.int64) output = distribution_ops.gather_embeddings_by_input( id_tensor, embeddings, input) start = time.time() output = sess.run(output) total_wall_time = time.time() - start print('wall time: {}'.format(total_wall_time)) with tf.compat.v1.Session() as sess: embeddings = tf.ones([num_features, 256]) id_tensor = tf.constant([x for x in range(num_features)], dtype=tf.int64) input = tf.constant([[y % num_features for y in range(x, x + 2)] for x in range(num_features)], dtype=tf.int64) output = distribution_ops.gather_embeddings_by_input( id_tensor, embeddings, input) start = time.time() output = sess.run(output) total_wall_time = time.time() - start print('wall time: {}'.format(total_wall_time)) def test_gather_embeddings_by_ids_multi_threads(self): num_features = 100000 with tf.compat.v1.Session() as sess: embeddings = tf.ones([num_features, 32]) id_tensor = tf.constant([x for x in range(num_features)], dtype=tf.int64) input = tf.constant([[y % num_features for y in range(x, x + 4)] for x in range(num_features)], dtype=tf.int64) output = distribution_ops.gather_embeddings_by_input( id_tensor, embeddings, input, use_multi_threads=True) start = time.time() output = sess.run(output) total_wall_time = time.time() - start print('wall time: {}'.format(total_wall_time)) with tf.compat.v1.Session() as sess: embeddings = tf.ones([num_features, 256]) id_tensor = tf.constant([x for x in range(num_features)], dtype=tf.int64) input = tf.constant([[y % num_features for y in range(x, x + 2)] for x in range(num_features)], dtype=tf.int64) output = distribution_ops.gather_embeddings_by_input( id_tensor, embeddings, input, use_multi_threads=True) start = time.time() output = sess.run(output) total_wall_time = time.time() - start print('wall time: {}'.format(total_wall_time)) def test_map_id_to_embedding(self): self.map_id_to_embedding(False) def test_map_id_to_embedding_multi_threads(self): self.map_id_to_embedding(True) if __name__ == "__main__": tf.compat.v1.disable_eager_execution() tf.test.main() ================================================ FILE: monolith/native_training/distribution_ops_fused_benchmark.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import shutil import time import numpy as np import tensorflow as tf from tensorflow.python.framework import ops from monolith.native_training import distribution_ops def run_fused_reorder_by_indicies(suffices=None): # Generate num_slots of lists of int64, where number of unique ids is around num_ids num_ids = int(1e6) num_slots = 30 num_of_shards = 256 int64 = np.iinfo(np.int64) ids = list( set(np.random.randint(low=int64.min, high=int64.max + 1, size=num_ids))) split_indicies = [0] + sorted(np.random.choice(num_ids, num_slots)) ids_list = [] for i in range(1, len(split_indicies)): slot_ids = ids[split_indicies[i - 1]:split_indicies[i]] slot_ids = np.concatenate([slot_ids, slot_ids]) # force dups np.random.shuffle(slot_ids) ids_list.append(slot_ids) # input: ids_list session_config = tf.compat.v1.ConfigProto() session_config.graph_options.rewrite_options.disable_meta_optimizer = False session_config.graph_options.rewrite_options.memory_optimization = 1 session_config.intra_op_parallelism_threads = 4 with tf.compat.v1.Session(config=session_config) as sess: ids_list = [ops.convert_to_tensor(ids, dtype=tf.int64) for ids in ids_list] reorder_op = distribution_ops.fused_reorder_by_indices( ids_list, num_of_shards=num_of_shards) start = time.time() _ = sess.run(reorder_op) return time.time() - start if __name__ == "__main__": tf.compat.v1.disable_eager_execution() # np.random.seed(1234) print('> Sess.run Wall Time:', np.average([run_fused_reorder_by_indicies() for _ in range(5)])) ================================================ FILE: monolith/native_training/distribution_ops_fused_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import numpy as np import tensorflow as tf import itertools from monolith.native_training import distribution_ops class DistributionOpsTest(tf.test.TestCase): def test_benchmark(self): num_tables = 128 with tf.compat.v1.Session() as sess: ids_list = [tf.random.uniform((10240,), 0, 1<<31, dtype=tf.int64) for i in range(num_tables)] reorder_op = distribution_ops.fused_reorder_by_indices( ids_list, num_of_shards=8, dim_sizes=[16 for _ in range(len(ids_list))]) sess.run(reorder_op) def _test_fused_reorder_by_indices(self, ids_list, shard_num, expected_output, expected_split_sizes, expected_sharded_slot_sizes, dim_sizes=None, expected_embedding_offsets=None): if dim_sizes is None: # Fake dim_sizes for testing dim_sizes = [2 for _ in range(len(ids_list))] with tf.compat.v1.Session() as sess: ids_list = [tf.convert_to_tensor(ids, dtype=tf.int64) for ids in ids_list] reorder_op = distribution_ops.fused_reorder_by_indices( ids_list, num_of_shards=shard_num, dim_sizes=dim_sizes) print('>>>', reorder_op[4]) output, split_sizes, sharded_slot_sizes, _, embedding_offsets = sess.run( reorder_op) self.assertAllEqual(output, expected_output) self.assertAllEqual(split_sizes, expected_split_sizes) self.assertAllEqual(sharded_slot_sizes, expected_sharded_slot_sizes) if expected_embedding_offsets: self.assertAllEqual(embedding_offsets, list(itertools.chain(*expected_embedding_offsets))) def test_fused_reorder_by_indices(self): # ids_list, shard_num # expected_output, expected_split_sizes, expected_sharded_slot_sizes self._test_fused_reorder_by_indices( # Fallback to original reorder_by_indices, # but keeping the inner-merged-slot order [[0, 1, 2, 2, 3, 5]], 3, [0, 3, 1, 2, 5], [2, 1, 2], [2, 1, 2]) self._test_fused_reorder_by_indices( # Extra slot [[0, 1, 2, 2, 3, 5], []], 3, [0, 3, 1, 2, 5], [2, 1, 2], [2, 0, 1, 0, 2, 0]) self._test_fused_reorder_by_indices( # plus 2*shard_num [[0, 1, 2, 2, 3, 5], [6, 7, 8, 8, 9, 11]], 3, [0, 3, 6, 9, 1, 7, 2, 5, 8, 11], [4, 2, 4], [2, 2, 1, 1, 2, 2]) self._test_fused_reorder_by_indices( # Empty slots [[], []], 2, [], [0, 0], [0, 0, 0, 0]) self._test_fused_reorder_by_indices([[0, 1, 4, 5], [2, 3, 6, 7]], 2, [0, 4, 2, 6, 1, 5, 3, 7], [4, 4], [2, 2, 2, 2]) self._test_fused_reorder_by_indices([[0, 1, 0], [3, 2, 3], [5, 6, 7]], 2, [0, 2, 6, 1, 3, 5, 7], [3, 4], [1, 1, 1, 1, 1, 2], dim_sizes=[1, 2, 3], expected_embedding_offsets=[[0, 6, 0], [7, 1, 7], [9, 3, 12]]) self._test_fused_reorder_by_indices( # Imagine the expected fused_embeddings as follows: # [1.1, 1.2, 1.3, # 3 # slot 0, dim 3, offset 0 # 2.1, 2.2, # 6 # slot 1, dim 2, offset 3 # 3.1, 3.2, 3.3, # 1 # slot 0, dim 3, offset 5 # 4.1, 4.2, 4.3, # 7 # slot 0, dim 3, offset 8 # 5.1, 5.2, # 4 # slot 1, dim 2, offset 11 # 6.1, 6.2, 6.3, # 2 # slot 0, dim 3, offset 13 # 7.1, 7.2, # 5 # slot 1, dim 2, offset 16 # 8.1, 8.2, # 8 # slot 1, dim 2, offset 18 # 9.1, 9.2], # 11 # slot 1, dim 2, offset 20 [[2, 3, 1, 2, 7, 2], [5, 8, 4, 4, 5, 11, 6]], 3, [3, 6, 1, 7, 4, 2, 5, 8, 11], [2, 3, 4], [1, 1, 2, 1, 1, 3], dim_sizes=[3, 2], expected_embedding_offsets=[[13, 0, 5, 13, 8, 13], [16, 18, 11, 11, 16, 20, 3]]) def test_ragged_tensor_workflow(self): with tf.Graph().as_default(): a = tf.RaggedTensor.from_tensor(tf.constant([[0], [1]], dtype=tf.int64)) b = tf.RaggedTensor.from_tensor(tf.constant([[2], [3]], dtype=tf.int64)) c = tf.RaggedTensor.from_tensor(tf.constant([[4], [5]], dtype=tf.int64)) d = tf.RaggedTensor.from_tensor(tf.constant([[6], [7]], dtype=tf.int64)) # Currently for merged slots A, B # the order ['A', 'B'] is based on merged_slot_to_config; # the mapping is based on MergedMultiTypeHashTable._slot_mapping: {'a': 'A', 'b': 'B', 'c': 'A', 'd', 'B'} merged_slot_values = [ tf.concat([a.values, c.values], 0), tf.concat([b.values, d.values], 0) ] self._test_fused_reorder_by_indices(merged_slot_values, 2, [0, 4, 2, 6, 1, 5, 3, 7], [4, 4], [2, 2, 2, 2]) if __name__ == "__main__": tf.compat.v1.disable_eager_execution() tf.test.main() ================================================ FILE: monolith/native_training/distribution_ops_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import numpy as np import tensorflow as tf from tensorflow.python.framework import test_util from monolith.native_training import distribution_ops import random class DistributionOpsTest(tf.test.TestCase): def test_split_by_indices(self): with tf.compat.v1.Session() as sess: ids = tf.constant([0, 1, 2, 2, 3], dtype=tf.int64) indices = tf.math.floormod(ids, 3) splits = distribution_ops.split_by_indices(indices, ids, num_splits=3) splits = sess.run(splits) expected_splits = [[0, 3], [1], [2, 2]] for split, expected_split in zip(splits, expected_splits): self.assertAllEqual(split, expected_split) def test_reorder_by_indices(self): with tf.compat.v1.Session() as sess: ids = tf.constant([0, 1, 2, 2, 3, 5], dtype=tf.int64) indices = tf.cast(tf.math.floormod(ids, 3), dtype=tf.int32) reorder_op = distribution_ops.reorder_by_indices(ids, indices, num_of_shards=3) output, split_sizes = sess.run(reorder_op) expected_output = [3, 0, 1, 5, 2] expected_split_sizes = [2, 1, 2] self.assertAllEqual(output, expected_output) self.assertAllEqual(split_sizes, expected_split_sizes) def test_split_by_indices_gradient(self): with self.session() as sess: indices = tf.constant([0, 1, 0], dtype=tf.int64) tensor = tf.constant([[0, 0], [1, 1], [2, 2]], dtype=tf.float32) splits = distribution_ops.split_by_indices(indices, tensor, num_splits=3) grad = tf.gradients(splits, tensor)[0] grad = sess.run(grad) self.assertAllEqual(grad, [[1, 1], [1, 1], [1, 1]]) def test_split_by_indices_empty_gradient(self): with self.session() as sess: indices = tf.constant([], dtype=tf.int64) tensor = tf.constant([], dtype=tf.float32) splits = distribution_ops.split_by_indices(indices, tensor, num_splits=3) grad, = tf.gradients(splits, tensor) grad = sess.run(grad) self.assertAllEqual(grad, []) def test_ragged_split_by_indices(self): with self.session() as sess: indices = tf.constant([0, 1, 0, 1], dtype=tf.int64) num = tf.ragged.constant([[], [], [4, 3, 2], [1], [], []], dtype=tf.int64) splits, pos = distribution_ops.ragged_split_by_indices(indices, num, num_splits=2) splits, pos = sess.run([splits, pos]) expected_splits = ( [[], [], [4, 2], [], [], []], [[], [], [3], [1], [], []], ) for split, expected_split in zip(splits, expected_splits): self.assertAllEqual(split, expected_split) expected_pos = ( [[], [], [0, 2], [], [], []], [[], [], [1], [3], [], []], ) for p1, p2 in zip(pos, expected_pos): self.assertAllEqual(p1, p2) def test_unique_key_with_value_and_offset_and_fill_with_offset_map(self): key = tf.ragged.constant([[], [0, 1, 2, 1, 0], [0, 1, 0], []], dtype=tf.int64) dims = [1, 2, 3, 4] result = distribution_ops.unique_key_with_value_and_offset(key, dims) self.assertAllEqual(result.unique_key, [[], [0, 1, 2], [0, 1], []]) self.assertAllEqual(result.value_offset, [[], [[0, 8], [2, 6], [4]], [[10, 16], [13]], []]) value = tf.range(12, dtype=tf.float32) filled_tensor = distribution_ops.fill_with_offset_map( tf.ragged.constant([[], [0, 1, 2], [3, 4], []], dtype=tf.int64), value, result.value_offset, result.value_buffer, dims) buffer = distribution_ops.finalize_shared_tensor([filled_tensor], dtype=tf.float32, shape=[None]) self.assertAllEqual( buffer, [0, 1, 2, 3, 4, 5, 2, 3, 0, 1, 6, 7, 8, 9, 10, 11, 6, 7, 8]) grad, = tf.gradients([buffer], [value], [tf.range(19, dtype=tf.float32)]) self.assertAllEqual(grad, [8, 10, 8, 10, 4, 5, 26, 28, 30, 13, 14, 15]) def test_fill_with_offset_map_error_case(self): key = tf.ragged.constant([[], [0, 1, 2, 1, 0], [0, 1, 0], []], dtype=tf.int64) dims = [1, 2, 3, 4] result = distribution_ops.unique_key_with_value_and_offset(key, dims) value = tf.range(10, dtype=tf.float32) # expected size: 12 filled_tensor = distribution_ops.fill_with_offset_map( tf.ragged.constant([[], [0, 1, 2], [3, 4], []], dtype=tf.int64), value, result.value_offset, result.value_buffer, dims) with self.assertRaises(tf.errors.InvalidArgumentError): self.evaluate(filled_tensor) def test_unique_key_with_value_and_offset_empty(self): key = tf.ragged.constant([[], [], []], dtype=tf.int64) result = distribution_ops.unique_key_with_value_and_offset(key, [1, 2, 3]) self.assertAllEqual(result.unique_key, [[], [], []]) self.assertAllEqual(result.value_offset, [[], [], []]) def test_map_id_to_embedding(self): with tf.compat.v1.Session() as sess: ids1 = tf.constant([1], dtype=tf.int64) embeddings1 = tf.constant([[1, 1]], dtype=tf.float32) ids2 = tf.constant([2], dtype=tf.int64) embeddings2 = tf.constant([[2, 2]], dtype=tf.float32) input = tf.constant([[1], [2]], dtype=tf.int64) output = distribution_ops.map_id_to_embedding([ids1, ids2], [embeddings1, embeddings2], input, use_multi_threads=False) output = sess.run(output) self.assertAllEqual(output, [[[1, 1]], [[2, 2]]]) def test_map_id_to_embedding_multi_threads(self): with tf.compat.v1.Session() as sess: num_elements, dim, ps_num = 1000, 16, 10 ids = tf.constant([x for x in range(num_elements)], dtype=tf.int64) embeddings = tf.constant( [[x for x in range(dim)] for _ in range(num_elements)], dtype=tf.float32) indices = tf.math.floormod(ids, ps_num) split_ids = distribution_ops.split_by_indices(indices, ids, ps_num) split_embeddings = distribution_ops.split_by_indices( indices, embeddings, ps_num) embeddings_mapped = distribution_ops.map_id_to_embedding( split_ids, split_embeddings, ids, use_multi_threads=True) embeddings = sess.run(embeddings) embeddings_mapped = sess.run(embeddings_mapped) self.assertAllEqual(embeddings, embeddings_mapped) def test_map_id_to_embedding_gradient(self): with self.session() as sess: ids1 = tf.constant([1], dtype=tf.int64) embeddings1 = tf.constant([[0, 0]], dtype=tf.float32) ids2 = tf.constant([2], dtype=tf.int64) embeddings2 = tf.constant([[0, 0]], dtype=tf.float32) input = tf.constant([1, 1, 2], dtype=tf.int64) output = distribution_ops.map_id_to_embedding([ids1, ids2], [embeddings1, embeddings2], input, use_multi_threads=False) target_output = tf.constant([[2, 2], [2, 2], [2, 2]], dtype=tf.float32) loss = target_output - output grads = tf.gradients(loss, [embeddings1, embeddings2]) grads = sess.run(grads) expected_grads = [[[-2, -2]], [[-1, -1]]] for grads_part, expexted_grads_part in zip(grads, expected_grads): self.assertAllEqual(grads_part, expexted_grads_part) def test_gather_embeddings_by_ids(self): with tf.compat.v1.Session() as sess: ids = tf.constant([1, 2, 3], dtype=tf.int64) embeddings = tf.constant([[1, 1], [2, 2], [3, 3]], dtype=tf.float32) input = tf.constant([[2], [1], [2]], dtype=tf.int64) output = distribution_ops.gather_embeddings_by_input( ids, embeddings, input) output, index_mapping = sess.run(output) self.assertAllEqual(output, [[[2, 2]], [[1, 1]], [[2, 2]]]) self.assertAllEqual(index_mapping, [[1], [0], [1]]) def test_gather_embeddings_by_ids_gradient(self): with self.session() as sess: ids = tf.constant([1, 2, 3], dtype=tf.int64) embeddings = tf.constant([[1, 1], [2, 2], [3, 3]], dtype=tf.float32) input = tf.constant([[1], [2], [1]], dtype=tf.int64) output, index_mapping = distribution_ops.gather_embeddings_by_input( ids, embeddings, input) target_output = tf.constant([[[2, 2]], [[2, 2]], [[2, 2]]], dtype=tf.float32) loss = target_output - output grads = tf.gradients(loss, embeddings) grads = sess.run(grads) expected_grads = [[-2, -2], [-1, -1], [0, 0]] self.assertAllEqual(grads[0], expected_grads) def test_gather_embeddings_by_ids_gradient_back_prop(self): with self.session() as sess: ids = tf.constant([2, 3, 1], dtype=tf.int64) grads = tf.constant([[1, 1], [2, 2], [4, 4], [8, 8]], dtype=tf.float32) # implies the input tensor with id value [3, 2, 3, 1] index_mapping = tf.constant([1, 0, 1, 2], dtype=tf.int64) emb_grads = distribution_ops.gather_embeddings_by_ids_gradient_back_prop( ids, grads, index_mapping) self.assertAllEqual(emb_grads, [[2, 2], [5, 5], [8, 8]]) @test_util.run_gpu_only def test_fused_gather_embeddings_by_input(self): with tf.compat.v1.Session() as sess, test_util.use_gpu(): # inputs = [ # tf.constant([2, 3, 1, 2, 7, 2], dtype=tf.int64), # tf.constant([5, 8, 4, 4, 5, 11, 6], dtype=tf.int64) # ] # shard_indices: [[2, 0, 1, 2, 1, 2], [2, 2, 1, 1, 2, 2, 0]] # fused_ids: [3, 6, 1, 7, 4, 2, 5, 8, 11] # fused_slot_sizes: [1, 1, 2, 1, 1, 3] embedding_dims = [3, 2] fused_embeddings = tf.constant([ 1.1, 1.2, 1.3, 2.1, 2.2, 3.1, 3.2, 3.3, 4.1, 4.2, 4.3, 5.1, 5.2, 6.1, 6.2, 6.3, 7.1, 7.2, 8.1, 8.2, 9.1, 9.2 ], dtype=tf.float32) SCALE = (12345, 11777 ) # To test the number of elements larger than GPU grid fused_embedding_offsets = [ tf.constant([13, 0, 5, 13, 8, 13] * SCALE[0], dtype=tf.int32), tf.constant([16, 18, 11, 11, 16, 20, 3] * SCALE[1], dtype=tf.int32) ] output = distribution_ops.fused_gather_embeddings_by_input( fused_embeddings, fused_embedding_offsets, embedding_dims) outputs = sess.run(output) expected_outputs = [[[6.1, 6.2, 6.3], [1.1, 1.2, 1.3], [3.1, 3.2, 3.3], [6.1, 6.2, 6.3], [4.1, 4.2, 4.3], [6.1, 6.2, 6.3]] * SCALE[0], [[7.1, 7.2], [8.1, 8.2], [5.1, 5.2], [5.1, 5.2], [7.1, 7.2], [9.1, 9.2], [2.1, 2.2]] * SCALE[1]] self.assertAllClose(outputs, expected_outputs) def test_fused_gather_embeddings_by_input_gradient(self): with tf.compat.v1.Session() as sess, test_util.use_gpu(): # The size of one-dimensional fused_embeddings. with tf.device("CPU:0"): fused_embeddings_size = tf.constant(22, dtype=tf.int32) embedding_dims = [3, 2] SCALE = 888 # To test float sum precision loss on CPU and GPU grads = [ tf.constant([[1.1, 1.2, 1.3], [2.1, 2.2, 2.3], [3.1, 3.2, 3.3], [4.1, 4.2, 4.3], [5.1, 5.2, 5.3], [6.1, 6.2, 6.3]] * SCALE, dtype=tf.float32), tf.constant([[1.4, 1.5], [2.4, 2.5], [3.4, 3.5], [4.4, 4.5], [5.4, 5.5], [6.4, 6.5], [7.4, 7.5]] * SCALE, dtype=tf.float32) ] embedding_offsets = [ tf.constant([13, 0, 5, 13, 8, 13] * SCALE, dtype=tf.int32), tf.constant([16, 18, 11, 11, 16, 20, 3] * SCALE, dtype=tf.int32) ] output_t = distribution_ops.fused_gather_embeddings_by_input_gradient( fused_embeddings_size, grads, embedding_offsets, embedding_dims) self.assertAllEqual(output_t.shape[0], 22) # shape inference when applicable output = sess.run(output_t) expected_output = [ 2.1, 2.2, 2.3, # id 3 offset 0 7.4, 7.5, # id 6 offset 3 3.1, 3.2, 3.3, # id 1 offset 5 5.1, 5.2, 5.3, # id 7 offset 8 7.8, 8.0, # id 4 offset 11 11.3, 11.6, 11.9, # id 2 offset 13 6.8, 7.0, # id 5 offset 16 2.4, 2.5, # id 8 offset 18 6.4, 6.5, # id 11 offset 20 ] self.assertAllClose(output, np.asarray(expected_output) * SCALE, rtol=1e-7 * SCALE) def test_reduce_mean(self): with tf.compat.v1.Session() as sess: id_indices = tf.constant([[0], [0], [1]], dtype=tf.int64) id_values = tf.constant([[4, 4], [2, 2], [1, 1]], dtype=tf.float32) reduced = distribution_ops.reduce_mean(id_indices, id_values, [2]) reduced = sess.run(reduced) self.assertAllEqual(reduced, [[3, 3], [1, 1]]) def test_reduce_mean_gradient(self): with self.session() as sess: id_indices = tf.constant([[0], [0]], dtype=tf.int64) id_values = tf.constant([[0, 0], [0, 0]], dtype=tf.float32) reduced = distribution_ops.reduce_mean(id_indices, id_values, [1]) target = tf.constant([[-2, -4]], dtype=tf.float32) loss = target - 2 * reduced grads = tf.gradients(loss, id_values)[0] grads = sess.run(grads) self.assertAllEqual(grads, [[-1, -1], [-1, -1]]) def test_reduce_sum(self): with tf.compat.v1.Session() as sess: id_indices = tf.constant([[0], [0], [1]], dtype=tf.int64) id_values = tf.constant([[1, 1], [2, 2], [4, 4]], dtype=tf.float32) reduced = distribution_ops.reduce_sum(id_indices, id_values, [2]) reduced = sess.run(reduced) self.assertAllEqual(reduced, [[3, 3], [4, 4]]) def test_reduce_sum_gradient(self): with self.session() as sess: id_indices = tf.constant([[0], [0]], dtype=tf.int64) id_values = tf.constant([[0, 0], [0, 0]], dtype=tf.float32) reduced = distribution_ops.reduce_sum(id_indices, id_values, [1]) target = tf.constant([[10, 99]], dtype=tf.float32) loss = target - reduced grads = tf.gradients(loss, id_values)[0] grads = sess.run(grads) self.assertAllEqual(grads, [[-1, -1], [-1, -1]]) def test_reduce_sqrtn(self): with tf.compat.v1.Session() as sess: id_indices = tf.constant([[0], [0], [1]], dtype=tf.int64) id_values = tf.constant([[3, 3], [4, 4], [4, 4]], dtype=tf.float32) reduced = distribution_ops.reduce_sqrtn(id_indices, id_values, [2]) reduced = sess.run(reduced) self.assertAllClose(reduced, [[5, 5], [4, 4]]) def test_reduce_sqrtn_gradient(self): with self.session() as sess: id_indices = tf.constant([[0], [0]], dtype=tf.int64) id_values = tf.constant([[3, 4], [4, 3]], dtype=tf.float32) reduced = distribution_ops.reduce_sqrtn(id_indices, id_values, [1]) target = tf.constant([[10, 15]], dtype=tf.float32) loss = target - reduced grads = tf.gradients(loss, id_values)[0] grads = sess.run(grads) self.assertAllClose(grads, [[-0.6, -0.8], [-0.8, -0.6]]) def test_reduce_sqrtn_gradient_zero(self): with self.session() as sess: id_indices = tf.constant([[0], [0]], dtype=tf.int64) id_values = tf.constant([[0, 0], [0, 0]], dtype=tf.float32) reduced = distribution_ops.reduce_sqrtn(id_indices, id_values, [1]) target = tf.constant([[10, 15]], dtype=tf.float32) loss = target - reduced grads = tf.gradients(loss, id_values)[0] grads = sess.run(grads) self.assertAllClose(grads, [[0, 0], [0, 0]]) def test_fused_reduce_sum_and_split(self): # Test split. with tf.compat.v1.Session() as sess, sess.graph.device(lambda op: '/CPU:0'): id_indices = tf.constant([0, 0, 1], dtype=tf.int64) id_values = tf.constant([[1, 1, 1], [2, 2, 1], [4, 4, 2]], dtype=tf.float32) reduced = distribution_ops.fused_reduce_sum_and_split( id_indices, id_values, 2, [2, 1]) reduced = sess.run(reduced) self.assertAllEqual(reduced[0], [[3, 3], [4, 4]]) self.assertAllEqual(reduced[1], [[2], [2]]) # Test a different split type. with tf.compat.v1.Session() as sess, sess.graph.device(lambda op: '/CPU:0'): id_indices = tf.constant([0, 0, 1], dtype=tf.int64) id_values = tf.constant([[1, 1, 1], [2, 2, 1], [4, 4, 2]], dtype=tf.float32) reduced = distribution_ops.fused_reduce_sum_and_split( id_indices, id_values, 2, [1, 2]) reduced = sess.run(reduced) self.assertAllEqual(reduced[0], [[3], [4]]) self.assertAllEqual(reduced[1], [[3, 2], [4, 2]]) # Test non-consecutive indicies with tf.compat.v1.Session() as sess, sess.graph.device(lambda op: '/CPU:0'): id_indices = tf.constant([0, 0, 2], dtype=tf.int64) id_values = tf.constant([[1, 1, 1], [2, 2, 1], [4, 4, 2]], dtype=tf.float32) reduced = distribution_ops.fused_reduce_sum_and_split( id_indices, id_values, 4, [1, 2]) reduced = sess.run(reduced) self.assertAllEqual(reduced[0], [[3], [0], [4], [0]]) self.assertAllEqual(reduced[1], [[3, 2], [0, 0], [4, 2], [0, 0]]) def test_fused_reduce_sum_and_split_grad(self): # Test split. with tf.compat.v1.Session() as sess, sess.graph.device(lambda op: '/CPU:0'): id_indices = tf.constant([0, 0, 1], dtype=tf.int64) id_values = tf.constant([[1, 1, 1], [2, 2, 1], [4, 4, 2]], dtype=tf.float32) reduced_result = distribution_ops.fused_reduce_sum_and_split( id_indices, id_values, 2, [2, 1]) grads = tf.gradients(reduced_result, id_values)[0] grads = sess.run(grads) self.assertAllEqual(grads, [[1, 1, 1], [1, 1, 1], [1, 1, 1]]) @test_util.run_gpu_only def test_fused_reduce_scatter(self): with tf.compat.v1.Session() as sess, test_util.use_gpu(): id_indices = [ tf.constant([0, 0, 1], dtype=tf.int32), tf.constant([0, 0, 1], dtype=tf.int32), tf.constant([], dtype=tf.int32, shape=[0]), tf.constant([0, 0, 2, 2], dtype=tf.int32), ] id_values = [ tf.constant([[1, 1, 1], [2, 2, 1], [4, 4, 2]], dtype=tf.float32), tf.constant([[1, 1, 1], [2, 2, 1], [4, 4, 2]], dtype=tf.float32), tf.constant([], dtype=tf.float32, shape=[0, 3]), tf.constant([[1, 1, 1, 1, 1], [2, 2, 1, 1, 1], [4, 4, 2, 2, 2], [4, 4, 2, 2, 2]], dtype=tf.float32) ] shapes = [(2, 3), (4, 3), (2, 3), (4, 5)] reduced_tensors = distribution_ops.fused_sorted_segment_sum( id_indices, id_values, shapes) truth_tensors = [ tf.scatter_nd(tf.expand_dims(i, -1), v, s) for i, v, s in zip(id_indices, id_values, shapes) ] reduced = sess.run(reduced_tensors) truth = sess.run(truth_tensors) expected = [[[3, 3, 2], [4, 4, 2]], [[3, 3, 2], [4, 4, 2], [0, 0, 0], [0, 0, 0]], [[0, 0, 0], [0, 0, 0]], [[3, 3, 2, 2, 2], [0, 0, 0, 0, 0], [8, 8, 4, 4, 4], [0, 0, 0, 0, 0]]] for r, e, t in zip(reduced, expected, truth): self.assertAllClose(r, e) self.assertAllClose(e, t) # Gradient Check gs_expected = sess.run(tf.gradients(truth_tensors, id_values)) gs = sess.run(tf.gradients(reduced_tensors, id_values)) self.assertAllClose(gs, gs_expected) @test_util.run_gpu_only def test_fused_reduce_and_split_gpu(self): num_rows = 102 batch_size = 256 emb_lens = [i * 2 - 1 for i in range(1, num_rows + 1)] slice_dims = [] for l in emb_lens: if l < 4: slices = [1 for i in range(l)] else: slices = [l // 4] * 4 slices[-1] += l % 4 slice_dims.append(slices) row_lens = [i for i in range(0, batch_size)] random.shuffle(row_lens) rows_before_reduction = sum(row_lens) shapes = [ tf.convert_to_tensor([batch_size, emb_lens[i]], dtype=tf.int64) for i in range(num_rows) ] ragged_tensors = [tf.ragged.range(row_lens) for j in range(num_rows)] value_rowids = [t.value_rowids() for t in ragged_tensors] splits = [t.row_splits for t in ragged_tensors] with tf.compat.v1.Session() as sess, test_util.use_gpu(): embeddings = [ tf.ones((rows_before_reduction, emb_lens[i])) for i in range(num_rows) ] outputs = distribution_ops.fused_reduce_and_split_gpu( splits, embeddings, slice_dims) outputs2 = [] for i in range(num_rows): temp1 = tf.scatter_nd(tf.expand_dims(value_rowids[i], -1), embeddings[i], shapes[i]) outputs2.extend(tf.split(temp1, slice_dims[i], axis=1)) self.assertEqual(len(outputs), len(outputs2)) for i in range(len(outputs)): rand = tf.random.uniform(outputs[i].shape) outputs[i] *= rand outputs2[i] *= rand grads = tf.gradients(outputs, embeddings) grads2 = tf.gradients(outputs2, embeddings) val_flags = [] for i in range(len(outputs)): val_flags.append(tf.reduce_all(tf.equal(outputs[i], outputs2[i]))) val_flag = tf.reduce_all(val_flags) self.assertEqual(len(grads), len(grads2)) grad_flags = [] for i in range(len(grads)): grad_flags.append(tf.reduce_all(tf.equal(grads[i], grads2[i]))) grad_flag = tf.reduce_all(grad_flags) f1, f2 = sess.run([val_flag, grad_flag]) self.assertTrue(f1) self.assertTrue(f2) @test_util.run_gpu_only def test_aligned_concat_split(self): with tf.compat.v1.Session() as sess, test_util.use_gpu(): arrays = [] num_items = 155 for i in range(num_items): num_dims = random.randint(1, 4) arrays.append(tf.random.uniform([random.randint(1, 50) for _ in range(num_dims)])) concat = distribution_ops.gen_distribution_ops.monolith_aligned_flat_concat(arrays) splits = distribution_ops.gen_distribution_ops.monolith_aligned_flat_split(arrays, concat) arrays, splits = sess.run([arrays, splits]) for i in range(num_items): self.assertAllEqual(arrays[i], splits[i]) if __name__ == "__main__": tf.compat.v1.disable_eager_execution() tf.test.main() ================================================ FILE: monolith/native_training/distribution_utils.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import time from absl import flags, logging import tensorflow as tf from monolith.native_training.metric.metric_hook import ByteCCLTelemetryHook FLAGS = flags.FLAGS _SYNC_TRAIN_INITED = False enable_bps = int(os.getenv("MONOLITH_WITH_BYTEPS", "0")) def bps_init(uuid: str): """ Initialize BytePS. Args: uuid: uuid of the training job, used to distinguish concurrent BytePS processes across runs. """ # init bps only if needed if os.environ.get('BYTEPS_ALLTOALL_SESSION_SIZE') is None: os.environ["BYTEPS_ALLTOALL_SESSION_SIZE"] = '3' # set size, rank based on OMPI env vars if os.environ.get('BYTEPS_LOCAL_SIZE', None) is None: os.environ["BYTEPS_LOCAL_SIZE"] = os.environ.get( 'OMPI_COMM_WORLD_LOCAL_SIZE') local_size = int(os.environ.get('BYTEPS_LOCAL_SIZE')) rank = int(os.environ.get('OMPI_COMM_WORLD_RANK')) size = int(os.environ.get('OMPI_COMM_WORLD_SIZE')) local_rank = rank % local_size phy_node_id = int(rank / local_size) socket_path = f"/tmp/bps_{uuid}_socket_{phy_node_id}" gdr_alltoall = os.environ.get('MONOLITH_WITH_BYTEPS_FWD_GDR', '0') == '1' or \ os.environ.get('MONOLITH_WITH_BYTEPS_BWD_GDR', '0') == '1' # gpu_nic_binding_mode: Default False, when True we bind gpu_id (0,1) to eth0, (2,3) to eth1... # This is useful for A100 systems where we have topology in which some gpus are closer to some # NICs. gpu_nic_binding_mode = int(os.environ.get('BYTEPS_GPU_NIC_BINDING_MODE', 0)) if not gpu_nic_binding_mode: # Constant binding mode (default), all GPUs use one NIC interface = os.getenv("DMLC_INTERFACE", "eth0") else: # gpu_nic_binding_mode binding mode NUM_GPU_PER_NIC = 2 nic_id = int(local_rank // NUM_GPU_PER_NIC) if gdr_alltoall: os.environ["CUDA_VISIBLE_DEVICES"] = str(local_rank) numa_id = os.environ['BYTEPS_NUMA_ID'] print( f"GDR: set CUDA_VISIBLE_DEVICES={local_rank}, BYTEPS_NUMA_ID={numa_id}" ) os.environ['BYTEPS_PIN_MEMORY'] = "1" os.environ['BYTEPS_PIN_MEMORY_CPU'] = os.environ.get( 'BYTEPS_PIN_MEMORY_CPU', '1') os.environ['DMLC_NUM_CPU_DEV'] = "0" os.environ['DMLC_NUM_GPU_DEV'] = "1" os.environ['BYTEPS_USE_GDR_ALLREDUCE'] = os.environ.get( 'BYTEPS_USE_GDR_ALLREDUCE', '1') interface = "eth{}".format(nic_id) # Add all eth otherwise it may give out "Destination not reachable" error # or block in some communication. if os.environ.get('BYTEPS_WITH_ALL_NICS', '0') == '1': os.environ[ "UCX_NET_DEVICES"] = "mlx5_0:1,mlx5_1:1,mlx5_2:1,mlx5_3:1,eth0,eth1,eth2,eth3" else: os.environ["UCX_NET_DEVICES"] = "mlx5_{}:1".format(nic_id) # scheduler connection info cmd = f'ip addr show {interface}' hostname = os.popen( cmd + ' | grep "\" | awk \'{ print $2 }\' | awk -F "/" \'{ print $1 }\'' ).read().strip() os.environ["UCX_RDMA_CM_SOURCE_ADDRESS"] = hostname os.environ["PSLITE_UCX_TLS"] = os.environ.get('PSLITE_UCX_TLS', 'rc_x,tcp,self,cuda') print( f"UCX: set PSLITE_UCX_TLS={os.environ['PSLITE_UCX_TLS']} {os.environ['UCX_NET_DEVICES']}" ) os.environ["DMLC_NODE_HOST"] = hostname os.environ["DMLC_ROLE"] = 'joint' os.environ["DMLC_ENABLE_UCX"] = os.environ.get('DMLC_ENABLE_UCX', '1') os.makedirs(socket_path, exist_ok=True) os.environ["DMLC_WORKER_ID"] = str(rank) os.environ["DMLC_NUM_WORKER"] = str(size) os.environ["DMLC_NUM_SERVER"] = str(size) os.environ["BYTEPS_UUID"] = uuid os.environ["BYTEPS_LOCAL_RANK"] = str(local_rank) os.environ["BYTEPS_SOCKET_PATH"] = socket_path os.environ["BYTEPS_OMP_THREAD_PER_GPU"] = os.environ.get( "BYTEPS_OMP_THREAD_PER_GPU", "1") os.environ["BYTEPS_FORCE_DISTRIBUTED"] = '1' os.environ["BYTEPS_TELEMETRY_ON"] = os.environ.get("BYTEPS_TELEMETRY_ON", '0') os.environ["BYTEPS_LOG_LEVEL"] = os.environ.get('BYTEPS_LOG_LEVEL', 'info') os.environ["BYTEPS_SERVER_DIRECT_RESPONSE"] = os.environ.get( 'BYTEPS_SERVER_DIRECT_RESPONSE', '2') os.environ["BYTEPS_UCX_FORCE_REQ_ORDER"] = '1' # performance tuning knobs os.environ["BYTEPS_KEY_HASH_FN"] = os.environ.get('BYTEPS_KEY_HASH_FN', 'djb2-colocate') os.environ["BYTEPS_UCX_SHORT_THRESH"] = os.environ.get( 'BYTEPS_UCX_SHORT_THRESH', '0') os.environ["PSLITE_UCX_RNDV_THRESH"] = os.environ.get( "PSLITE_UCX_RNDV_THRESH", '8192') os.environ["BYTEPS_WORKER_LOCAL_ROOT"] = os.environ.get( 'BYTEPS_WORKER_LOCAL_ROOT', '-1') # To enable async alltoall operations, we must reserve memory buffers on the receiver side. # BYTEPS_P2P_PARTITION_BYTES sets the receive buffer size for each alltoall operation from each sender. # It needs to be large enough such that the actual data sent does not exceed the buffer size, otherwise # error message may occur if os.environ.get("BYTEPS_P2P_PARTITION_BYTES") is None: alltoall_buff_size_per_rank = int(2048000 * 128 * 2 / size) os.environ["BYTEPS_P2P_PARTITION_BYTES"] = str(alltoall_buff_size_per_rank) if os.environ.get("BYTEPS_PARTITION_BYTES") is None: allreduce_partition_size = 1024000 if size < 128 else 512000 os.environ["BYTEPS_PARTITION_BYTES"] = str(allreduce_partition_size) import byteps.tensorflow as bps bps.init(lazy=False) # bps allreduce stress test def byteps_benchmark_ar(total_len, total_niter=10000, use_cpu=False, op='pushpull'): tf.compat.v1.enable_eager_execution() import byteps.tensorflow as bps import numpy as np rank, size = bps.rank(), bps.size() niter = 0 print( f'===== start pushpull_benchmark {rank}/{size} total_len={total_len} =====', flush=True) device = tf.device("/gpu:0" if not use_cpu else "/cpu:0") with device: tensor = tf.ones([total_len, 1], dtype=tf.float32) * (rank + 1) t0 = time.time() interval = 20 name = f'data_len_{total_len}_{op}_' + ('cpu' if use_cpu else 'gpu') comm_fn = bps.push_pull goodputs = [] while niter < total_niter: with device: result = comm_fn(tensor, average=True, name=name) niter += 1 if niter % interval == 0: t1 = time.time() latency = (t1 - t0) / interval * 1000 goodput = total_len * 32 / latency / 1000000 goodputs.append(goodput) rank == 0 and print( f'DONE iter={niter}, latency={latency:.3} ms, Goodput={goodput:.5} Gb/s, is_cpu={use_cpu}', flush=True) t0 = time.time() print( f'===== end pushpull_benchmark {rank}/{size} total_len={total_len} =====', flush=True) return goodputs[1:] # bps all2all stress test def byteps_benchmark_a2a(total_len, total_niter=10000, dst_gpu=True, src_gpu=True): tf.compat.v1.enable_eager_execution() # the CPU alltoall size is much smaller in real use cases if not dst_gpu and not src_gpu: total_len /= 8 import byteps.tensorflow as bps import numpy as np rank, size = bps.rank(), bps.size() niter = 0 len_per_worker = int(total_len / size) assert total_len % size == 0 p2p_matrix = np.array([len_per_worker] * (size * size)).reshape(size, size) splits_list = list(p2p_matrix[rank]) recv_splits_list = list(p2p_matrix[:, rank]) print( f'===== start all2all_benchmark {rank}/{size} total_len={total_len} =====', flush=True) with tf.device("/cpu:0"): splits = tf.constant(splits_list, dtype=tf.int32) recv_splits = tf.constant(recv_splits_list, dtype=tf.int32) with tf.device("/gpu:0" if src_gpu else "/cpu:0"): tensor = tf.ones([sum(splits_list), 1], dtype=tf.float32) * (rank + 1) t0 = time.time() interval = 20 name = f'data_len_{total_len}_' alltoall_fn = bps.alltoall if dst_gpu: if src_gpu: name += 'g2g' else: alltoall_fn = bps.alltoall_cpu2gpu name += 'c2g' else: if src_gpu: alltoall_fn = bps.alltoall_gpu2cpu name += 'g2c' else: name += 'c2c' goodputs = [] while niter < total_niter: with tf.device("/gpu:0" if src_gpu or dst_gpu else "/cpu:0"): result = alltoall_fn(tensor, splits=splits, recv_splits=recv_splits, name=name) niter += 1 if niter % interval == 0: t1 = time.time() latency = (t1 - t0) / interval * 1000 goodput = total_len * 32 / latency / 1000000 goodputs.append(goodput) rank == 0 and print( f'DONE iter={niter}, latency={latency:.3} ms, Goodput={goodput:.5} Gb/s', flush=True) t0 = time.time() print( f'===== end all2all_benchmark {rank}/{size} total_len={total_len} =====', flush=True) return goodputs[1:] def bps_comm_benchmark(): benchmark_bps = os.environ.get("MONOLITH_BENCHMARK_BPS", "none") benchmark_iters = int(os.getenv("MONOLITH_BENCHMARK_ITERS", "200")) gpus = tf.config.experimental.list_physical_devices('GPU') for gpu in gpus: tf.config.experimental.set_memory_growth(gpu, True) assert benchmark_bps in ("c2g", "g2g", "c2c", "g2c", "ar", "all"), benchmark_bps benchmarks = ["c2g", "g2g", "c2c", "g2c", "ar"] if benchmark_bps == "all" else [benchmark_bps] for benchmark in benchmarks: results = [] dst_gpu = benchmark in ("c2g", "g2g") src_gpu = benchmark in ("g2c", "g2g") if benchmark == "ar": total_len = int(os.getenv("MONOLITH_BENCHMARK_BPS_AR_LEN", "65536000")) goodputs_cpu = byteps_benchmark_ar(total_len, total_niter=benchmark_iters, use_cpu=True) results.append((total_len, sum(goodputs_cpu) / len(goodputs_cpu))) goodputs_gpu = byteps_benchmark_ar(total_len, total_niter=benchmark_iters, use_cpu=False) results.append((total_len, sum(goodputs_gpu) / len(goodputs_gpu))) else: total_len = int(os.getenv("MONOLITH_BENCHMARK_BPS_A2A_LEN", "65536000")) for _ in range(3): goodputs = byteps_benchmark_a2a(total_len, total_niter=benchmark_iters, dst_gpu=dst_gpu, src_gpu=src_gpu) results.append((total_len, sum(goodputs) / len(goodputs))) total_len = total_len // 2 print(benchmark + "_summary:", results) def init_sync_train_and_update_conf(dct_config): global _SYNC_TRAIN_INITED logging.info("Entering synchronous training.") # Import and init horovod/byteps on demand. try: if enable_bps: if not _SYNC_TRAIN_INITED: bps_init(dct_config.uuid) import byteps.tensorflow as hvd enable_bps_bcast = int(os.getenv("MONOLITH_WITH_BYTEPS_BCAST", "1")) enable_bps_allreduce = int( os.getenv("MONOLITH_WITH_BYTEPS_ALLREDUCE", "1")) if enable_bps_bcast == 0 or enable_bps_allreduce == 0: import horovod.tensorflow as hvd if not _SYNC_TRAIN_INITED: hvd.init() _SYNC_TRAIN_INITED = True if not dct_config.merge_sync_training_ckpt: model_dir_suffix = 'index-{:04}'.format(hvd.rank()) model_dir = os.path.join(dct_config.model_dir, dct_config.uuid, model_dir_suffix) dct_config.model_dir = model_dir else: import horovod.tensorflow as hvd if not _SYNC_TRAIN_INITED: hvd.init() _SYNC_TRAIN_INITED = True if not dct_config.merge_sync_training_ckpt: model_dir_suffix = 'index-{:04}'.format(hvd.rank()) model_dir = os.path.join(dct_config.model_dir, dct_config.uuid, model_dir_suffix) dct_config.model_dir = model_dir dct_config.num_ps = 0 dct_config.reorder_fids_in_data_pipeline = True dct_config.index = hvd.rank() dct_config.num_workers = hvd.size() dct_config.enable_variable_partition = False except (ImportError, tf.errors.NotFoundError) as e: logging.warning(f'init_sync_train_and_get_index error {e}') def get_mpi_rank(): rank = 0 if 'OMPI_COMM_WORLD_RANK' in os.environ: rank = int(os.environ.get('OMPI_COMM_WORLD_RANK')) else: logging.warning(f"get_mpi_rank use default 0") return rank def get_mpi_local_rank(): local_rank = 0 if 'OMPI_COMM_WORLD_LOCAL_RANK' in os.environ: local_rank = int(os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK')) else: logging.warning(f"get_mpi_local_rank use default 0") return local_rank def get_mpi_size(): size = 1 if 'OMPI_COMM_WORLD_SIZE' in os.environ: size = int(os.environ.get('OMPI_COMM_WORLD_SIZE')) else: logging.warning(f"get_mpi_size use default 1") return size def get_mpi_local_size(): local_size = 1 if 'OMPI_COMM_WORLD_LOCAL_SIZE' in os.environ: local_size = int(os.environ.get('OMPI_COMM_WORLD_LOCAL_SIZE')) else: logging.warning(f"get_mpi_local_size use default 1") return local_size def enable_sync_training(): try: return FLAGS.enable_sync_training and 'OMPI_COMM_WORLD_LOCAL_RANK' in os.environ except: return False def try_init_cuda(): if 'CUDA_VISIBLE_DEVICES' not in os.environ and 'OMPI_COMM_WORLD_LOCAL_RANK' in os.environ: os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ['CUDA_VISIBLE_DEVICES'] = str(get_mpi_local_rank()) global _SYNC_TRAIN_INITED if 'OMPI_COMM_WORLD_LOCAL_RANK' in os.environ: if not _SYNC_TRAIN_INITED: try: if FLAGS.enable_sync_training: enable_bps = int(os.getenv("MONOLITH_WITH_BYTEPS", "0")) enable_hvd = int(os.getenv("MONOLITH_WITH_HOROVOD", "0")) if enable_bps: import byteps.tensorflow as hvd elif enable_hvd: import horovod.tensorflow as hvd else: raise Exception('no allreduce tools found!') hvd.init() _SYNC_TRAIN_INITED = True except Exception as e: logging.info(str(e)) def get_device_str(force_on_cpu: bool = False): is_mpi_mode = True if 'OMPI_COMM_WORLD_LOCAL_RANK' in os.environ else False is_ps_mode = True if FLAGS.num_ps > 0 else False from monolith.native_training import device_utils device = 'GPU' if FLAGS.enable_gpu_training or device_utils._GPU_PLACEMENT_ALLOWED else 'CPU' device = 'CPU' if force_on_cpu else device if is_mpi_mode and FLAGS.enable_sync_training: if is_ps_mode: rank = get_mpi_rank() job = 'chief' if rank == 0 else 'worker' task = rank if rank == 0 else rank - 1 return f'/job:{job}/replica:0/task:{task}/device:{device}:0' else: return '' else: return f'/device:{device}:0' def get_sync_run_hooks(is_full_sync: bool = False): if enable_sync_training(): enable_bps = int(os.getenv("MONOLITH_WITH_BYTEPS", "0")) enable_bps_bcast = int(os.getenv("MONOLITH_WITH_BYTEPS_BCAST", "1")) if enable_bps and enable_bps_bcast == -1: run_hooks = [] elif enable_bps and enable_bps_bcast: import byteps.tensorflow as bps logging.info('Enabled BPS for bcast') run_hooks = [bps.BroadcastGlobalVariablesHook(0, device=get_device_str())] if is_full_sync: run_hooks.append(ByteCCLTelemetryHook(50)) else: import horovod.tensorflow as hvd run_hooks = [hvd.BroadcastGlobalVariablesHook(0, device=get_device_str())] return run_hooks else: return [] def update_session_config_for_gpu(session_config): enable_bps = int(os.getenv("MONOLITH_WITH_BYTEPS", "0")) if enable_sync_training(): # It's recommended to set the visible device list in session config instead of the CUDA_VISIBLE_DEVICES environment variable. # Setting the CUDA_VISIBLE_DEVICES variable may mislead NCCL as per my testing # https://horovod.readthedocs.io/en/stable/tensorflow.html?highlight=visible_device_list # https://horovod.readthedocs.io/en/stable/troubleshooting.html?highlight=cuda_visible_devices#running-out-of-memory local_rank = os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK', "0") if os.environ.get('MONOLITH_FORCE_GPU_COMPATIBLE', '1') == '1': session_config.gpu_options.force_gpu_compatible = True logging.info("set force_gpu_compatible=True") if enable_bps and (os.environ.get('MONOLITH_WITH_BYTEPS_FWD_GDR', '0') == '1' or \ os.environ.get('MONOLITH_WITH_BYTEPS_BWD_GDR', '0') == '1'): # if GDR alltoall is enabled, GPU memory need to be registered for UCX # ahead of time. Therefore, we disable the allow_growth option for GPU. # The cuda visible devices are also limited to one device only. session_config.gpu_options.allow_growth = False session_config.gpu_options.per_process_gpu_memory_fraction = 0.4 session_config.gpu_options.visible_device_list = local_rank else: session_config.gpu_options.allow_growth = True session_config.gpu_options.visible_device_list = local_rank else: session_config.gpu_options.allow_growth = True ================================================ FILE: monolith/native_training/embedding_combiners.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import abc import tensorflow as tf from monolith.native_training import device_utils from monolith.native_training import distribution_ops from monolith.native_training import ragged_utils class Combiner(abc.ABC): def __init__(self, max_seq_length: int): self._max_seq_length = max_seq_length @property def max_seq_length(self): return self._max_seq_length @abc.abstractmethod def combine(self, key: tf.RaggedTensor, embedding: tf.Tensor, name: str = None): pass class ReduceSum(Combiner): def __init__(self): super().__init__(0) def combine(self, key: tf.RaggedTensor, embedding: tf.Tensor, name: str = None): return distribution_ops.reduce_sum(tf.expand_dims( ragged_utils.fused_value_rowids(key), -1), embedding, tf.expand_dims(key.nrows(), 0), name=name) class ReduceMean(Combiner): def __init__(self): super().__init__(0) def combine(self, key: tf.RaggedTensor, embedding: tf.Tensor, name: str = None): return distribution_ops.reduce_mean(tf.expand_dims( ragged_utils.fused_value_rowids(key), -1), embedding, tf.expand_dims(key.nrows(), 0), name=name) class FirstN(Combiner): def __init__(self, seq_length: int): assert seq_length > 0, "seq_length must be greater than 0" super().__init__(seq_length) def combine(self, key: tf.RaggedTensor, embedding: tf.Tensor, name: str = None): """For rows with smaller number of embeddings than seq_length, automatically append embedding elements which are all zero (default to scatter_nd). Tensor's shape is (batch, seq_length, dim) """ name = name or "FirstNCombiner" with tf.name_scope(name): if not isinstance(embedding, tf.Tensor): embedding = tf.convert_to_tensor(embedding) batch_size_tensor = key.nrows() key_sparse = key.to_sparse() indices = key_sparse.indices shape = tf.stack([ batch_size_tensor, tf.math.reduce_max([self.max_seq_length, key_sparse.dense_shape[1]]), embedding.shape.as_list()[1] ]) with device_utils.maybe_device_if_allowed('/device:GPU:0'): scattered = tf.scatter_nd(indices, embedding, shape) # We use slice here instead of array composition because of the shape problem. return tf.slice(scattered, [0, 0, 0], [-1, self.max_seq_length, -1]) ================================================ FILE: monolith/native_training/embedding_combiners_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import tensorflow as tf from monolith.native_training import embedding_combiners class CombinerTest(tf.test.TestCase): def testReduceSum(self): key = tf.RaggedTensor.from_row_lengths([1, 2, 3], [1, 2]) emb = [[1.0], [2.0], [3.0]] comb = embedding_combiners.ReduceSum() result = self.evaluate(comb.combine(key, emb)) self.assertAllClose(result, [[1.0], [5.0]]) def testFirstN(self): key = tf.RaggedTensor.from_row_lengths([1, 2, 3, 4, 5, 6], [1, 2, 3]) emb = [[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]] comb = embedding_combiners.FirstN(2) result = self.evaluate(comb.combine(key, emb)) self.assertAllClose(result, [[[1.0], [0.0]], [[2.0], [3.0]], [[4.0], [5.0]]]) def testFirstNUnknownShape(self): key = tf.compat.v1.ragged.placeholder(tf.int64, 1, []) emb = tf.compat.v1.placeholder(tf.float32, shape=[None, 6]) comb = embedding_combiners.FirstN(2) result = comb.combine(key, emb) self.assertAllEqual(result.shape, [None, 2, 6]) if __name__ == "__main__": tf.compat.v1.disable_eager_execution() tf.test.main() ================================================ FILE: monolith/native_training/entry.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import abc import copy from typing import Any, List, Union import tensorflow as tf from monolith.native_training.monolith_export import monolith_export from monolith.native_training.runtime.hash_table import \ embedding_hash_table_pb2 class Optimizer(abc.ABC): """The abstract base class for optimizer.""" @abc.abstractmethod def as_proto(self) -> embedding_hash_table_pb2.OptimizerConfig: pass def _convert_to_proto(obj: object, proto: object): proto.SetInParent() for k, v in obj.__dict__.items(): if v is not None: setattr(proto, k, v) class StochasticRoundingFloat16OptimizerWrapper(Optimizer): def __init__(self, optimizer): self._optimizer = optimizer def as_proto(self): proto = self._optimizer.as_proto() proto.stochastic_rounding_float16 = True return proto @monolith_export class SgdOptimizer(Optimizer): r"""随机梯度下降优化器. 定义参数为x, 梯度为grad, 第i次更新梯度有 .. math:: x_{i+1} = x_{i} - \eta * grad Args: learning_rate (:obj:`float`): 学习率 """ def __init__(self, learning_rate=None): self.learning_rate = learning_rate def as_proto(self): opt = embedding_hash_table_pb2.OptimizerConfig() _convert_to_proto(self, opt.sgd) return opt @monolith_export class AdagradOptimizer(Optimizer): r"""Adagrad优化器, 论文可参考 http://jmlr.org/papers/v12/duchi11a.html 定义参数为x, 梯度为grad, 第i次更新梯度时有 .. math:: g_{i+1} = g_{i} + grad^2 x_{i+1} = x_{i} - \frac{\eta}{\sqrt{g_i + \epsilon}} grad Args: learning_rate (:obj:`float`): 学习率 initial_accumulator_value (:obj:`float`): accmulator的起始值 hessian_compression_times (:obj:`float`): 在训练的时候,对accumulator使用hessian sketching算法进行压缩. 1代表没有压缩,值越大,压缩效果越好 warmup_steps (:obj:`int`): 已弃用 """ def __init__( self, learning_rate=None, # alpha initial_accumulator_value=None, # beta hessian_compression_times=1, warmup_steps=0, weight_decay_factor=0.0): self.learning_rate = learning_rate self.initial_accumulator_value = initial_accumulator_value self.hessian_compression_times = hessian_compression_times self.weight_decay_factor = weight_decay_factor self.warmup_steps = warmup_steps def as_proto(self): opt = embedding_hash_table_pb2.OptimizerConfig() _convert_to_proto(self, opt.adagrad) return opt @monolith_export class AdadeltaOptimizer(Optimizer): def __init__(self, learning_rate=None, weight_decay_factor=0.0, averaging_ratio=0.9, epsilon=0.01, warmup_steps=0): self.learning_rate = learning_rate self.weight_decay_factor = weight_decay_factor self.averaging_ratio = averaging_ratio self.epsilon = epsilon self.warmup_steps = warmup_steps def as_proto(self): opt = embedding_hash_table_pb2.OptimizerConfig() _convert_to_proto(self, opt.adadelta) return opt @monolith_export class AdamOptimizer(Optimizer): r"""Adam优化器, 论文可参考 https://arxiv.org/abs/1412.6980 定义参数为x, 梯度为grad, 第i次更新梯度时有 .. math:: m_{i+1} = \beta_1 * m_i + (1 - \beta_1) * grad v_{i+1} = \beta_2 * v_i + (1 - \beta_2) * grad^2 w_{i+1} = w_i - \eta * \frac{m_i}{\sqrt{v_i + \epsilon}} Args: learning_rate (:obj:`float`): 学习率 beta1 (:obj:`float`): 一阶矩估计的指数衰减率 beta2 (:obj:`float`): 二阶矩估计的指数衰减率 epsilon (:obj:`float`): 用来保证除数不为0的偏移量 warmup_steps (:obj:`int`): 已弃用 """ def __init__(self, learning_rate=None, beta1=0.9, beta2=0.99, use_beta1_warmup=False, weight_decay_factor=0.0, use_nesterov=False, epsilon=0.01, warmup_steps=0): self.learning_rate = learning_rate self.beta1 = beta1 self.beta2 = beta2 self.use_beta1_warmup = use_beta1_warmup self.weight_decay_factor = weight_decay_factor self.use_nesterov = use_nesterov self.epsilon = epsilon self.warmup_steps = warmup_steps def as_proto(self): opt = embedding_hash_table_pb2.OptimizerConfig() _convert_to_proto(self, opt.adam) return opt class AmsgradOptimizer(Optimizer): def __init__(self, learning_rate=None, beta1=0.9, beta2=0.99, weight_decay_factor=0.0, use_nesterov=False, epsilon=0.01, warmup_steps=0): self.learning_rate = learning_rate self.beta1 = beta1 self.beta2 = beta2 self.weight_decay_factor = weight_decay_factor self.use_nesterov = use_nesterov self.epsilon = epsilon self.warmup_steps = warmup_steps def as_proto(self): opt = embedding_hash_table_pb2.OptimizerConfig() _convert_to_proto(self, opt.amsgrad) return opt @monolith_export class BatchSoftmaxOptimizer(Optimizer): r"""Batch softmax优化器, 论文可参考 https://research.google/pubs/pub48840/ Args: learning_rate (:obj:`float`): 学习率 """ def __init__( self, learning_rate=None, # alpha ): self.learning_rate = learning_rate def as_proto(self): opt = embedding_hash_table_pb2.OptimizerConfig() _convert_to_proto(self, opt.batch_softmax) return opt @monolith_export class MomentumOptimizer(Optimizer): def __init__(self, learning_rate=None, weight_decay_factor=0.0, use_nesterov=False, momentum=0.9, warmup_steps=0): self.learning_rate = learning_rate self.weight_decay_factor = weight_decay_factor self.use_nesterov = use_nesterov self.momentum = momentum self.warmup_steps = warmup_steps def as_proto(self): opt = embedding_hash_table_pb2.OptimizerConfig() _convert_to_proto(self, opt.momentum) return opt class MovingAverageOptimizer(Optimizer): def __init__(self, momentum=0.9): self.momentum = momentum def as_proto(self): opt = embedding_hash_table_pb2.OptimizerConfig() _convert_to_proto(self, opt.moving_average) return opt @monolith_export class RmspropOptimizer(Optimizer): def __init__(self, learning_rate=None, weight_decay_factor=0.0, momentum=0.9): self.learning_rate = learning_rate self.weight_decay_factor = weight_decay_factor self.momentum = momentum def as_proto(self): opt = embedding_hash_table_pb2.OptimizerConfig() _convert_to_proto(self, opt.rmsprop) return opt @monolith_export class RmspropV2Optimizer(Optimizer): def __init__(self, learning_rate=None, weight_decay_factor=0.0, momentum=0.9): self.learning_rate = learning_rate self.weight_decay_factor = weight_decay_factor self.momentum = momentum def as_proto(self): opt = embedding_hash_table_pb2.OptimizerConfig() _convert_to_proto(self, opt.rmspropv2) return opt class FTRLWithGroupSparsityOptimizer(Optimizer): def __init__( self, learning_rate=None, # alpha initial_accumulator_value=None, beta=None, warmup_steps=0, l1_regularization=None, # lambda1 l2_regularization=None): # lambda2 self.learning_rate = learning_rate self.initial_accumulator_value = initial_accumulator_value self.beta = beta self.l1_regularization_strength = l1_regularization self.l2_regularization_strength = l2_regularization self.warmup_steps = warmup_steps def as_proto(self): opt = embedding_hash_table_pb2.OptimizerConfig() _convert_to_proto(self, opt.group_ftrl) return opt @monolith_export class AdaGradWithGroupLassoOptimizer(Optimizer): def __init__(self, learning_rate=None, beta=None, initial_accumulator_value=None, l2_regularization=None, weight_decay_factor=0.0, warmup_steps=0): self.learning_rate = learning_rate self.beta = beta self.initial_accumulator_value = initial_accumulator_value self.l2_regularization_strength = l2_regularization self.weight_decay_factor = weight_decay_factor self.warmup_steps = warmup_steps def as_proto(self): opt = embedding_hash_table_pb2.OptimizerConfig() _convert_to_proto(self, opt.group_adagrad) return opt # TODO: put DcOptimizer into entry.py class DynamicWdAdagradOptimizer(Optimizer): def __init__( self, learning_rate=None, # alpha initial_accumulator_value=None, # beta hessian_compression_times=1, warmup_steps=0, weight_decay_factor=0.0, decouple_weight_decay=True, enable_dynamic_wd=True, flip_direction=True, dynamic_wd_temperature=1.0): self.learning_rate = learning_rate self.initial_accumulator_value = initial_accumulator_value self.hessian_compression_times = hessian_compression_times self.weight_decay_factor = weight_decay_factor self.warmup_steps = warmup_steps self.decouple_weight_decay = decouple_weight_decay self.enable_dynamic_wd = enable_dynamic_wd self.flip_direction = flip_direction self.dynamic_wd_temperature = dynamic_wd_temperature def as_proto(self): opt = embedding_hash_table_pb2.OptimizerConfig() _convert_to_proto(self, opt.dynamic_wd_adagrad) return opt @monolith_export class FtrlOptimizer(Optimizer): """FTRL优化器, 论文可参考 https://dl.acm.org/citation.cfm?id=2488200 Args: initial_accumulator_value (:obj:`float`): accumulator的起始值 beta (:obj:`float`): 论文中的beta值 """ def __init__( self, learning_rate=None, # alpha initial_accumulator_value=None, beta=None, warmup_steps=0, l1_regularization=None, # lambda1 l2_regularization=None): # lambda2 self.learning_rate = learning_rate self.initial_accumulator_value = initial_accumulator_value self.beta = beta self.l1_regularization_strength = l1_regularization self.l2_regularization_strength = l2_regularization self.warmup_steps = warmup_steps def as_proto(self): opt = embedding_hash_table_pb2.OptimizerConfig() _convert_to_proto(self, opt.ftrl) return opt class Initializer(abc.ABC): """The abstract base class for initializer""" @abc.abstractmethod def as_proto(self) -> embedding_hash_table_pb2.InitializerConfig: pass @monolith_export class ZerosInitializer(Initializer): """全0初始化器,将会把embedidng的初始值设为全0""" def as_proto(self): init = embedding_hash_table_pb2.InitializerConfig() _convert_to_proto(self, init.zeros) return init @monolith_export class ConstantsInitializer(Initializer): """常数初始化器,将会把embedidng的初始值设为常数""" def __init__(self, constant: float): self.constant = constant def as_proto(self): init = embedding_hash_table_pb2.InitializerConfig() _convert_to_proto(self, init.constants) return init class RandomUniformInitializer(Initializer): """随机均匀的初始化器,将会把初始化区间默认为[minval, maxval] Args: minval, maxval (:obj:`float`): 初始化的区间 """ def __init__(self, minval=None, maxval=None): self.minval = minval self.maxval = maxval def as_proto(self): init = embedding_hash_table_pb2.InitializerConfig() _convert_to_proto(self, init.random_uniform) return init class BatchSoftmaxInitializer(Initializer): def __init__(self, init_step_interval: float): if init_step_interval < 1: raise ValueError("init_step_interval should be >= 1, while got {}".format( init_step_interval)) self.constant = init_step_interval def as_proto(self): init = embedding_hash_table_pb2.InitializerConfig() _convert_to_proto(self, init.constants) return init class Compressor(abc.ABC): """The abstract base class for compressor""" @abc.abstractmethod def as_proto(self) -> embedding_hash_table_pb2.FloatCompressorConfig: pass @monolith_export class OneBitCompressor(Compressor): def __init__(self, step_size: int = 200, amplitude: float = 0.05): super().__init__() self.step_size = step_size self.amplitude = amplitude def as_proto(self): comp = embedding_hash_table_pb2.FloatCompressorConfig() comp.one_bit.step_size = self.step_size _convert_to_proto(self, comp.one_bit) return comp @monolith_export class FixedR8Compressor(Compressor): def __init__(self, fixed_range=1.0): super().__init__() self.r = fixed_range def as_proto(self): comp = embedding_hash_table_pb2.FloatCompressorConfig() _convert_to_proto(self, comp.fixed_r8) return comp @monolith_export class Fp16Compressor(Compressor): """当模型服务时,将会对embedding进行Fp16编码,从而达到在服务时节省内存的目的""" def as_proto(self): comp = embedding_hash_table_pb2.FloatCompressorConfig() _convert_to_proto(self, comp.fp16) return comp @monolith_export class Fp32Compressor(Compressor): """当模型服务时,将会对embedding进行Fp32编码""" def as_proto(self): comp = embedding_hash_table_pb2.FloatCompressorConfig() _convert_to_proto(self, comp.fp32) return comp def CombineAsSegment( dim_size: int, initializer: Union[Initializer, embedding_hash_table_pb2.InitializerConfig], optimizer: Union[Optimizer, embedding_hash_table_pb2.OptimizerConfig], compressor: Union[Compressor, embedding_hash_table_pb2.FloatCompressorConfig] ) -> embedding_hash_table_pb2.EntryConfig.Segment: segment = embedding_hash_table_pb2.EntryConfig.Segment() segment.dim_size = dim_size if hasattr(initializer, 'as_proto'): segment.init_config.CopyFrom(initializer.as_proto()) else: segment.init_config.CopyFrom(initializer) if hasattr(optimizer, 'as_proto'): segment.opt_config.CopyFrom(optimizer.as_proto()) else: segment.opt_config.CopyFrom(optimizer) if hasattr(compressor, 'as_proto'): segment.comp_config.CopyFrom(compressor.as_proto()) else: segment.comp_config.CopyFrom(compressor) return segment class HashTableConfig(abc.ABC): """For hash table since we are not sure which field to update, we use an update function""" @abc.abstractmethod def mutate_table( self, table_config: embedding_hash_table_pb2.EmbeddingHashTableConfig): pass class CuckooHashTableConfig(HashTableConfig): def __init__(self, initial_capacity=1, feature_evict_every_n_hours=0): self._initial_capacity = initial_capacity self._feature_evict_every_n_hours = feature_evict_every_n_hours def mutate_table( self, table_config: embedding_hash_table_pb2.EmbeddingHashTableConfig): table_config.initial_capacity = self._initial_capacity table_config.cuckoo.SetInParent() if self._feature_evict_every_n_hours > 0: table_config.enable_feature_eviction = True table_config.feature_evict_every_n_hours = self._feature_evict_every_n_hours class HashTableConfigInstance(): """The config instance for generating HashTable""" def __init__(self, table_config: embedding_hash_table_pb2.EmbeddingHashTableConfig, learning_rate_fns: List[Any], extra_restore_names=None): self._table_config = table_config self.extra_restore_names = copy.copy(extra_restore_names) or [] self._learning_rate_fns = learning_rate_fns self._learning_rate_tensor = None # Used to check whether two slots share the same config. def __str__(self): return "TableConfigPB: %s, LearningRateFns: [%s]" % ( self._table_config.SerializeToString(), ", ".join( [str(fn) for fn in self._learning_rate_fns])) @property def table_config(self): return self._table_config @property def learning_rate_fns(self): return self._learning_rate_fns @property def learning_rate_tensor(self): return self._learning_rate_tensor def set_learning_rate_tensor(self, learning_rate_tensor: tf.Tensor): self._learning_rate_tensor = learning_rate_tensor def call_learning_rate_fns(self) -> tf.Tensor: """Call learning rate function if callable and return a tf.Tensor""" with tf.name_scope("learning_rate"): learning_rates = list() for learning_rate_fn in self._learning_rate_fns: if not callable(learning_rate_fn): learning_rate = tf.cast(learning_rate_fn, dtype=tf.float32) else: learning_rate = learning_rate_fn() learning_rates.append(learning_rate) if len(learning_rates) > 0: learning_rate_tensor = tf.stack(learning_rates) else: raise Exception("Learning_rate_fns must be not empty.") return learning_rate_tensor def call_learning_rate_fns_fewer_ops(self) -> List[tf.Tensor]: """Call learning rate function if callable and return a tf.Tensor""" with tf.name_scope("learning_rate"): learning_rates = list() for learning_rate_fn in self._learning_rate_fns: if not callable(learning_rate_fn): learning_rate = learning_rate_fn else: learning_rate = learning_rate_fn() learning_rates.append(learning_rate) if len(learning_rates) == 0: raise Exception("Learning_rate_fns must be not empty.") return learning_rates ================================================ FILE: monolith/native_training/entry_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import unittest from monolith.native_training import entry from monolith.native_training import learning_rate_functions from monolith.native_training.runtime.hash_table import \ embedding_hash_table_pb2 def _default_learning_rate_fn(): return learning_rate_functions.PolynomialDecay(initial_learning_rate=0.01, decay_steps=20, end_learning_rate=0.05) class EntryTest(unittest.TestCase): """The tests here are for testing complilation.""" def test_optimizers(self): entry.SgdOptimizer(0.01).as_proto() entry.AdagradOptimizer(0.01, 0.1).as_proto() entry.AdagradOptimizer(0.01, 0.1, 10).as_proto() entry.FtrlOptimizer(0.01, 0.1, 1).as_proto() entry.DynamicWdAdagradOptimizer(0.01, 0.1, 1).as_proto() entry.AdadeltaOptimizer(0.01, 0.0, 0.9, 0.01).as_proto() entry.AdamOptimizer(0.01, 0.9, 0.99, False, 0.0, False, 0.01).as_proto() entry.AmsgradOptimizer(0.01, 0.9, 0.99, 0.0, False, 0.01).as_proto() entry.MomentumOptimizer(0.01, 0.0, False, 0.9).as_proto() entry.MovingAverageOptimizer(0.9).as_proto() entry.RmspropOptimizer(0.01, 0.0, 0.9).as_proto() entry.RmspropV2Optimizer(0.01, 0.0, 0.9).as_proto() entry.BatchSoftmaxOptimizer(0.01).as_proto() def test_initializer(self): entry.ZerosInitializer().as_proto() entry.RandomUniformInitializer(-0.5, 0.5).as_proto() entry.BatchSoftmaxInitializer(1.0).as_proto() def test_compressor(self): entry.Fp16Compressor().as_proto() entry.Fp32Compressor().as_proto() entry.FixedR8Compressor().as_proto() entry.OneBitCompressor().as_proto() def test_combine(self): entry.CombineAsSegment(5, entry.ZerosInitializer(), entry.SgdOptimizer(), entry.Fp16Compressor()) def test_hashtable_config(self): entry.CuckooHashTableConfig() def test_hashtable_config_entrance(self): table_config1 = embedding_hash_table_pb2.EmbeddingHashTableConfig() config1 = entry.HashTableConfigInstance(table_config1, [0.1]) table_config2 = embedding_hash_table_pb2.EmbeddingHashTableConfig() config2 = entry.HashTableConfigInstance(table_config2, [0.1]) assert (str(config1) == str(config2)) table_config3 = embedding_hash_table_pb2.EmbeddingHashTableConfig() config3 = entry.HashTableConfigInstance(table_config3, [_default_learning_rate_fn()]) table_config4 = embedding_hash_table_pb2.EmbeddingHashTableConfig() config4 = entry.HashTableConfigInstance(table_config4, [_default_learning_rate_fn()]) assert (str(config3) == str(config4)) assert (str(config1) != str(config3)) if __name__ == "__main__": unittest.main() ================================================ FILE: monolith/native_training/env_utils.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import contextlib import hashlib import os import subprocess import socket from absl import logging def setup_hdfs_env(): pass def generate_psm_from_uuid(s): return s def get_zk_auth_data(): ZK_AUTH = os.getenv('ZK_AUTH', None) if ZK_AUTH: print("ZK_AUTH", ZK_AUTH) return [("digest", ZK_AUTH)] return None ================================================ FILE: monolith/native_training/env_utils_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import unittest from unittest import mock from monolith.native_training import env_utils if __name__ == "__main__": unittest.main() ================================================ FILE: monolith/native_training/estimator.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from absl import logging, flags from dataclasses import dataclass from dataclasses_json import dataclass_json import os import copy import json import numpy as np import collections import getpass from typing import Dict, List from kazoo.client import KazooClient from typing import Optional, Union, get_type_hints import tensorflow as tf from monolith.native_training import env_utils from monolith.agent_service.utils import AgentConfig from monolith.agent_service.backends import ZKBackend from monolith.native_training.zk_utils import default_zk_servers from monolith.agent_service.replica_manager import ReplicaWatcher from monolith.native_training.cpu_training import CpuTraining, create_exporter from monolith.native_training.runner_utils import RunnerConfig, monolith_discovery from monolith.native_training.cpu_training import local_train_internal from monolith.native_training.cpu_training import distributed_train from monolith.native_training.cpu_training import distributed_sync_train from monolith.native_training.service_discovery import ServiceDiscoveryType from monolith.native_training.monolith_export import monolith_export from monolith.native_training.model_dump.dump_utils import DumpUtils from monolith.core.hyperparams import InstantiableParams from monolith.native_training.data.item_pool_hook import ItemPoolSaveRestoreHook from monolith.native_training.utils import set_metric_prefix from monolith.native_training.data.parsers import get_default_parser_ctx, ParserCtx from monolith.native_training.model_export.export_context import \ is_exporting, is_exporting_distributed, ExportMode from monolith.native_training.zk_utils import MonolithKazooClient from monolith.native_training.distribution_utils import init_sync_train_and_update_conf, try_init_cuda from monolith.native_training import device_utils @monolith_export class EstimatorSpec( collections.namedtuple( 'EstimatorSpec', ['label', 'pred', 'head_name', 'loss', 'optimizer', 'classification'])): """EstimatorSpec是model_fn返回的数据结构. Args: label (:obj:`Union[tf.Tensor, List[tf.Tensor]]`): 样本标签, multi-head可以使用列表 pred (:obj:`Union[tf.Tensor, List[tf.Tensor]]`): 预测结果, multi-head可以使用列表 head_name (:obj:`Union[str, List[str]]`): predict名称, multi-head可以使用列表 loss (:obj:`tf.Tensor`): 损失 optimizer (:obj:`tf.Optimizer`): dense部分的优化器 classification (:obj:`Union[bool, List[bool]]`): 是否为分类模型, multi-head可使用列表 """ def __new__(cls, label, pred, head_name=None, loss=None, optimizer=None, classification=True): return super(EstimatorSpec, cls).__new__(cls, label=label, pred=pred, head_name=head_name, loss=loss, optimizer=optimizer, classification=classification) def _replace(self, **kwds): """Return a new EstimatorSpec replacing specified fields with new values.""" if 'mode' in kwds: if self.mode != kwds['mode']: raise ValueError('mode of EstimatorSpec cannot be changed.') new_fields = map(kwds.pop, self._fields, list(self)) return EstimatorSpec(*new_fields) @monolith_export @dataclass_json @dataclass class RunConfig: """Estimator相关配置, 用户模型外参数统一入口 Args: is_local (:obj:`bool`): 是否为本地模式, 默认为False num_ps (:obj:`int`): PS个数 num_workers (:obj:`int`): Woeker个数 chief_timeout_secs (:obj:`int`): chief 超时时长, 默认为 1800秒 operation_timeout_in_ms (:obj:`int`): 操作超时时长, 默认为 600000, 为600s session_creation_timeout_secs (:obj:`int`): session创建超时时长, 默认为7200秒 enable_fused_layout (:obj:`bool`): 是否打开layout融合, 加速计算 partial_recovery (:obj:`bool`): 是否开启部分恢复 max_retry_times (:obj:`int`): 发生容错时, 最大重启次数, 默认为 6 retry_wait_in_secs (:obj:`int`): 发生容错时, 重启时间间隔, 默认为 5 bzid (:obj:`str`): serving 业务id, 用于参数同步时找到对应Online PS base_name (:obj:`str`): serving 模型名, 用于参数同步时找到对应Online PS ps_replica_num (:obj:`int`): serving PS 副本数 enable_parameter_sync (:obj:`bool`): 是否开启参数同步, 默认为False model_dir (:obj:`str`): 模型dump目录 restore_dir (:obj:`str`): 模型加载目录, 当dump目录与加载目录不同时才需指定, 默认从model_dir中加载模型 restore_ckpt (:obj:`str`): 加载checkpoint版本, 默认加载最新版 save_checkpoints_secs (:obj:`int`): 每过多少秒存一个checkpoint save_checkpoints_steps (:obj:`int`): 每过多少step存一个checkpoint max_rpc_deadline_millis (:obj:`int`): prc超时时长, 默认30秒 dense_only_save_checkpoints_secs (:obj:`int`): 每过多少秒存一个dense部分checkpoint dense_only_save_checkpoints_steps (:obj:`int`): 每过多少step存一个dense部分checkpoint checkpoints_max_to_keep (:obj:`int`): 最多保存多少个checkpoint warmup_file (:obj:`str`): serving warmup文件名 enable_local_profiling (:obj:`bool`): 是否打开本地测试 profiling use_native_multi_hash_table (:obj:`bool`): 请不要指定这个变量,将于2023-1-1移除 clear_nn (:obj:`bool`): 是否在reload模型时将dense部分随机初始化, 默认为false. continue_training (:obj:`bool`): 是clear_nn为true时, global_step是否继续保持, 默认为false. reload_alias_map (:obj:`dict`): 在加载模型时, 如果由于某些原因, 变量名字不一致, 可以用reload_alias_map指定新老名字的对应关系 enable_alias_map_auto_gen: 是否需要自动生成 alias_map save_summary_steps: 每隔多少global_step保存一次summary log_step_count_steps: 每隔多少global_step打印一次loss disable_native_metrics: 是否关闭 TensorFlow 的 metrics 功能,如 AUC、MSE 的计算,默认为 True """ # basic is_local: bool = False num_ps: int = 0 num_workers: int = 1 chief_timeout_secs: int = 1800 operation_timeout_in_ms: int = -1 session_creation_timeout_secs: int = 7200 enable_fused_layout: bool = False enable_model_dump: bool = False # failover partial_recovery: bool = False max_retry_times: int = 6 retry_wait_in_secs: int = 5 # for params sync bzid: str = None base_name: str = None ps_replica_num: int = None enable_parameter_sync: bool = False # checkpoint and export model_dir: str = "" restore_dir: str = None restore_ckpt: str = None save_checkpoints_secs: int = None save_checkpoints_steps: int = None max_rpc_deadline_millis: int = 30000 dense_only_save_checkpoints_secs: int = None dense_only_save_checkpoints_steps: int = None checkpoints_max_to_keep: int = 10 warmup_file: str = './warmup_file' enable_local_profiling: bool = False use_native_multi_hash_table: bool = None clear_nn: bool = False continue_training: bool = False reload_alias_map: Dict[str, int] = None enable_alias_map_auto_gen: bool = None kafka_topics: str = None kafka_group_id: str = None kafka_servers: str = None disable_native_metrics: bool = True save_summary_steps: int = 100 log_step_count_steps: int = 100 def to_runner_config(self) -> RunnerConfig: conf = RunnerConfig( restore_dir=self.restore_dir, restore_ckpt=self.restore_ckpt, model_dir=self.model_dir, enable_fused_layout=self.enable_fused_layout, enable_model_dump=self.enable_model_dump, warmup_file=self.warmup_file, enable_alias_map_auto_gen=self.enable_alias_map_auto_gen, kafka_topics=self.kafka_topics, kafka_group_id=self.kafka_group_id, kafka_servers=self.kafka_servers, save_summary_steps=self.save_summary_steps, log_step_count_steps=self.log_step_count_steps, disable_native_metrics=self.disable_native_metrics) for name, _ in get_type_hints(RunConfig).items(): value = getattr(self, name) if hasattr(conf, name) and value is not None: default = getattr(RunConfig, name) # must be and, because RunnerConfig value can be writen by command line # we cannot use default value in RunConfig to overwrite command line value if value != default and getattr(conf, name) != value: setattr(conf, name, value) # in case US tearm use CONSUL if conf.discovery_type == ServiceDiscoveryType.CONSUL: conf.discovery_type = ServiceDiscoveryType.ZK if not conf.enable_gpu_training: # set default value for embedding prefetch/postpush if conf.embedding_prefetch_capacity <= 0: conf.embedding_prefetch_capacity = 1 if not conf.enable_embedding_postpush: conf.enable_embedding_postpush = True # [todo] remove this when enable_realtime_training changed to enable_parameter_sync if self.enable_parameter_sync: if hasattr(conf, 'enable_realtime_training'): conf.enable_realtime_training = True elif hasattr(conf, 'enable_parameter_sync'): conf.enable_parameter_sync = True else: raise RuntimeError("enable_parameter_sync not set!") return conf def __post_init__(self): ser_data = self.to_json() DumpUtils().add_config(ser_data) # get user params params = vars(self) user_params = [] for name, _ in get_type_hints(RunConfig).items(): default_value = getattr(RunConfig, name) if default_value != params[name]: logging.info("save user param {} with value {}".format( name, params[name])) user_params.append(name) DumpUtils().add_user_params(user_params) @monolith_export class Estimator(object): """利用Estimator可以实现local模式与分布式模式的统一, 另外, Estimator可以帮助初始化/save/restore变量, 执行hooks, 写summary等 Args: model (:obj:`Model`): NativeModel对象 conf (:obj:`RunConfig`): 运行模型所要的配置 warm_start_from (:obj:`str`): 在保存saved_model时, 可以保存warmup文件. warm_start_from用于指定warmup文件的位置, 到目录名即可 """ def __init__(self, model, conf: Union[RunConfig, RunnerConfig], warm_start_from=None): self._runner_conf = conf.to_runner_config() if isinstance( conf, RunConfig) else conf self._model = model self._task = None self._warm_start_from = warm_start_from self._sync_backend = None self._kazoo_client = None if isinstance(conf, RunConfig): self._enable_loacl_profiling = conf.enable_local_profiling else: self._enable_loacl_profiling = False logging.info(self._runner_conf) if self._runner_conf.is_local: # local mode cannot asscess deep_insight self._model.metrics.enable_deep_insight = False else: self._model.metrics.enable_deep_insight = True if self._runner_conf.deep_insight_name: self._model.metrics.deep_insight_name = self._runner_conf.deep_insight_name if self._runner_conf.deep_insight_target: self._model.metrics.deep_insight_target = self._runner_conf.deep_insight_target if self._runner_conf.deep_insight_sample_ratio: self._model.metrics.deep_insight_sample_ratio = self._runner_conf.deep_insight_sample_ratio if self._runner_conf.enable_realtime_training and self._runner_conf.server_type == 'ps': assert self._runner_conf.bzid, "Business id cannot be none while realtime training." assert self._runner_conf.base_name, "Base name cannot be none while realtime training." zk_servers = self._runner_conf.zk_server or os.environ.get( 'zk_servers', default_zk_servers()) if self._runner_conf.unified_serving: self._sync_backend = ZKBackend(self._runner_conf.bzid, zk_servers=zk_servers) else: assert self._runner_conf.base_name, "Base name cannot be none while realtime training." self._kazoo_client = MonolithKazooClient(hosts=zk_servers) self._kazoo_client.start() agent_config = AgentConfig(bzid=self._runner_conf.bzid, base_name=self._runner_conf.base_name, deploy_type='ps', num_ps=self._runner_conf.num_ps, dc_aware=self._runner_conf.dc_aware) replica_watcher = ReplicaWatcher( self._kazoo_client, agent_config, zk_watch_address_family=self._runner_conf.zk_watch_address_family) self._sync_backend = replica_watcher.to_sync_wrapper() if self._runner_conf.params_override: logging.info("Override: {}".format(self._runner_conf.params_override)) params_override_dict = json.loads(self._runner_conf.params_override) if hasattr(model, 'p'): model.p.set(**params_override_dict) elif hasattr(model, 'params'): model.params.set(**params_override_dict) else: logging.warning('params_override error!') try: if not os.environ.get("HADOOP_HDFS_HOME"): env_utils.setup_hdfs_env() except Exception as e: logging.error('setup_hdfs_env fail {}!'.format(e)) os.environ["TF_GRPC_WORKER_CACHE_THREADS"] = str( self._runner_conf.tf_grpc_worker_cache_threads) os.environ["MONOLITH_GRPC_WORKER_SERVICE_HANDLER_MULTIPLIER"] = str( self._runner_conf.monolith_grpc_worker_service_handler_multiplier) # private attr self.__est: Optional[tf.estimator.Estimator] = None @property def _sess_config(self): return self._est._session_config @property def model_dir(self): return self._runner_conf.model_dir @property def config(self): return self._est._config @property def _est(self): if self.__est is None: model = copy.deepcopy(self._model) model.mode = tf.estimator.ModeKeys.PREDICT self._task = CpuTraining(config=self._runner_conf, task=model.instantiate()) # the default estimate for predict/export_saved_model/import_saved_model if 'TF_CONF' in os.environ: del os.environ['TF_CONF'] self.__est = tf.estimator.Estimator( self._task.create_model_fn(), model_dir=self._runner_conf.model_dir, config=tf.estimator.RunConfig( log_step_count_steps=self._runner_conf.log_step_count_steps), warm_start_from=self._warm_start_from) return self.__est def _init_fountain_env(self): if self._model.train.use_fountain and bool( self._runner_conf.fountain_zk_host) and bool( self._runner_conf.fountain_model_name): logging.info("Override Fountain Params:{}; {}".format( self._runner_conf.fountain_model_name, self._runner_conf.fountain_zk_host)) self._model.train.fountain_zk_host = self._runner_conf.fountain_zk_host self._model.train.fountain_model_name = self._runner_conf.fountain_model_name def close(self): if self._sync_backend is not None: try: self._sync_backend.stop() except Exception as e: logging.error(e) if self._kazoo_client is not None: try: self._kazoo_client.stop() except Exception as e: logging.info(e) try: self._kazoo_client.close() except Exception as e: logging.info(e) def get_variable_value(self, name): return self._est.get_variable_value(name) def get_variable_names(self): return self._est.get_variable_names() def latest_checkpoint(self): return self._est.latest_checkpoint() def train(self, steps=None, max_steps=None, hooks=None): assert hooks is None or isinstance(hooks, list) and \ all(isinstance(o, tf.estimator.SessionRunHook) for o in hooks) set_metric_prefix("monolith.training.{}".format( self._runner_conf.deep_insight_name)) model = copy.deepcopy(self._model) if not isinstance(model, InstantiableParams): model.file_name = self._model.file_name model.mode = tf.estimator.ModeKeys.TRAIN if steps is not None: model.train.steps = steps if max_steps is not None: model.train.max_steps = max_steps self._init_fountain_env() if self._runner_conf.is_local: if not self._runner_conf.model_dir: model_dir = "/tmp/{}/{}".format(getpass.getuser(), model.name) else: model_dir = self._runner_conf.model_dir DumpUtils().record_params(model) self.__est = local_train_internal(model, self._runner_conf, model_dir=model_dir, steps=steps, profiling=self._enable_loacl_profiling, user_hooks=hooks) DumpUtils().dump(f'{self._runner_conf.model_dir}/model_dump') else: DumpUtils().enable = False if self._sync_backend is not None: self._sync_backend.start() self._sync_backend.subscribe_model(self._runner_conf.model_name or model.metrics.deep_insight_name) logging.info("Environment vars: %s", os.environ) logging.info("Flags: %s", flags.FLAGS.flag_values_dict()) logging.info(f'{model.p}') if self._runner_conf.enable_full_sync_training: init_sync_train_and_update_conf(self._runner_conf) self.__est = distributed_sync_train(self._runner_conf, params=model, sync_backend=self._sync_backend, user_hooks=hooks) else: with monolith_discovery(self._runner_conf) as discovery: if self._runner_conf.enable_gpu_training: device_utils.enable_gpu_training() model.train.use_gpu_emb_table = False if self._runner_conf.enable_partial_sync_training and self._runner_conf.server_type == "worker": try_init_cuda() self._runner_conf.device_fn = device_utils.get_device_fn() model.train.slow_start_steps = 0 self.__est = distributed_train(config=self._runner_conf, discovery=discovery, params=model, sync_backend=self._sync_backend, user_hooks=hooks) self.close() def evaluate(self, steps=None, hooks=None): assert hooks is None or isinstance(hooks, list) and \ all(isinstance(o, tf.estimator.SessionRunHook) for o in hooks) model = copy.deepcopy(self._model) model.mode = tf.estimator.ModeKeys.EVAL if not isinstance(model, InstantiableParams): model.file_name = self._model.file_name self._runner_conf.mode = tf.estimator.ModeKeys.EVAL if steps is not None: model.train.steps = steps self._init_fountain_env() if self._runner_conf.is_local: DumpUtils().record_params(model) if not self._runner_conf.model_dir: model_dir = "/tmp/{}/{}".format(getpass.getuser(), model.name) else: model_dir = self._runner_conf.model_dir self.__est = local_train_internal(model, self._runner_conf, model_dir=model_dir, steps=steps, profiling=self._enable_loacl_profiling, user_hooks=hooks) DumpUtils().dump(f'{self._runner_conf.model_dir}/model_dump') else: DumpUtils().enable = False logging.info(f'{model.p}') logging.info("Environment vars: %s", os.environ) logging.info("Flags: %s", flags.FLAGS.flag_values_dict()) if self._runner_conf.enable_full_sync_training: init_sync_train_and_update_conf(self._runner_conf) self.__est = distributed_sync_train(self._runner_conf, params=model, sync_backend=self._sync_backend, user_hooks=hooks) else: with monolith_discovery(self._runner_conf) as discovery: if self._runner_conf.enable_gpu_training: device_utils.enable_gpu_training() model.train.use_gpu_emb_table = False if self._runner_conf.enable_partial_sync_training and self._runner_conf.server_type == "worker": try_init_cuda() self._runner_conf.device_fn = device_utils.get_device_fn() self.__est = distributed_train(config=self._runner_conf, discovery=discovery, params=model, sync_backend=self._sync_backend) self.close() def predict(self, predict_keys=None, hooks=None, checkpoint_path=None, yield_single_examples=True): est = self._est # create tf estimator input_fn = self._task.create_input_fn(tf.estimator.ModeKeys.PREDICT) est.predict(input_fn, predict_keys, hooks, checkpoint_path, yield_single_examples) self.close() def export_saved_model(self, batch_size=64, name=None, dense_only: bool = False, enable_fused_layout: bool = False): model = copy.deepcopy(self._model) runner_conf = copy.deepcopy(self._runner_conf) runner_conf.enable_fused_layout = enable_fused_layout model.name = name or "demo_export" model.train.per_replica_batch_size = batch_size model.mode = tf.estimator.ModeKeys.PREDICT model_dir = runner_conf.model_dir export_dir_base = os.path.join(model_dir, model.serving.export_dir_base) warmup_file = runner_conf.warmup_file with ParserCtx(enable_fused_layout=enable_fused_layout): task = CpuTraining(config=runner_conf, task=model.instantiate()) exporter = create_exporter(task, model_dir, warmup_file, export_dir_base, dense_only) serving_input_receiver_fn = task.create_serving_input_receiver_fn() exporter.export_saved_model(serving_input_receiver_fn) @monolith_export def import_saved_model(saved_model_path: str, input_name: str = "instances", output_name: str = 'output', signature: str = None): """导出saved_model Args: saved_model_path (:obj:`str`): saved_model路径 """ class saved_model(object): def __init__(self, saved_model_path, signature, inputs, outputs): basename = os.path.basename(saved_model_path) if not basename.isnumeric(): versions = [] for subitem in tf.io.gfile.listdir(saved_model_path): if subitem.isnumeric(): versions.append(int(subitem)) if versions: versions.sort() saved_model_path = os.path.join(saved_model_path, str(versions[-1])) else: raise RuntimeError(f"no models in dir {saved_model_path}") self._saved_model_path = saved_model_path if signature: self._signature = signature else: self._signature = tf.compat.v1.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY if inputs: self._inputs = inputs if isinstance(inputs, (list, tuple)) else [inputs] else: self._inputs = None if outputs: self._outputs = outputs if isinstance(outputs, (list, tuple)) else [outputs] else: self._outputs = None def __enter__(self): class infer(object): def __init__(self, graph, sess, placeholders, output_dict, output_name_map): self._graph = graph self._sess = sess self._placeholders = placeholders self._output_dict = output_dict self._output_name_map = output_name_map def __call__(self, features: Dict[str, np.ndarray]) -> List[np.ndarray]: with self._graph.as_default(), self._sess.as_default(): if len(self._placeholders) == 1: placeholders = next(iter(self._placeholders.values())) result = sess.run(self._output_dict, feed_dict={placeholders: features}) else: result = sess.run(self._output_dict, feed_dict={ self._placeholders[name]: feature for name, feature in features.items() }) return { self._output_name_map[key]: tensor for key, tensor in result.items() } tag = tf.compat.v1.saved_model.tag_constants.SERVING graph = tf.compat.v1.Graph() sess = tf.compat.v1.Session(graph=graph) with graph.as_default(), sess.as_default(): imported = tf.compat.v1.saved_model.load(sess, {tag}, self._saved_model_path) print(imported.signature_def, flush=True) signature_def = imported.signature_def[self._signature] placeholders: Dict[str, tf.compat.v1.placeholder] = {} for input_name in self._inputs: input_ph_name = signature_def.inputs[input_name].name input_ph = graph.get_tensor_by_name(input_ph_name) placeholders[input_name] = input_ph output_dict, output_name_map = {}, {} if self._outputs: for output_name in self._outputs: output_tensor_name = signature_def.outputs[output_name].name output_tensor = graph.get_tensor_by_name(output_tensor_name) if output_tensor_name.endswith(':0'): output_tensor_name = output_tensor_name[0:-2] output_dict[output_tensor_name] = output_tensor output_name_map[output_tensor_name] = output_name else: for output_name, tensor in signature_def.outputs.items(): output_tensor_name = tensor.name output_tensor = graph.get_tensor_by_name(output_tensor_name) if output_tensor_name.endswith(':0'): output_tensor_name = output_tensor_name[0:-2] output_dict[output_tensor_name] = output_tensor output_name_map[output_tensor_name] = output_name logging.info('import_saved_model finished') return infer(graph, sess, placeholders, output_dict, output_name_map) def __exit__(self, exc_type, exc_val, exc_tb): logging.info('exit import_saved_model') return saved_model(saved_model_path, signature, input_name, output_name) ================================================ FILE: monolith/native_training/estimator_dist_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import socket import time import unittest from multiprocessing import Process import tensorflow as tf from monolith.native_training.runner_utils import RunnerConfig from monolith.native_training.model import TestFFMModel from monolith.native_training.service_discovery import TfConfigServiceDiscovery from monolith.native_training.estimator import Estimator from monolith.native_training.utils import get_test_tmp_dir model_name = 'dist_test' model_dir = "{}/{}/ckpt".format(get_test_tmp_dir(), model_name) export_base = "{}/{}/saved_model".format(get_test_tmp_dir(), model_name) _EXIT_SUCCESS = 0 def get_free_port(): """TODO(fitzwang) this function is not safe in preemption env""" sock = socket.socket() sock.bind(('', 0)) ip, port = sock.getsockname() sock.close() return port def get_cluster(ps_num, worker_num): cluster = { 'ps': ['localhost:{}'.format(get_free_port()) for _ in range(ps_num)], 'worker': [ 'localhost:{}'.format(get_free_port()) for _ in range(worker_num - 1) ], 'chief': ['localhost:{}'.format(get_free_port())] } return cluster def get_saved_model_path(export_base): try: candidates = [] for f in os.listdir(export_base): if not (f.startswith('temp') or f.startswith('tmp')): fname = os.path.join(export_base, f) if os.path.isdir(fname): candidates.append(fname) candidates.sort() return candidates[-1] except: return "" class EstimatorTrainTest(unittest.TestCase): """The tests here are for testing complilation.""" params = None @classmethod def setUpClass(cls) -> None: if tf.io.gfile.exists(model_dir): tf.io.gfile.rmtree(model_dir) params = TestFFMModel.params() params.metrics.enable_deep_insight = False params.train.per_replica_batch_size = 64 cls.params = params def train(self): ps_num, worker_num = 2, 3 cluster = get_cluster(ps_num, worker_num) def start_server(server_type, index): task = {'type': server_type, 'index': index} tf_confg = {'cluster': cluster, 'task': task} discovery = TfConfigServiceDiscovery(tf_confg) dct_config = RunnerConfig(index=discovery.index, model_dir=model_dir, ps_num=ps_num, worker_num=worker_num, server_type=discovery.server_type) estimator = Estimator(self.params, dct_config, discovery) estimator.train(steps=10) threads = [] for i in range(ps_num): thread = Process(target=start_server, args=('ps', i)) thread.start() threads.append(thread) for i in range(worker_num): if i == 0: thread = Process(target=start_server, args=('chief', i)) else: thread = Process(target=start_server, args=('worker', i - 1)) thread.start() threads.append(thread) if i == 0: time.sleep(1) for thread in threads: thread.join() assert thread.exitcode == _EXIT_SUCCESS def eval(self): ps_num, worker_num = 2, 3 cluster = get_cluster(ps_num, worker_num) def start_server(server_type, index): task = {'type': server_type, 'index': index} tf_confg = {'cluster': cluster, 'task': task} discovery = TfConfigServiceDiscovery(tf_confg) dct_config = RunnerConfig(index=discovery.index, model_dir=model_dir, ps_num=ps_num, worker_num=worker_num, server_type=discovery.server_type) estimator = Estimator(self.params, dct_config, discovery) estimator.evaluate(steps=10) threads = [] for i in range(ps_num): thread = Process(target=start_server, args=('ps', i)) thread.start() threads.append(thread) for i in range(worker_num): if i == 0: thread = Process(target=start_server, args=('chief', i)) else: thread = Process(target=start_server, args=('worker', i - 1)) thread.start() threads.append(thread) if i == 0: time.sleep(1) for thread in threads: thread.join() assert thread.exitcode == _EXIT_SUCCESS def test_dist(self): self.train() self.eval() if __name__ == "__main__": unittest.main() ================================================ FILE: monolith/native_training/estimator_mode_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import time import socket import unittest import copy import subprocess from absl import flags, logging from typing import Dict, List import tensorflow as tf from tensorflow.python.framework import test_util from monolith.native_training.runner_utils import RunnerConfig from monolith.native_training.estimator import Estimator, import_saved_model, RunConfig from monolith.native_training.utils import get_test_tmp_dir from monolith.native_training.tasks.sparse_dense_gpu.model_test import gen_input_file, MultiHeadModel FLAGS = flags.FLAGS #copy from monolith/native_training/cpu_training_test.py class DistributedTrainTest(tf.test.TestCase): _DISTRIBUTED_TRAIN_BINARY = "monolith/native_training/tasks/sparse_dense_gpu/model" @classmethod def setUpClass(cls) -> None: FLAGS.dataset_input_patterns = f"/tmp/estimator_mode_test_eb.pb" gen_input_file(FLAGS.dataset_input_patterns) # maybe file is not enough to train def link_some_file(suffix): if not os.path.exists(FLAGS.dataset_input_patterns + suffix): os.symlink(FLAGS.dataset_input_patterns, FLAGS.dataset_input_patterns + suffix) for i in range(10): link_some_file(str(i)) FLAGS.dataset_input_patterns += "{INT(0,99)}" def find_free_port(self, count): port_list = [] while len(port_list) < count: sock = socket.socket() sock.bind(('', 0)) port = sock.getsockname()[1] if port not in port_list: port_list.append(port) return port_list def _run_test(self, task_name: str, other_args: List, num_ps: int, num_workers: int, num_dsworkers: int, other_env: Dict = {}, worker_args: List = [], use_mpi_run=False): cur_modir = "{}/{}/ckpt".format(get_test_tmp_dir(), task_name) if tf.io.gfile.exists(cur_modir): tf.io.gfile.rmtree(cur_modir) os.makedirs(cur_modir) logging.info(f"show cur_modir: {cur_modir}") args_tmpl = [ self._DISTRIBUTED_TRAIN_BINARY, "--mode=train", f"--model_dir={cur_modir}", f"--num_ps={num_ps}", f"--num_workers={num_workers}", f"--uuid={self._testMethodName}", f"--dataset_input_patterns={FLAGS.dataset_input_patterns}", f"--dataset_input_use_snappy=False", #f"--dataset_input_use_parquet=True", "--lagrangex_header=True", "--sort_id=False", "--kafka_dump=False", "--kafka_dump_prefix=False", "--data_type=ExampleBatch", "--discovery_type=mlp", "--operation_timeout_in_ms=10000", "--disable_native_metrics=True", "--dataset_use_dataservice=True" if num_dsworkers else "--dataset_use_dataservice=False", "--cluster_type=stable", ] + other_args my_env = os.environ.copy() my_env.update(other_env) def fill_host_env(role_name, num_role, cur_port_list): if num_role <= 0: return all_host = [] all_addr = [] role_name = role_name.upper() for i in range(num_role): all_host.append("localhost") cur_port = cur_port_list[i] all_addr.append(f"localhost:{cur_port}") my_env[f"MLP_{role_name}_{i}_PORT"] = f"{cur_port}" my_env[f"MLP_{role_name}_{i}_HOST"] = f"localhost" my_env[f"MLP_{role_name}_{i}_PRIMARY_HOST"] = f"localhost" my_env[f"MLP_{role_name}_NUM"] = f"{num_role}" my_env[f"MLP_{role_name}_ALL_HOSTS"] = f"{','.join(all_host)}" my_env[f"MLP_{role_name}_ALL_PRIMARY_HOSTS"] = my_env[ f"MLP_{role_name}_ALL_HOSTS"] my_env[f"MLP_{role_name}_ALL_ADDRS"] = f"{','.join(all_addr)}" my_env[f"MLP_{role_name}_ALL_PRIMARY_ADDRS"] = my_env[ f"MLP_{role_name}_ALL_ADDRS"] #data_service_dispachter num_dispatcher = 0 if num_dsworkers: num_dispatcher = 1 all_port = self.find_free_port(num_ps + num_workers + num_dsworkers + num_dispatcher) ps_port = all_port[:num_ps] worker_port = all_port[num_ps:num_ps + num_workers] dsworker_port = all_port[num_ps + num_workers:-num_dispatcher] dispatcher_port = all_port[-num_dispatcher:] fill_host_env('ps', num_ps, ps_port) fill_host_env('worker', num_workers, worker_port) fill_host_env('dispatcher', num_dispatcher, dispatcher_port) fill_host_env('dsworker', num_dsworkers, dsworker_port) processes = {} log_files = [] def start_process(role_name, num_role, cur_port_list, use_mpi_run=False): if use_mpi_run: hostfile = f"{cur_modir}/../hostfile" f = open(hostfile, "w") f.write(f"localhost slots={num_role}") f.close() args = copy.copy(args_tmpl) args.append(f"--server_type={role_name}") if role_name == "worker": args += worker_args cur_env = copy.deepcopy(my_env) cur_env["MLP_ROLE"] = role_name cur_env["MLP_PORT"] = f"{cur_port_list[0]}" cur_env["MLP_SSH_PORT"] = f"{worker_port[0]}" cur_env["MONOLITH_WITH_HOROVOD"] = f"1" cur_env["MONOLITH_WITH_HOROVOD_FID_G2G"] = f"1" cur_env["MONOLITH_WITH_ALLREDUCE_FUSION"] = f"one" #cur_env["MONOLITH_GPU_FEATURE_FACTORY_FUSION_LEVEL"] = f"1" #cur_env["HOROVOD_MPI_THREADS_DISABLE"] = f"1" #cur_env["GPU_AFFINITY_NIC_ADDRESS"] = f"1" #cur_env["NCCL_SOCKET_IFNAME"] = f"eth0" #cur_env["NCCL_P2P_LEVEL"] = f"1" mpi_run_args = [ "mpirun", "--map-by", f"ppr:{num_role}:node", "-np", f"{num_role}", "--hostfile", hostfile, "--allow-run-as-root", "-oversubscribe", "--tag-output", "--report-bindings", #"--mca", "btl_tcp_if_include", "eth0", "--mca", "oob_tcp_if_include", "eth0" ] for k, v in cur_env.items(): mpi_run_args.append("-x") mpi_run_args.append(f"{k}={v}") args = mpi_run_args + args process = subprocess.Popen(args) logging.info(f"start a process for {role_name}:{range(num_role)}") processes[f"{role_name}:{0}"] = process for i in range(1, num_role): processes[f"{role_name}:{i}"] = None else: for i in reversed(range(num_role)): log_file = open(cur_modir + f"/../{role_name}_{i}.log", 'w') log_files.append(log_file) args = copy.copy(args_tmpl) args.append(f"--server_type={role_name}") args.append("--index={}".format(i)) cur_env = copy.deepcopy(my_env) cur_env["MLP_ROLE"] = role_name cur_env["MLP_ROLE_INDEX"] = f"{i}" cur_env["MLP_PORT"] = f"{cur_port_list[i]}" cur_env["MLP_SSH_PORT"] = f"{worker_port[0]}" #if i == 0 and role_name == "worker": # time.sleep(5) ''' shell_commond = "" for k, v in cur_env.items(): if "BASH_FUNC_" in k: continue shell_commond += f"{k}={v} " shell_commond += f" bazel-bin/{args[0]} " for arg in args[1:]: shell_commond += f"{arg} " logging.info( f"start a shell for {role_name}:{i} \n {shell_commond}") ''' process = subprocess.Popen(args, env=cur_env) logging.info(f"start a process for {role_name}:{i}") processes[f"{role_name}:{i}"] = process start_process('dispatcher', num_dispatcher, dispatcher_port) start_process('dsworker', num_dsworkers, dsworker_port) start_process('ps', num_ps, ps_port) start_process('worker', num_workers, worker_port, use_mpi_run=use_mpi_run) print(" ".join(args_tmpl), num_ps, num_workers, num_dsworkers) def wait_for_process(role_name, num_role, timeout=10, ignore_timeout=False): for i in range(num_role): role = f"{role_name}:{i}" if role not in processes: continue process = processes[role] if process is None: continue if not ignore_timeout: self.assertEqual(process.wait(timeout=timeout), 0) else: try: self.assertEqual(process.wait(timeout=timeout), 0) except subprocess.TimeoutExpired as e: logging.warning(f"exit process for {role} timeout") process.terminate() processes.pop(role) logging.info(f"exit process for {role}") wait_for_process('worker', 1, 250) wait_for_process('worker', num_workers, timeout=10, ignore_timeout=True) wait_for_process('ps', num_ps, timeout=1) wait_for_process('dsworker', num_dsworkers, timeout=1, ignore_timeout=True) #maybe chief port not free wait_for_process('dispatcher', num_dispatcher, timeout=1, ignore_timeout=True) #maybe chief port not free for log_file in log_files: log_file.flush() log_file.close() tf.io.gfile.rmtree(cur_modir) def run_cpu(self, name, other_args): # TODO cpu mode run gpu have error if test_util.is_gpu_available(cuda_only=True): return args = [ "--enable_gpu_training=False", "--enable_sync_training=False", "--embedding_prefetch_capacity=1", "--enable_embedding_postpush=True", "--chief_timeout_secs=20", ] + other_args num_ps = 2 num_workers = 2 num_dsworkers = 0 self._run_test(f"full_cpu_{name}", args, num_ps, num_workers, num_dsworkers) def test_cpu0(self): args = [ "--enable_fused_layout=False", "--use_native_multi_hash_table=False", ] self.run_cpu('0', args) def test_cpu1(self): args = [ "--enable_fused_layout=False", "--use_native_multi_hash_table=True", ] self.run_cpu('1', args) def test_cpu2(self): args = [ "--enable_fused_layout=True", "--use_native_multi_hash_table=True", ] self.run_cpu('2', args) def test_cpu3(self): args = [ "--enable_fused_layout=True", "--use_native_multi_hash_table=False", ] self.run_cpu('3', args) def sparse_dense_run(self, name, other_args): if not test_util.is_gpu_available(cuda_only=True): return gpus = tf.config.list_physical_devices('GPU') args = [ "--enable_gpu_training=True", "--enable_sync_training=True", "--enable_partial_sync_training=True", "--embedding_prefetch_capacity=1", "--enable_embedding_postpush=True", '--params_override={"train.max_steps": 10}', ] + other_args worker_args = [] num_ps = 2 num_workers = min(2, len(gpus)) num_dsworkers = 1 other_env = {} self._run_test(f"sparse_dense_{name}", args, num_ps, num_workers, num_dsworkers, other_env=other_env, worker_args=worker_args, use_mpi_run=True) def test_sparse_dense0(self): args = [ "--enable_fused_layout=True", "--use_native_multi_hash_table=False", ] self.sparse_dense_run('0', args) def test_sparse_dense1(self): args = [ "--enable_fused_layout=True", "--use_native_multi_hash_table=True", ] self.sparse_dense_run('1', args) def test_sparse_dense2(self): args = [ "--enable_fused_layout=False", "--use_native_multi_hash_table=False", ] self.sparse_dense_run('2', args) def test_sparse_dense3(self): args = [ "--enable_fused_layout=False", "--use_native_multi_hash_table=True", ] self.sparse_dense_run('3', args) def full_gpu_run(self, name, other_args): if not test_util.is_gpu_available(cuda_only=True): return gpus = tf.config.list_physical_devices('GPU') args = [ "--enable_gpu_training=True", "--enable_sync_training=True", "--reorder_fids_in_data_pipeline=True", "--filter_type=probabilistic_filter", "--embedding_prefetch_capacity=1", "--enable_async_optimize=False", '--params_override={"train.max_steps": 10}', ] + other_args worker_args = [] num_ps = 0 num_workers = min(2, len(gpus)) num_dsworkers = 1 other_env = {} self._run_test(f"full_gpu_{name}", args, num_ps, num_workers, num_dsworkers, other_env=other_env, worker_args=worker_args, use_mpi_run=True) def test_full_gpu_0(self): args = [ "--enable_fused_layout=True", "--use_native_multi_hash_table=False", ] self.full_gpu_run('0', args) def test_full_gpu_1(self): args = [ "--enable_fused_layout=True", "--use_native_multi_hash_table=True", ] self.full_gpu_run('1', args) def test_full_gpu_2(self): args = [ "--enable_fused_layout=False", "--use_native_multi_hash_table=False", ] self.full_gpu_run('2', args) def test_full_gpu_3(self): args = [ "--enable_fused_layout=False", "--use_native_multi_hash_table=True", ] self.full_gpu_run('3', args) if __name__ == "__main__": tf.compat.v1.disable_eager_execution() tf.test.main() ================================================ FILE: monolith/native_training/estimator_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import time import unittest import tensorflow as tf from monolith.native_training.runner_utils import RunnerConfig from monolith.native_training.input import generate_ffm_example from monolith.native_training.model import TestFFMModel,\ _VOCAB_SIZES, _NUM_EXAMPLES from monolith.native_training.estimator import Estimator, import_saved_model from monolith.native_training.utils import get_test_tmp_dir model_name = 'estimator_test' model_dir = "{}/{}/ckpt".format(get_test_tmp_dir(), model_name) export_base = "{}/{}/ckpt/exported_models".format(get_test_tmp_dir(), model_name) def get_saved_model_path(export_base): try: candidates = [] for f in os.listdir(export_base): if not (f.startswith('temp') or f.startswith('tmp')): fname = os.path.join(export_base, f) if os.path.isdir(fname): candidates.append(fname) candidates.sort() return candidates[-1] except: return "" class EstimatorTrainTest(unittest.TestCase): """The tests here are for testing complilation.""" params = None @classmethod def setUpClass(cls) -> None: if tf.io.gfile.exists(model_dir): tf.io.gfile.rmtree(model_dir) params = TestFFMModel.params() params.metrics.enable_deep_insight = False params.train.per_replica_batch_size = 64 params.serving.export_dir_base = export_base params.serving.shared_embedding = True cls.params = params cls.conf = RunnerConfig(is_local=True, num_ps=0, model_dir=model_dir, use_native_multi_hash_table=False) def train(self): estimator = Estimator(self.params, self.conf) estimator.train(steps=10) def eval(self): estimator = Estimator(self.params, self.conf) estimator.evaluate(steps=10) def predict(self): estimator = Estimator(self.params, self.conf) estimator.predict() def export_saved_model(self): estimator = Estimator(self.params, self.conf) estimator.export_saved_model() def import_saved_model(self): saved_model_path = get_saved_model_path(export_base) print('saved_model_path', saved_model_path, flush=True) with import_saved_model(saved_model_path=saved_model_path) as infer: # There are some bugs here since functions to restore tables are not called. Will # resolve this by using resource concept in the future. start = time.time() num_ins = 0 for i in range(10): features = [ generate_ffm_example(_VOCAB_SIZES) for _ in range(_NUM_EXAMPLES) ] num_ins += len(features) infer(features) end = time.time() print(start, end, num_ins, 1000 * (end - start) / 10) def test_local(self): self.train() self.eval() self.predict() self.export_saved_model() self.import_saved_model() if __name__ == "__main__": tf.compat.v1.disable_eager_execution() tf.test.main() ================================================ FILE: monolith/native_training/feature.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import abc from collections import namedtuple import copy import enum from dataclasses import dataclass, asdict, field from typing import Callable, Dict, Iterable, List, Tuple, Set, NamedTuple, Union import sys import os from absl import logging import tensorflow as tf from monolith.native_training import device_utils from monolith.native_training import distribution_ops from monolith.native_training import embedding_combiners from monolith.native_training import entry from monolith.native_training import learning_rate_functions from monolith.native_training import ragged_utils from monolith.native_training.runtime.hash_table import \ embedding_hash_table_pb2 from monolith.native_training.model_export.export_context import is_exporting from monolith.native_training.monolith_export import monolith_export from monolith.native_training import prefetch_queue _FEATURE_STRAT_END_KEY = "{}:{}_{}" # Default expire time is 100 years. DEFAULT_EXPIRE_TIME = 36500 class FeatureEmbTable(abc.ABC): """Used by framework. Do not use in the user code directly. Instead, using FeatureSlot.""" def add_feature_slice(self, segment: embedding_hash_table_pb2.EntryConfig.Segment, learning_rate_fn=None): """ Add one feature slice for this embedding table. """ pass @abc.abstractmethod def embedding_lookup(self, feature_name: str, start: int, end: int) -> tf.Tensor: """ Returns combined embedding tensors for the given feature name. """ pass def set_feature_metadata(self, feature_name: str, combiner: embedding_combiners.Combiner): pass class FeatureSlice(NamedTuple): """Represents a slice of a feature slot.""" feature_slot: "FeatureSlot" start: int end: int @dataclass class FeatureSlotConfig: name: str = None has_bias: bool = False bias_initializer: entry.Initializer = entry.ZerosInitializer() bias_optimizer: entry.Optimizer = entry.FtrlOptimizer( initial_accumulator_value=1e-6, beta=1.0) bias_compressor: entry.Compressor = entry.Fp32Compressor() bias_learning_rate_fn: Callable = None default_vec_initializer: entry.Initializer = entry.RandomUniformInitializer() default_vec_optimizer: entry.Optimizer = entry.AdagradOptimizer( initial_accumulator_value=1.0) default_vec_compressor: entry.Compressor = entry.Fp16Compressor() default_vec_learning_rate_fn: Callable = None hashtable_config: entry.HashTableConfig = entry.CuckooHashTableConfig() slot_id: int = None occurrence_threshold: int = 0 expire_time: int = DEFAULT_EXPIRE_TIME def __post_init__(self): if not self.name: self.name = str(self.slot_id) @monolith_export class FeatureSlot: """维护特征与HashTable的关系, 隐藏HashTable的细节. FeatureSlot可以看成是用户视角的HashTable Args: table (:obj:`FeatureEmbTable`): 内部HashTable config (:obj:`FeatureSlotConfig`): 特征配置 """ def __init__(self, table: FeatureEmbTable, config: FeatureSlotConfig): self._table = table self._config = config self._current_dim_size = 0 self._feature_columns = set() if self._config.has_bias: self._bias_slice = self.add_feature_slice( 1, self._config.bias_initializer, self._config.bias_optimizer, self._config.bias_compressor, self._config.bias_learning_rate_fn) def add_feature_slice(self, dim_size: int, initializer: entry.Initializer = None, optimizer: entry.Optimizer = None, compressor: entry.Compressor = None, learning_rate_fn=None) -> FeatureSlice: """ 在哈希表中增加一段长度为|dim_size|,并采用|initializer|作为初始化器,|optimizer|作为 优化器,同时在serving中使用|compressor|作为压缩器的embedding. 返回一个feature slice被FeatureColumn使用 Args: dim_size (:obj:`float`): 这段embedding slice的长度 optimizer (:obj:`entry.Optimizer`): 这段embedding slice的初始化器 compressor (:obj:`entry.Compressor`): 这段embedding slice的初始化器 learning_rate_fn (:obj:`Callable`): 如果不为None,覆盖在optimizer中定义的学习率 """ initializer = initializer or self._config.default_vec_initializer optimizer = optimizer or self._config.default_vec_optimizer compressor = compressor or self._config.default_vec_compressor learning_rate_fn = learning_rate_fn or self._config.default_vec_learning_rate_fn segment = entry.CombineAsSegment(dim_size, initializer, optimizer, compressor) self._table.add_feature_slice(segment, learning_rate_fn=learning_rate_fn) s = FeatureSlice(self, self._current_dim_size, self._current_dim_size + dim_size) self._current_dim_size = self._current_dim_size + dim_size return s def get_bias_slice(self): assert self._config.has_bias return self._bias_slice def _add_feature_column(self, fc): self._feature_columns.add(fc) self._table.set_feature_metadata(fc.feature_name, fc.combiner) def _fc_embedding_lookup(self, feature_name: str, s: FeatureSlice): return self._table.embedding_lookup(feature_name, s.start, s.end) def get_feature_columns(self): return self._feature_columns @property def slot(self): return int(self._config.name) @property def name(self): return self._config.name @monolith_export class FeatureColumn: """将FeatureColumn与输入的Feature进行链接 Args: feature_slot (:obj:`FeatureSlot`): 这个类对应的FeatureSlot feature_name (:obj:`str`): 这个类对应的链接的feature_name(在input_fn返回的结果) """ @classmethod def reduce_sum(cls): return embedding_combiners.ReduceSum() @classmethod def reduce_mean(cls): return embedding_combiners.ReduceMean() @classmethod def first_n(cls, seq_length: int): return embedding_combiners.FirstN(seq_length) def __init__(self, feature_slot: FeatureSlot, feature_name: str, combiner=None): self._feature_name = feature_name self._feature_slot = feature_slot self._combiner = combiner or self.reduce_sum() self._size_tensor = None feature_slot._add_feature_column(self) def embedding_lookup(self, s: FeatureSlice) -> tf.Tensor: """返回feature_name在feature_slot中进行查询之后的结果. """ assert s.feature_slot == self._feature_slot, "Slice must come from the dedicated feature slot." return self._feature_slot._fc_embedding_lookup(self._feature_name, s) def get_all_embeddings_concat(self) -> tf.Tensor: """ Returns concatenated all embeddings owned by this column. Used in calculate gradients """ return self._feature_slot._table.embedding_lookup(self._feature_name, None, None) def get_all_embedding_slices(self) -> List[tf.Tensor]: """ Returns concatenated all embedding slices owned by this column. Used in computing gradients """ output_list = [] for k, v in self._embedding_slices.items(): if self._feature_name in k: output_list.append(v) return output_list @property def feature_name(self): return self._feature_name @property def feature_slot(self) -> FeatureSlot: return self._feature_slot @property def combiner(self) -> embedding_combiners.Combiner: return self._combiner def get_bias(self) -> tf.Tensor: """字节内部使用. 请勿直接使用""" bias_slice = self._feature_slot.get_bias_slice() return self._feature_slot._fc_embedding_lookup(self._feature_name, bias_slice) def set_size_tensor(self, row_lengths: tf.Tensor): assert isinstance(self._combiner, embedding_combiners.FirstN ), "This function is only supported in a sequence feature." seq_length = self._combiner.max_seq_length # Convert row_lengths to [B, max_seq_length] Tensor, in which # the first row_length elements of each row are 1, and the rest are # 0. This is used as the size_tensor batch_size = tf.size(row_lengths) # 0-D Tensor boolean_mask = tf.less( tf.reshape( tf.tile(tf.range(0, seq_length), [batch_size]), [batch_size, -1], ), tf.expand_dims(row_lengths, 1)) # [B, max_seq_length] Tensor self._size_tensor = tf.cast(boolean_mask, tf.int32, name='size_tensor') def get_size_tensor(self): return self._size_tensor FeatureColumnV1 = FeatureColumn SliceConfig = namedtuple("SliceConfig", ["segment", "learning_rate_fn"]) class TableConfig(NamedTuple): slice_configs: List[SliceConfig] feature_names: Set[str] unmerged_slice_dims: List[int] hashtable_config: entry.HashTableConfig feature_to_combiners: Dict[str, embedding_combiners.Combiner] class FeatureFactory(abc.ABC): """Used to get features in the model_fn.""" def __init__(self): self.slot_to_occurrence_threshold = {} self.slot_to_expire_time = {} @abc.abstractmethod def create_feature_slot(self, config: FeatureSlotConfig) -> FeatureSlot: """Creates a feature slot by config.""" def apply_gradients(self, grads_and_vars: Iterable[Tuple[tf.Tensor, tf.Tensor]], req_time: tf.Tensor = None, scale: tf.Tensor = 1) -> tf.Operation: """ Applies the gradients to Features owned by this factory. The reason why we do not make per table based apply_gradients is because of performance reason. In the runtime, we may do a batch lookup. Args: grads_and_vars - vars must be the all_embedding_concat from each FeatureColumn. """ raise NotImplementedError( "apply_gradients is not supported in this factory.") class DummyFeatureEmbTable(FeatureEmbTable): """It is used to collect config of table from model_fn.""" def __init__(self, batch_size, hashtable_config): self._batch_size = batch_size self._hashtable_config = hashtable_config self._slices = [] self._merged_slices = [] self._feature_names = set() self._feature_to_combiner = {} self._dim_size = 0 def add_feature_slice(self, segment: embedding_hash_table_pb2.EntryConfig.Segment, learning_rate_fn=None): # The learning_rate_fn can be an instance of LearningRateFunction or a float # value. By default, set the learning_rate_fn according to the optimizer config. if learning_rate_fn is None: opt_config = getattr(segment.opt_config, segment.opt_config.WhichOneof("type")) if hasattr(opt_config, 'warmup_steps') and opt_config.warmup_steps > 0: learning_rate_fn = learning_rate_functions.PolynomialDecay( initial_learning_rate=0.0, decay_steps=opt_config.warmup_steps, end_learning_rate=opt_config.learning_rate) else: learning_rate_fn = opt_config.learning_rate self._dim_size += segment.dim_size self._slices.append(SliceConfig(segment, learning_rate_fn)) def embedding_lookup(self, feature_name: str, start: int, end: int) -> tf.Tensor: if start is None and end is None: # This is the special case for gradients. start = 0 end = self._dim_size # TODO(leqi.zou): Maybe we should add a dict here to make sure for the # same look up we should return same result. emb_ph = tf.compat.v1.placeholder(tf.float32, shape=[self._batch_size, end - start]) key = tf.compat.v1.ragged.placeholder(tf.int64, 1, []) combiner = self._feature_to_combiner[feature_name] combined = combiner.combine( key, emb_ph, name=f'{combiner.__class__.__name__}_{feature_name}_{start}_{end}') if self._batch_size: shape = combined.shape.as_list() shape[0] = self._batch_size combined = tf.reshape(combined, shape) return combined def set_feature_metadata(self, feature_name: str, combiner: embedding_combiners.Combiner): self._feature_names.add(feature_name) self._feature_to_combiner[feature_name] = combiner def get_table_config(self) -> TableConfig: """Returns merged slices of FeatureEmbTable""" self._merged_slices = self._merge_slices() # Note(youlong.cheng): This is mainly for tf.split after pooling embedding. # The alternative way uses strided_slice causes duplicated backward # calcualtion and unncessary memory write. unmerged_slice_dims = [config.segment.dim_size for config in self._slices] return TableConfig(self._merged_slices, [feature_name for feature_name in self._feature_names], unmerged_slice_dims, self._hashtable_config, self._feature_to_combiner) def get_feature_names(self): return self._feature_names def _merge_slices(self): """Combines the slices which only differ in dim_size.""" merged = [] # Using deepcopy to prevent modifing the proto in self._slices. slices = copy.deepcopy(self._slices) for s in slices: if not merged: merged.append(s) continue last_s = merged[-1] last_dim_size = last_s.segment.dim_size last_s.segment.ClearField("dim_size") dim_size = s.segment.dim_size s.segment.ClearField("dim_size") if last_s.segment.SerializeToString() == s.segment.SerializeToString( ) and str(last_s.learning_rate_fn) == str(s.learning_rate_fn): # We can merge these two slices last_s.segment.dim_size = last_dim_size + dim_size else: last_s.segment.dim_size = last_dim_size s.segment.dim_size = dim_size merged.append(s) return merged class DummyFeatureFactory(FeatureFactory): """Factory to collect the config.""" def __init__(self, batch_size): super().__init__() self._batch_size = batch_size self._tables = {} def create_feature_slot(self, config: FeatureSlotConfig): """Creates a feature slot by config.""" if config.name in self._tables: raise NameError("Duplicate names for the table. Name: {}".format( config.name)) table = DummyFeatureEmbTable(self._batch_size, config.hashtable_config) self._tables.update({config.name: table}) if config.slot_id is not None: self.slot_to_occurrence_threshold.update( {config.slot_id: config.occurrence_threshold}) self.slot_to_expire_time.update({config.slot_id: config.expire_time}) else: logging.warning( "feature[{}] slot is None. pls check feature_list.conf".format( config.name)) return FeatureSlot(table, config) def apply_gradients(self, *args, **kwargs) -> tf.Operation: return tf.no_op() def get_table_name_to_table_config(self) -> Dict[str, TableConfig]: table_configs = {} for k, v in self._tables.items(): table_config = v.get_table_config() if len(table_config.slice_configs) > 0: table_configs[k] = table_config else: raise RuntimeError(f'{k} has no slice, pls. check!') return table_configs class EmbeddingFeatureEmbTable(FeatureEmbTable): """Actual emb table that provides the embedding tensor from embeddings.""" def __init__(self, embeddings: Dict[str, tf.Tensor], embedding_slices: Dict[str, tf.Tensor]): self._embeddings = embeddings self._embedding_slices = embedding_slices def embedding_lookup(self, feature_name: str, start: int, end: int) -> tf.Tensor: if start is None and end is None: # It is important to return the origin tensor since we may # use this tensor as map key. return self._embeddings[feature_name] k = _FEATURE_STRAT_END_KEY.format(feature_name, start, end) logging.vlog(1, "_embedding_slices: {}".format(self._embedding_slices)) return self._embedding_slices[k] class _FeatureFactoryFusionHelper: """Only for feature to be reduced. Not for features to keep the original dim.""" def __init__(self): self._d = {} def append(self, name, ragged_ids, embeddings, slice_dims): self._d[name] = (ragged_ids.row_splits, ragged_utils.fused_value_rowids(ragged_ids), embeddings, ragged_ids.nrows(), slice_dims) def reduce_and_split(self): """(reduce -> split) * N: BASIC for both CPU and GPU.""" feature_name_to_slices = {} for name, (_, value_rowids, embeddings, batch_size_tensor, slice_dims) in self._d.items(): with tf.device("/device:CPU:0"): shape = tf.stack([batch_size_tensor, embeddings.shape.as_list()[1]]) # (batch_size, dim) with device_utils.maybe_device_if_allowed('/device:GPU:0'): # scatter_nd (a.k.a embedding_combiners.ReduceSum) + split reduced_emb = tf.scatter_nd( tf.expand_dims(value_rowids, -1), embeddings, shape, name=name, ) tf.compat.v1.add_to_collection("monolith_reduced_embs", reduced_emb) feature_name_to_slices[name] = tf.split(reduced_emb, slice_dims, axis=1, name=name + "_split") return feature_name_to_slices def fused_reduce_and_split(self): """(reduce + split) * N: For CPU Performance.""" feature_name_to_slices = {} for name, (_, value_rowids, embeddings, batch_size_tensor, slice_dims) in self._d.items(): # We do a simple fused operation that returns a list of tensors, split # across the column dimension, so it returns a list of tensors of shapes # [batch_size, split_dim[i]]. with tf.device("/device:CPU:0"): slices = distribution_ops.fused_reduce_sum_and_split( value_rowids, embeddings, batch_size_tensor, slice_dims, name=f'ReduceSumAndSplit_{name}') feature_name_to_slices[name] = slices return feature_name_to_slices def fused_reduce_then_split(self): """reduce * N -> split: For GPU Performance. Note that we don't fuse the split here, so that split + downstream model op can be fused when pattern matched at graph optimization level. """ feature_name_to_slices = {} if not self._d: return feature_name_to_slices es, ss, ds = [], [], [] for name, (row_splits, _, embeddings, _, slice_dims) in self._d.items(): ss.append(row_splits) es.append(embeddings) ds.append(slice_dims) with device_utils.maybe_device_if_allowed('/device:GPU:0'): out = distribution_ops.fused_reduce_and_split_gpu(ss, es, ds) slice_idx = 0 for name, (_, _, _, _, slice_dims) in self._d.items(): feature_name_to_slices[name] = out[slice_idx:slice_idx + len(slice_dims)] slice_idx += len(slice_dims) return feature_name_to_slices def create_embedding_slices( name_to_embeddings: Dict[str, tf.Tensor], name_to_embedding_ids: Dict[str, tf.RaggedTensor], feature_to_combiner: Dict[str, embedding_combiners.Combiner], feature_to_unmerged_slice_dims: Dict[str, List[int]]) -> Dict[str, tf.Tensor]: embedding_slices = {} feature_to_slices = {} helper = _FeatureFactoryFusionHelper() # Here we perform a fused reduce_sum+splitv operations. for name, embeddings in name_to_embeddings.items(): ragged_ids = name_to_embedding_ids[name] combiner = feature_to_combiner[name] if isinstance(combiner, embedding_combiners.ReduceSum): # This is for a general case, where splits and reduce_sums both happen. # We do a simple fused operation that returns a list of tensors, split # across the column dimension, so it returns a list of tensors of shapes # [None, split_dim[i]], where None refers to the batch_size. helper.append( name, ragged_ids, embeddings, # to combiner feature_to_unmerged_slice_dims[name]) # to split else: combined_emb = combiner.combine( ragged_ids, embeddings, name=f'{combiner.__class__.__name__}_{name}_vv') with device_utils.maybe_device_if_allowed('/device:GPU:0'): slices = tf.split(combined_emb, feature_to_unmerged_slice_dims[name], axis=-1) feature_to_slices[name] = slices with device_utils.maybe_device_if_allowed('/device:GPU:0'): # In a long term, this optimization should be on graph-transform level at runtime. if not is_exporting() and device_utils.within_placement_context_of("GPU"): if int(os.getenv("MONOLITH_GPU_FEATURE_FACTORY_FUSION_LEVEL", '1')) == 1: feature_to_slices.update(helper.fused_reduce_then_split()) else: feature_to_slices.update(helper.reduce_and_split()) else: if is_exporting(): feature_to_slices.update(helper.reduce_and_split()) else: feature_to_slices.update(helper.fused_reduce_and_split()) # assign slice tensors to embedding table for lookup for name, slices in feature_to_slices.items(): start = 0 for i, dim in enumerate(feature_to_unmerged_slice_dims[name]): end = start + dim embedding_slices[_FEATURE_STRAT_END_KEY.format(name, start, end)] = slices[i] start = end return embedding_slices class FeatureFactoryFromEmbeddings(FeatureFactory): def __init__(self, name_to_embeddings: Dict[str, tf.Tensor], name_to_embedding_slices: Dict[str, tf.Tensor]): super().__init__() self._name_to_embeddings = name_to_embeddings self._name_to_embedding_slices = name_to_embedding_slices def create_feature_slot(self, config: FeatureSlotConfig) -> FeatureSlot: # TODO(zouxuan): self._embeddings is actually never updated or used. table = EmbeddingFeatureEmbTable(self._name_to_embeddings, self._name_to_embedding_slices) return FeatureSlot(table, config) class EmbeddingLayoutFakeTable(FeatureEmbTable): def embedding_lookup(self, feature_name: str, start: int, end: int) -> tf.Tensor: return None class EmbeddingLayoutFactory(object): def __init__(self, hash_table: 'PartitionedHashTable', layout_embeddings: Dict[str, Union[tf.Tensor, List[tf.Tensor]]], auxiliary_bundle: Dict[str, tf.Tensor] = None, async_function_mgr: prefetch_queue.AsyncFunctionMgr = None, async_push: bool = False): self.hash_table = hash_table self.layout_embeddings = layout_embeddings self.auxiliary_bundle = auxiliary_bundle self._async_function_mgr = async_function_mgr self._async_push = async_push def create_feature_slot(self, config: FeatureSlotConfig) -> FeatureSlot: table = EmbeddingLayoutFakeTable() return FeatureSlot(table, config) def apply_gradients(self, grads_and_vars: Iterable[Tuple[tf.Tensor, tf.Tensor]], req_time: tf.Tensor = None, grad_scale: tf.Tensor = None): return self.hash_table.apply_gradients( layout_grads_and_vars=grads_and_vars, global_step=tf.compat.v1.train.get_or_create_global_step(), req_time=req_time or self.auxiliary_bundle.get("req_time"), auxiliary_bundle=self.auxiliary_bundle, async_function_mgr=self._async_function_mgr, async_push=self._async_push, grad_scale=grad_scale, ) def get_layout(self, layout: str) -> Union[tf.Tensor, List[tf.Tensor]]: assert layout in self.layout_embeddings return self.layout_embeddings[layout] def flattened_layout(self) -> List[tf.Tensor]: return self.hash_table.flatten_layout(self.layout_embeddings) ================================================ FILE: monolith/native_training/feature_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from copy import deepcopy import tensorflow as tf from google.protobuf import text_format from monolith.native_training import entry from monolith.native_training import embedding_combiners from monolith.native_training import feature from monolith.native_training import learning_rate_functions from monolith.native_training.runtime.hash_table import \ embedding_hash_table_pb2 def _default_learning_rate_fn(): return learning_rate_functions.PolynomialDecay(initial_learning_rate=0.01, decay_steps=20, end_learning_rate=0.05) class CollectingConfigTest(tf.test.TestCase): def test_basic(self): table = feature.DummyFeatureEmbTable( batch_size=4, hashtable_config=entry.CuckooHashTableConfig()) seg = embedding_hash_table_pb2.EntryConfig.Segment() seg.dim_size = 5 seg.opt_config.sgd.SetInParent() table.add_feature_slice(seg) table.set_feature_metadata("feature_name", feature.FeatureColumn.reduce_sum()) placeholder = table.embedding_lookup("feature_name", 0, 5) self.assertAllEqual(placeholder.shape, [4, 5]) def test_basic_with_seq_features(self): table = feature.DummyFeatureEmbTable( batch_size=4, hashtable_config=entry.CuckooHashTableConfig()) seg = embedding_hash_table_pb2.EntryConfig.Segment() seg.dim_size = 5 seg.opt_config.sgd.SetInParent() table.add_feature_slice(seg) table.set_feature_metadata("feature_name", feature.FeatureColumn.first_n(10)) placeholder = table.embedding_lookup("feature_name", 0, 5) self.assertAllEqual(placeholder.shape, [4, 10, 5]) def test_info(self): table = feature.DummyFeatureEmbTable( batch_size=4, hashtable_config=entry.CuckooHashTableConfig()) entry1 = embedding_hash_table_pb2.EntryConfig.Segment() text_format.Parse( "dim_size: 5 opt_config { adagrad { warmup_steps: 10 } } ", entry1) table.add_feature_slice(deepcopy(entry1)) entry2 = embedding_hash_table_pb2.EntryConfig.Segment() text_format.Parse("dim_size: 2 opt_config { sgd {} }", entry2) table.add_feature_slice(deepcopy(entry2), learning_rate_fn=_default_learning_rate_fn()) entry3 = embedding_hash_table_pb2.EntryConfig.Segment() text_format.Parse("dim_size: 2 opt_config { sgd {} }", entry3) table.add_feature_slice(deepcopy(entry3), learning_rate_fn=_default_learning_rate_fn()) table.add_feature_slice(deepcopy(entry3)) table.set_feature_metadata("feature1", embedding_combiners.ReduceSum()) table.embedding_lookup("feature1", 0, 2) config = table.get_table_config() slices = config.slice_configs self.assertEqual(len(slices), 3) self.assertEqual(slices[0].segment.SerializeToString(), entry1.SerializeToString()) self.assertIsInstance(slices[0].learning_rate_fn, learning_rate_functions.LearningRateFunction) merged_entry = embedding_hash_table_pb2.EntryConfig.Segment() text_format.Parse("dim_size: 4 opt_config { sgd {} }", merged_entry) self.assertEqual(slices[1].segment.SerializeToString(), merged_entry.SerializeToString()) self.assertAllEqual(config.feature_names, ["feature1"]) def test_factory(self): factory = feature.DummyFeatureFactory(5) slot_config = feature.FeatureSlotConfig(name="table_name") slot = factory.create_feature_slot(slot_config) s = slot.add_feature_slice(5) fc1 = feature.FeatureColumnV1(slot, "feature1") fc1.embedding_lookup(s) fc2 = feature.FeatureColumnV1(slot, "feature2") fc2.embedding_lookup(s) table_name_to_config = factory.get_table_name_to_table_config() self.assertTrue("table_name" in table_name_to_config) table_config = table_name_to_config["table_name"] self.assertSetEqual(set(table_config.feature_names), set(["feature1", "feature2"])) self.assertEqual(table_config.slice_configs[0].segment.dim_size, 5) def test_factory_with_seq_features(self): factory = feature.DummyFeatureFactory(5) slot_config = feature.FeatureSlotConfig(name="table_name") slot = factory.create_feature_slot(slot_config) s = slot.add_feature_slice(5) fc1 = feature.FeatureColumnV1(slot, "feature1", combiner=embedding_combiners.FirstN(5)) fc1.embedding_lookup(s) fc2 = feature.FeatureColumnV1(slot, "feature2", combiner=embedding_combiners.FirstN(10)) fc2.embedding_lookup(s) table_name_to_config = factory.get_table_name_to_table_config() self.assertTrue("table_name" in table_name_to_config) table_config = table_name_to_config["table_name"] self.assertSetEqual(set(table_config.feature_names), set(["feature1", "feature2"])) self.assertDictEqual(table_config.feature_to_combiners, { "feature1": fc1.combiner, "feature2": fc2.combiner }) self.assertEqual(table_config.slice_configs[0].segment.dim_size, 5) def test_factory_with_slot_occurrence_threshold(self): factory = feature.DummyFeatureFactory(5) slot_config_1 = feature.FeatureSlotConfig(name="table_name_1", slot_id=1, occurrence_threshold=3) slot_1 = factory.create_feature_slot(slot_config_1) s_1 = slot_1.add_feature_slice(5) fc1 = feature.FeatureColumnV1(slot_1, "feature1") fc1.embedding_lookup(s_1) slot_config_2 = feature.FeatureSlotConfig(name="table_name_2", slot_id=2, occurrence_threshold=7) slot_2 = factory.create_feature_slot(slot_config_2) s_2 = slot_2.add_feature_slice(5) fc2 = feature.FeatureColumnV1(slot_2, "feature2") fc2.embedding_lookup(s_2) self.assertEqual(factory.slot_to_occurrence_threshold[1], 3) self.assertEqual(factory.slot_to_occurrence_threshold[2], 7) def test_factory_with_applying_gradients(self): factory = feature.DummyFeatureFactory(5) slot_config = feature.FeatureSlotConfig(name="table") slot = factory.create_feature_slot(slot_config) s = slot.add_feature_slice(1) fc = feature.FeatureColumnV1(slot, "feature1") concat_embedding = fc.get_all_embeddings_concat() factory.apply_gradients([(tf.constant([[0.0] * 2] * 5), concat_embedding)]) def test_bias(self): factory = feature.DummyFeatureFactory(5) slot_config = feature.FeatureSlotConfig(name="table", has_bias=True) slot = factory.create_feature_slot(slot_config) fc = feature.FeatureColumnV1(slot, "feature1") fc.get_bias() class EmbeddingTest(tf.test.TestCase): def test_factory(self): embeddings = {"feature1": tf.constant([[1, 4], [2, 3]], dtype=tf.float32)} embedding_ids = { "feature1": tf.RaggedTensor.from_row_splits([1, 2], [0, 1, 2]) } slices = feature.create_embedding_slices( embeddings, embedding_ids, {"feature1": embedding_combiners.ReduceSum()}, {"feature1": [1, 1]}) factory = feature.FeatureFactoryFromEmbeddings(embeddings, slices) slot_config = feature.FeatureSlotConfig(name="table_name") slot = factory.create_feature_slot(slot_config) s = slot.add_feature_slice(1) fc = feature.FeatureColumnV1(slot, "feature1") tensor = fc.embedding_lookup(s) with self.session() as sess: tensor = sess.run(tensor) self.assertAllEqual(tensor, [[1], [2]]) def test_factory_with_seq_features(self): embeddings = { "feature1": tf.constant([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=tf.float32) } embedding_ids = { "feature1": tf.RaggedTensor.from_row_splits([1, 2, 3, 4], [0, 2, 4]) } slices = feature.create_embedding_slices( embeddings, embedding_ids, {"feature1": embedding_combiners.FirstN(2)}, {"feature1": [1, 1]}) factory = feature.FeatureFactoryFromEmbeddings(embeddings, slices) slot_config = feature.FeatureSlotConfig(name="table_name") slot = factory.create_feature_slot(slot_config) s = slot.add_feature_slice(1) fc = feature.FeatureColumnV1(slot, "feature1") tensor = fc.embedding_lookup(s) with self.session() as sess: tensor = sess.run(tensor) self.assertAllEqual(tensor, [[[1], [3]], [[5], [7]]]) def test_fused_factory(self): embeddings = { "feature1": tf.constant([[1, 2], [2, 3], [3, 5]], dtype=tf.float32) } embedding_ids = { "feature1": tf.RaggedTensor.from_row_splits([1, 2, 3], [0, 1, 1, 3]) } slices = feature.create_embedding_slices( embeddings, embedding_ids, {"feature1": embedding_combiners.ReduceSum()}, {"feature1": [1, 1]}) factory = feature.FeatureFactoryFromEmbeddings(embeddings, slices) slot_config = feature.FeatureSlotConfig(name="table_name") slot = factory.create_feature_slot(slot_config) s = slot.add_feature_slice(1) s2 = slot.add_feature_slice(1) fc = feature.FeatureColumnV1(slot, "feature1") tensor = fc.embedding_lookup(s) with self.session() as sess: tensor = sess.run(tensor) self.assertAllClose(tensor, [[1], [0], [5]]) tensor = fc.embedding_lookup(s2) with self.session() as sess: tensor = sess.run(tensor) self.assertAllClose(tensor, [[2], [0], [8]]) def test_fused_factory_with_seq_features_larger_than_max_seq_length(self): # For rows with bigger number of embeddings than max_seq_length, # discard the extra embedding elements. embeddings = { "feature1": tf.constant([[1, 2], [2, 3], [3, 5], [10, 11]], dtype=tf.float32) } embedding_ids = { "feature1": tf.RaggedTensor.from_row_splits([1, 2, 3, 4], [0, 1, 1, 4]) } ragged_ids = embedding_ids["feature1"] slices = feature.create_embedding_slices( embeddings, embedding_ids, {"feature1": embedding_combiners.FirstN(2)}, {"feature1": [1, 1]}) factory = feature.FeatureFactoryFromEmbeddings(embeddings, slices) slot_config = feature.FeatureSlotConfig(name="table_name") slot = factory.create_feature_slot(slot_config) s = slot.add_feature_slice(1) s2 = slot.add_feature_slice(1) fc = feature.FeatureColumnV1(slot, "feature1") tensor = fc.embedding_lookup(s) with self.session() as sess: tensor = sess.run(tensor) self.assertAllEqual(tensor, [[[1], [0]], [[0], [0]], [[2], [3]]]) tensor = fc.embedding_lookup(s2) with self.session() as sess: tensor = sess.run(tensor) self.assertAllEqual(tensor, [[[2], [0]], [[0], [0]], [[3], [5]]]) if __name__ == "__main__": tf.compat.v1.disable_eager_execution() tf.test.main() ================================================ FILE: monolith/native_training/feature_utils.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from enum import Enum import os from typing import Iterable, Dict from absl import logging import tensorflow as tf from tensorflow.python.training import training_util from monolith.native_training import clip_ops from monolith.native_training.distribution_ops import gen_distribution_ops from monolith.native_training import device_utils from monolith.native_training import feature from monolith.native_training.native_task import NativeContext enable_hvd = os.getenv("MONOLITH_WITH_HOROVOD") enable_bps = int(os.getenv("MONOLITH_WITH_BYTEPS", '0')) enable_bps_allreduce = int(os.getenv("MONOLITH_WITH_BYTEPS_ALLREDUCE", '1')) enable_allreduce_fusion = str( os.getenv("MONOLITH_WITH_ALLREDUCE_FUSION", 'none')) enable_allreduce_fp16 = int(os.getenv("MONOLITH_WITH_ALLREDUCE_FP16", '0')) # for hvd skip_allreduce = int(os.getenv("MONOLITH_SKIP_ALLREDUCE", '0')) # enable (limited) fusion functionality for byteccl where bias tensors are fused into one # tensor before performing allreduce. if enable_hvd != None: import horovod.tensorflow as hvd from horovod.tensorflow.compression import FP16Compressor, NoneCompressor control_ops = [] dense_opt_ops = [] def allreduce_cond(grads, scale = 1): if enable_bps and enable_bps_allreduce: import byteps.tensorflow as bps from byteps.tensorflow.compression import FP16Compressor as BPSFP16Compressor, NoneCompressor as BPSNoneCompressor compression = BPSFP16Compressor if enable_allreduce_fp16 else BPSNoneCompressor else: compression = FP16Compressor if enable_allreduce_fp16 else NoneCompressor grads_wo_none = [grad for grad in grads if grad is not None] num_grads = len(grads) results = [None for _ in range(num_grads)] if len(grads_wo_none) == 0: return grads def map_to_output(reduced): r_idx = 0 for i in range(num_grads): if grads[i] is not None: results[i] = reduced[r_idx] r_idx += 1 assert r_idx == len(reduced), "Something is wrong" return results global control_ops if enable_allreduce_fusion == 'one': # note: one allreduce fusion does not yet support CPU # note: concat -> allreduce -> split is noticeably faster than hvd.grouped_allreduce grads_fused = gen_distribution_ops.monolith_aligned_flat_concat(grads_wo_none, scale) control_ops = [grads_fused] if enable_bps and enable_bps_allreduce: grads_fused_avg = bps.push_pull(grads_fused, average=True, compression=compression, name="bps_ar_fuse_one") else: grads_fused_avg = hvd.allreduce(grads_fused, op=hvd.Average, compression=compression, name="hvd_ar_fuse_one") return map_to_output(gen_distribution_ops.monolith_aligned_flat_split(grads_wo_none, grads_fused_avg)) elif enable_allreduce_fusion == "grouped": assert not enable_bps or not enable_bps_allreduce return map_to_output( hvd.grouped_allreduce([grad * scale for grad in grads_wo_none], op=hvd.Average, compression=compression)) elif enable_allreduce_fusion == 'multi': raise RuntimeError("Support for multi is dropped. Please use 'one' as the fusion strategy") else: logging.info('Enabled allreduce without fusion using Average Op!') if enable_bps and enable_bps_allreduce: return [ bps.push_pull(grad * scale, average=True, compression=compression) if grad is not None else grad for grad in grads ] else: return [ hvd.allreduce(grad * scale, op=hvd.Average, compression=compression) if grad is not None else grad for grad in grads ] class GradClipType(Enum): ClipByNorm = 1 ClipByGlobalNorm = 2 ClipByValue = 3 ClipByDenseAndSparse = 4 NoClip = 5 def _gen_norm_warmup(clip_norm: float, global_step_var: tf.Tensor, warmup_step: int): if not warmup_step: return clip_norm return tf.cond( tf.less(global_step_var, warmup_step), lambda: tf.compat.v1.div( tf.cast(global_step_var, dtype=tf.float32), float(warmup_step)), lambda: 1.0) * clip_norm def apply_gradients_with_var_optimizer( ctx: NativeContext, feature_columns: Iterable[feature.FeatureColumnV1], var_opt: tf.compat.v1.train.Optimizer, loss: tf.Tensor, clip_type: GradClipType = GradClipType.ClipByGlobalNorm, clip_norm: float = None, global_step=None, grads_and_vars_summary: bool = False, use_allreduce: bool = False, ue_gradient_check: bool = False, ue_fc_names: list = [], ue_euclidean_norm_threshold: float = 0.0, dense_weight_decay: float = 0.0, features: Dict[str, tf.Tensor] = {}, sparse_clip_norm: float = None, sparse_norm_warmup_steps: int = None, dense_reduce_mean: bool = False, batch_size: int = 1, is_fused_layout: bool = False) -> tf.Operation: """ A helper function that applies gradient to both dense params and embedding params. Args: clip_type - clip type clip_norm - norm will be used by clip global_step - is not None, will be added by 1. grads_and_vars_summary - when True, will print summary of grads and vars dense_weight_decay - dense weight decay, l2 norm """ with device_utils.maybe_device_if_allowed('/device:GPU:0'): assert isinstance(var_opt, tf.compat.v1.train.Optimizer) feature_columns = list(feature_columns) if is_fused_layout: layout_factory: feature.EmbeddingLayoutFactory = ctx.layout_factory all_embeddings = layout_factory.flattened_layout() else: all_embeddings = [fc.get_all_embeddings_concat() for fc in feature_columns] variables = tf.compat.v1.trainable_variables() grads_and_vars = var_opt.compute_gradients(loss, variables + all_embeddings, colocate_gradients_with_ops=True) # Some variables are created but unused and we need to filter them out. if is_fused_layout: grads_and_vars_tmp = grads_and_vars[:len(variables)] for gv in grads_and_vars[len(variables):]: grads_and_vars_tmp.append((gv[0] if gv[0] is not None else tf.zeros_like(gv[1]), gv[1])) grads_and_vars = grads_and_vars_tmp dense_gvs = [gv for gv in grads_and_vars[:len(variables)] if gv[0] is not None] sparse_gvs = [gv for gv in grads_and_vars[len(variables):] if gv[0] is not None] if is_fused_layout: feature_columns = [] else: feature_columns = [ fc for fc, gv in zip(feature_columns, grads_and_vars[len(variables):]) if gv[0] is not None ] variables = [gv[1] for gv in dense_gvs] all_embeddings = [gv[1] for gv in sparse_gvs] grads_and_vars = dense_gvs + sparse_gvs grads = [grad_and_var[0] for grad_and_var in grads_and_vars] # UE conditional gradient check if ue_gradient_check: grads = [] for grad_and_var in grads_and_vars: found = False for fc_name in ue_fc_names: if fc_name in grad_and_var[1].name or 'uue' in grad_and_var[1].name: grads.append( tf.where( tf.norm(features[fc_name]) > ue_euclidean_norm_threshold, grad_and_var[0], tf.zeros_like(grad_and_var[0]))) logging.info("UE Vars: {}".format(grad_and_var[1].name)) found = True break if not found: grads.append(grad_and_var[0]) # TODO(zouxuan): this is a quick workaround to fix the empty grads issue. if len(grads) == 0: return tf.no_op() dense_grads = grads[:len(variables)] sparse_grads = grads[len(variables):] if dense_reduce_mean: dense_grads = [g / batch_size for g in dense_grads] global_dense_norm = None global_sparse_norm = None norm_fn = clip_ops._global_norm if device_utils.within_placement_context_of( "GPU") else tf.linalg.global_norm if clip_type == GradClipType.ClipByGlobalNorm and clip_norm is not None: global_dense_norm = norm_fn(grads) global_sparse_norm = global_dense_norm # use the same norm for sparse and dense sparse_clip_norm = sparse_clip_norm or clip_norm if sparse_norm_warmup_steps is not None: sparse_clip_norm = _gen_norm_warmup(sparse_clip_norm, global_step, sparse_norm_warmup_steps) logging.info('sparse_norm_warmup_steps: %s', sparse_norm_warmup_steps) with tf.device('/device:CPU:0'): tf.compat.v1.summary.scalar("global_gradient_norm", global_dense_norm) elif clip_type == GradClipType.ClipByValue and clip_norm is not None: clipped_grads = [ tf.clip_by_value(g, clip_value_min=-clip_norm, clip_value_max=clip_norm) for g in grads ] elif clip_type == GradClipType.ClipByNorm and clip_norm is not None: clipped_grads = [tf.clip_by_norm(g, clip_norm) for g in grads] elif clip_type == GradClipType.ClipByDenseAndSparse: global_dense_norm = norm_fn(dense_grads) if sparse_clip_norm is not None: global_sparse_norm = norm_fn(sparse_grads) with tf.device('/device:CPU:0'): tf.compat.v1.summary.scalar("global_gradient_norm/dense", global_dense_norm) if global_sparse_norm is not None: tf.compat.v1.summary.scalar("global_gradient_norm/sparse", global_sparse_norm) else: clipped_grads = grads if skip_allreduce: use_allreduce = False # Conditionally perform clip by global norm. # If we're using synchronous (allreduce=True) distributed GPU training, # we defer clip and only calculate a scale factor. The scaling is fused # with later concat/gather kernels for better performance def cond_defer_clip(norm, clip_norm, grads): defer_clip = device_utils.within_placement_context_of("GPU") and \ use_allreduce and not grads_and_vars_summary and not is_fused_layout scale = 1 if norm is not None: if not defer_clip: grads, _ = clip_ops.clip_by_global_norm(grads, clip_norm, use_norm=norm) else: scale = tf.minimum(clip_norm / norm, 1) return grads, scale if clip_type in (GradClipType.ClipByGlobalNorm, GradClipType.ClipByDenseAndSparse): dense_clipped_grads, dense_scale = cond_defer_clip(global_dense_norm, clip_norm, dense_grads) sparse_clipped_grads, sparse_scale = cond_defer_clip(global_sparse_norm, sparse_clip_norm, sparse_grads) else: dense_scale = 1 sparse_scale = 1 dense_clipped_grads = clipped_grads[:len(variables)] sparse_clipped_grads = clipped_grads[len(variables):] if grads_and_vars_summary: with tf.device("/device:CPU:0"): if len(dense_clipped_grads) > 0: tf.compat.v1.summary.histogram( "variable_gradient", tf.concat( [tf.reshape(grad, [-1]) for grad in dense_clipped_grads], 0)) dense_grad_sizes = [] for grad, var in zip(dense_clipped_grads, variables): summary_var_name = var.name.replace(":", "_") tf.compat.v1.summary.scalar( "trainable_variable_norm/{}".format(summary_var_name), tf.norm(var)) tf.compat.v1.summary.histogram( "trainable_variable/{}".format(summary_var_name), var) tf.compat.v1.summary.scalar( "gradient_norm/{}".format(summary_var_name), tf.norm(grad)) tf.compat.v1.summary.histogram( "gradient/{}".format(summary_var_name), grad) dense_grad_sizes.append(tf.size(grad)) tf.compat.v1.summary.histogram("dense_grad_sizes", dense_grad_sizes) tf.compat.v1.summary.scalar("dense_grad_total_size", tf.reduce_sum(dense_grad_sizes)) tf.compat.v1.summary.scalar("dense_grad_total_num", len(dense_grad_sizes)) for i, fc in enumerate(feature_columns): tf.compat.v1.summary.histogram("{}_gradient".format(fc.feature_name), sparse_clipped_grads[i]) logging.info('use_allreduce: %s', use_allreduce) dense_clipped_grads = allreduce_cond( dense_clipped_grads, dense_scale ) if use_allreduce and enable_hvd else dense_clipped_grads if dense_weight_decay and variables: dense_clipped_grads = [ g + dense_weight_decay * v for g, v in zip(dense_clipped_grads, variables) ] logging.info('dense_weight_decay: %s', dense_weight_decay) train_ops = [] grads_and_vars_without_optimizer = [] if variables: global dense_opt_ops for i, var in enumerate(variables): if hasattr(var, 'optimizer') and var.optimizer: train_ops.append( ctx.add_async_function(var.optimizer.apply_gradients, ([(dense_clipped_grads[i], var)],))) logging.info("var {} uses a custom optimizer: {}".format( var.name, var.optimizer)) else: grads_and_vars_without_optimizer.append((dense_clipped_grads[i], var)) train_ops.append( ctx.add_async_function(var_opt.apply_gradients, (grads_and_vars_without_optimizer,))) dense_opt_ops = train_ops.copy() with tf.device('/device:CPU:0'): train_ops.append( ctx.apply_embedding_gradients( list(zip(sparse_clipped_grads, all_embeddings)), sparse_scale)) if global_step is not None: # The control dependency here ensures that # when the StepCounterHook tries to get the global_step # from the training session at the same time of training, # the read_value should be consistent (before assign_add). # Also makes sure that the global step is incremented after the optimize ops, # since embedding optimizer requires this global step as input with tf.control_dependencies( train_ops + [training_util._get_or_create_global_step_read()]): train_ops.append( ctx.add_async_function(tf.compat.v1.assign_add, (global_step, 1))) return tf.group(*train_ops) def apply_gradients(ctx: NativeContext, var_opt: tf.compat.v1.train.Optimizer, loss: tf.Tensor, clip_type: GradClipType = GradClipType.ClipByGlobalNorm, clip_norm: float = None, dense_weight_decay: float = 0.0, global_step=None, use_allreduce: bool = False): layout_factory: feature.EmbeddingLayoutFactory = ctx.layout_factory variables = tf.compat.v1.trainable_variables() layout_embeddings = layout_factory.flattened_layout() grads_and_vars = var_opt.compute_gradients(loss, variables + layout_embeddings, colocate_gradients_with_ops=True) # clip grads flag = False for g, v in grads_and_vars: if g is None: flag = True logging.info(f'grad of {v} is None, maybe it not used in the graph') if flag: grads_and_vars = [(g, v) for (g, v) in grads_and_vars if g is not None] variables = [v for (g, v) in grads_and_vars if v in variables] layout_embeddings = [ v for (g, v) in grads_and_vars if v in layout_embeddings ] assert len(grads_and_vars) == len(variables) + len(layout_embeddings) grads = [g for (g, _) in grads_and_vars] if grads and clip_norm is not None and clip_norm > 0: if clip_type == GradClipType.ClipByGlobalNorm: clipped_grads, global_g_norm = clip_ops.clip_by_global_norm( grads, clip_norm, use_norm=tf.linalg.global_norm(grads)) logging.info('clip_by_global_norm: %s', clip_norm) with tf.device('/device:CPU:0'): tf.compat.v1.summary.scalar("global_gradient_norm", global_g_norm) elif clip_type == GradClipType.ClipByNorm: clipped_grads = [tf.clip_by_norm(g, clip_norm) for g in grads] else: raise Exception(f"{clip_type} is not supported yet!") else: clipped_grads = grads train_ops = [] # dense apply_gradients if variables: dense_clipped_grads = clipped_grads[:len(variables)] if use_allreduce and enable_hvd: dense_clipped_grads = allreduce_cond( dense_clipped_grads) if dense_weight_decay > 0: grads_and_vars = [(g + dense_weight_decay * v, v) for g, v in zip(dense_clipped_grads, variables)] else: grads_and_vars = list(zip(dense_clipped_grads, variables)) train_ops.append( var_opt.apply_gradients(grads_and_vars, global_step=global_step)) else: with tf.control_dependencies( [training_util._get_or_create_global_step_read()]): train_ops.append(tf.compat.v1.assign_add(global_step, 1)) # sparse apply_gradients if layout_embeddings: sparse_clipped_grads = clipped_grads[len(variables):] grads_and_vars = list(zip(sparse_clipped_grads, layout_embeddings)) train_ops.append(ctx.apply_embedding_gradients(grads_and_vars)) return tf.group(*train_ops) ================================================ FILE: monolith/native_training/feature_utils_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from unittest import mock import tensorflow as tf from tensorflow.python.framework import test_util import os os.environ['MONOLITH_WITH_ALLREDUCE_FUSION'] = 'one' from monolith.native_training import embedding_combiners from monolith.native_training import feature, feature_utils from monolith.native_training.native_task import NativeContext from monolith.native_training import prefetch_queue def _setup_test_embedding(is_async=False): """Will create embedding with 3,1. And returns a emb with size 3.""" emb_var = tf.Variable([[1.0, 1.0, 1.0, 1.0]], trainable=False) emb = {"feature1": emb_var} emb_id = tf.RaggedTensor.from_row_splits([111], [0, 1]) slices = feature.create_embedding_slices( emb, {"feature1": emb_id}, {"feature1": embedding_combiners.ReduceSum()}, {"feature1": [3, 1]}) feature_factory = feature.FeatureFactoryFromEmbeddings(emb, slices) def apply_emb_gradients(grads_and_vars, scale=1): return tf.group([var.assign_sub(grad * scale) for grad, var in grads_and_vars]) feature_factory.apply_gradients = mock.MagicMock( side_effect=apply_emb_gradients) ctx = NativeContext( feature_factory=feature_factory, async_function_mgr=prefetch_queue.AsyncFunctionMgr(is_async)) slot = ctx.create_feature_slot(feature.FeatureSlotConfig(name="Slot")) s = slot.add_feature_slice(3) fc = feature.FeatureColumnV1(slot, "feature1") emb = fc.embedding_lookup(s) return ctx, fc, emb_var, emb class FeatureUtilsTest(tf.test.TestCase): def test_apply_gradients_with_dense_optimizer(self): ctx, fc, emb_var, emb = _setup_test_embedding() emb_loss = tf.reduce_sum(tf.reduce_sum(emb)) var = tf.Variable(1.0) global_step = tf.compat.v1.train.get_or_create_global_step() loss = emb_loss + var opt = tf.compat.v1.train.GradientDescentOptimizer(1.0) # norm is 2, will be clipped by 1 op = feature_utils.apply_gradients_with_var_optimizer( ctx, [fc], opt, loss, clip_norm=1.0, global_step=global_step, grads_and_vars_summary=True) with self.session() as sess: sess.run(tf.compat.v1.global_variables_initializer()) sess.run(op) self.assertAllEqual(sess.run(var), 0.5) self.assertAllEqual(sess.run(emb_var), [[0.5, 0.5, 0.5, 1.0]]) self.assertAllEqual(sess.run(global_step), 1) @test_util.run_gpu_only def test_apply_gradients_with_dense_optimizer_gpu(self): # this test tests the fusion of clip_by_global_norm with later kernels with test_util.use_gpu(): ctx, fc, emb_var, emb = _setup_test_embedding() emb_loss = tf.reduce_sum(tf.reduce_sum(emb)) var = tf.Variable(1.0) global_step = tf.compat.v1.train.get_or_create_global_step() loss = emb_loss + var opt = tf.compat.v1.train.GradientDescentOptimizer(1.0) # norm is 2, will be clipped by 1 op = feature_utils.apply_gradients_with_var_optimizer( ctx, [fc], opt, loss, clip_norm=1.0, global_step=global_step, grads_and_vars_summary=False, use_allreduce=True) with self.session() as sess: sess.run(tf.compat.v1.global_variables_initializer()) sess.run(op) self.assertAllEqual(sess.run(var), 0.5) self.assertAllEqual(sess.run(emb_var), [[0.5, 0.5, 0.5, 1.0]]) self.assertAllEqual(sess.run(global_step), 1) def test_apply_gradients_with_dense_optimizer_post_push(self): ctx, fc, emb_var, emb = _setup_test_embedding(is_async=True) emb_loss = tf.reduce_sum(tf.reduce_sum(emb)) var = tf.Variable(1.0) opt = tf.compat.v1.train.GradientDescentOptimizer(1.0) loss = emb_loss + var op = feature_utils.apply_gradients_with_var_optimizer(ctx, [fc], opt, loss) with tf.compat.v1.train.SingularMonitoredSession( hooks=ctx.async_function_mgr.hooks) as sess: sess.run(op) sess.run(op) sess.run(op) # Since it is async pushed, the push should happen twice. var_value, emb_var_value = sess.run([var, emb_var]) # Run op three times will trigger two optimization self.assertAllEqual(var_value, -1.0) # But emb is not affected. Optimized by 3 times. self.assertAllEqual(emb_var_value, [[-2.0, -2.0, -2.0, 1.0]]) def test_apply_gradients_without_dense_optimizer(self): ctx, fc, emb_var, emb = _setup_test_embedding() emb_loss = tf.reduce_sum(tf.reduce_sum(emb)) global_step = tf.compat.v1.train.get_or_create_global_step() loss = emb_loss opt = tf.compat.v1.train.GradientDescentOptimizer(1.0) op = feature_utils.apply_gradients_with_var_optimizer( ctx, [fc], opt, loss, global_step=global_step) with self.session() as sess: sess.run(tf.compat.v1.global_variables_initializer()) sess.run(op) self.assertAllEqual(sess.run(emb_var), [[0.0, 0.0, 0.0, 1.0]]) self.assertAllEqual(sess.run(global_step), 1) if __name__ == "__main__": tf.compat.v1.disable_eager_execution() tf.test.main() ================================================ FILE: monolith/native_training/file_ops.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import tensorflow as tf from monolith.native_training.runtime.ops import gen_monolith_ops file_ops = gen_monolith_ops class WritableFile: """A gfile wrapper used in the graph execution.""" def __init__(self, filename): self._handle = file_ops.monolith_writable_file(filename) def append(self, content): """Append the content into the file. Args: content - a 0-D string tensor. """ return file_ops.monolith_writable_file_append(self._handle, content) def append_entry_dump(self, item_id, bias, embedding): return file_ops.monolith_entry_dump_file_append(self._handle, item_id, bias, embedding) def close(self): return file_ops.monolith_writable_file_close(self._handle) class FileCloseHook(tf.estimator.SessionRunHook): """A hook that will close WritableFiles at the end of session.""" def __init__(self, files): assert isinstance(files, list) self._files = files self._close_ops = [f.close() for f in files] def end(self, session): session.run(self._close_ops) ================================================ FILE: monolith/native_training/file_ops_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import tensorflow as tf from monolith.native_training import file_ops class WritableFileTest(tf.test.TestCase): def test_basic(self): filename = os.environ["TEST_TMPDIR"] + "/test_basic/test_name" times = 1000 @tf.function def write(): f = file_ops.WritableFile(filename) for i in tf.range(times): f.append("1234") f.close() self.evaluate(write()) with tf.io.gfile.GFile(filename) as f: self.assertAllEqual(f.read(), "1234" * times) def test_hook(self): filename = os.environ["TEST_TMPDIR"] + "/test_hook/test_name" f = file_ops.WritableFile(filename) write_op = f.append("1234") with tf.compat.v1.train.MonitoredSession( hooks=[file_ops.FileCloseHook([f])]) as sess: sess.run(write_op) with tf.io.gfile.GFile(filename) as f: self.assertAllEqual(f.read(), "1234") if __name__ == "__main__": tf.compat.v1.disable_eager_execution() tf.test.main() ================================================ FILE: monolith/native_training/fountain/BUILD ================================================ load("@rules_python//python:defs.bzl", "py_binary", "py_library", "py_test") load("@rules_cc//cc:defs.bzl", "cc_library") cc_library( name = "fountain_dataset_ops", visibility = ["//visibility:public"], ) py_library( name = "fountain_lib", visibility = ["//visibility:public"], ) ================================================ FILE: monolith/native_training/fountain/README.md ================================================ Dummy implementation of fountain. ================================================ FILE: monolith/native_training/fused_embedding_to_layout_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from contextlib import nullcontext import logging import string import numpy as np np.random.seed(2) from random import randint import tensorflow as tf # tf.compat.v1.disable_eager_execution() tf.compat.v1.disable_v2_behavior() from collections import defaultdict from tensorflow.python.framework import test_util from monolith.native_training import distribution_ops from idl.matrix.proto.example_pb2 import ExampleBatch, Example, FeatureListType, \ SliceConfig, PoolingType, OutType, OutConfig, FeatureConfig, FeatureConfigs, TensorShape from monolith.native_training.data.parsers import parse_instances, parse_examples, parse_example_batch, \ sharding_sparse_fids, get_default_parser_ctx, ParserCtx SHARD_BIT = 0x80000000 def infer_shape(out_conf: OutConfig, out_type: OutType, max_sequence_length: int = 0): out_conf.out_type = out_type if out_type == OutType.NONE: for sc in out_conf.slice_configs: shape = out_conf.shape.add() if max_sequence_length > 0: shape.dims.extend([-1, max_sequence_length, sc.end - sc.start]) else: shape.dims.extend([-1, sc.end - sc.start]) elif out_type == OutType.CONCAT: shape = out_conf.shape.add() last_dim = 0 for sc in out_conf.slice_configs: last_dim += sc.end - sc.start if max_sequence_length > 0: shape.dims.extend([-1, max_sequence_length, last_dim]) else: shape.dims.extend([-1, last_dim]) elif out_type == OutType.STACK: shape = out_conf.shape.add() last_dim = None for sc in out_conf.slice_configs: if last_dim is None: last_dim = sc.end - sc.start else: assert last_dim == sc.end - sc.start if max_sequence_length > 0: shape.dims.extend( [-1, len(out_conf.slice_configs), max_sequence_length, last_dim]) else: shape.dims.extend([-1, len(out_conf.slice_configs), last_dim]) elif out_type == OutType.ADDN: shape = out_conf.shape.add() last_dim = None for sc in out_conf.slice_configs: if last_dim is None: last_dim = sc.end - sc.start else: assert last_dim == sc.end - sc.start if max_sequence_length > 0: shape.dims.extend([-1, max_sequence_length, last_dim]) else: shape.dims.extend([-1, last_dim]) else: raise ValueError('out_type error') def get_key(ln: str, sc: SliceConfig) -> str: return f"{ln}_{sc.feature_name}_{sc.start}_{sc.end}" def pooling(pooling_type, in_data, max_length): if max_length and len(in_data) > max_length: data = in_data[0:max_length] else: data = in_data if pooling_type == PoolingType.SUM: result = np.zeros_like(data[0]) for d in data: result += d return result if pooling_type == PoolingType.MEAN: result = np.zeros_like(data[0]) for d in data: result += d result /= len(data) return result else: last_dim = int(data[0].shape[-1]) result = np.zeros(shape=(max_length, last_dim), dtype=np.float32) for i, d in enumerate(data): result[i, :] = d if i < max_length: result[i, :] = d else: break return result class FusedEmbeddingToLayoutTest(tf.test.TestCase): def get_pre_output_offset(self, shard, f_cfg): return f_cfg["pre_output_index"] + shard * f_cfg[ "table_feature_count"] + f_cfg["feature_in_table_index"] def get_feature_cfg(self, raw_feature_cfgs, ps_num): feature_cfg = defaultdict(dict) table_cfg = defaultdict(dict) for feature_name, cfg in raw_feature_cfgs.feature_configs.items(): dim = 0 for slice_dim in cfg.slice_dims: dim += slice_dim feature_cfg[feature_name] = { "feature_name": feature_name, "feature_index": -1, "table_name": cfg.table, "table_index": -1, "feature_in_table_index": -1, "table_feature_count": 0, "pre_output_index": 0, "dim_sum": dim, } if cfg.table not in table_cfg: table_cfg[cfg.table] = { "table_name": cfg.table, "feature_list": [], "table_index": -1, "feature_count": 0, } table_name_sort = sorted(table_cfg.keys()) for idx, name in enumerate(table_name_sort): table_cfg[name]["table_index"] = idx feature_name_sort = sorted(feature_cfg.keys()) for idx, name in enumerate(feature_name_sort): f_cfg = feature_cfg[name] t_cfg = table_cfg[f_cfg["table_name"]] f_cfg["feature_index"] = idx f_cfg["table_index"] = t_cfg["table_index"] f_cfg["feature_in_table_index"] = len(t_cfg["feature_list"]) t_cfg["feature_list"].append(name) pre_index = 0 for idx, name in enumerate(table_name_sort): t_cfg = table_cfg[name] t_cfg["feature_count"] = len(t_cfg["feature_list"]) for feature_name in t_cfg["feature_list"]: f_cfg = feature_cfg[feature_name] f_cfg["pre_output_index"] = pre_index f_cfg["table_feature_count"] = t_cfg["feature_count"] pre_index += max(t_cfg["feature_count"], 1) * ps_num return feature_cfg, table_cfg, feature_name_sort, table_name_sort def test_fused_embedding_to_layout(self, shard_op_version=None, op_version=2, parallel_flag=1, use_gpu=False): batch_size = 256 num_ps = 5 slot_count = 200 slot_table_split = [50, 100 ] #slot split for [table_one, table_two, table_three] max_sequence_length = 3 feature_cfgs = FeatureConfigs() bias = OutConfig() vec = OutConfig() ffm1 = OutConfig() ffm2 = OutConfig() firstN = OutConfig() for slot in range(1, slot_count): feature_name = f"fc_slot_{slot}" fconf = FeatureConfig() if slot >= slot_table_split[1]: table_name = "table_one" #table_three, but now test for table with different dim slice_dims = [1, 4, 16] sequence_length = max_sequence_length pooling_type = PoolingType.FIRSTN slice_config = firstN.slice_configs.add() slice_config.feature_name = feature_name slice_config.start = 1 slice_config.end = 21 else: sequence_length = 0 if slot < slot_table_split[0]: table_name = "table_one" slice_dims = [1, 4, 8] pooling_type = PoolingType.SUM slice_config = ffm1.slice_configs.add() slice_config.feature_name = feature_name slice_config.start = 5 slice_config.end = 13 else: table_name = "table_two" slice_dims = [1, 4, 16] pooling_type = PoolingType.MEAN slice_config = ffm2.slice_configs.add() slice_config.feature_name = feature_name slice_config.start = 5 slice_config.end = 21 slice_config = bias.slice_configs.add() slice_config.feature_name = feature_name slice_config.start = 0 slice_config.end = 1 slice_config = vec.slice_configs.add() slice_config.feature_name = feature_name slice_config.start = 1 slice_config.end = 5 fconf.table = table_name fconf.slice_dims.extend(slice_dims) fconf.max_sequence_length = sequence_length fconf.pooling_type = pooling_type feature_cfgs.feature_configs[feature_name].CopyFrom(fconf) infer_shape(bias, OutType.ADDN) feature_cfgs.out_configs['bias'].CopyFrom(bias) infer_shape(vec, OutType.CONCAT) feature_cfgs.out_configs['vec'].CopyFrom(vec) infer_shape(ffm1, OutType.STACK) feature_cfgs.out_configs['ffm1'].CopyFrom(ffm1) infer_shape(ffm2, OutType.NONE) feature_cfgs.out_configs['ffm2'].CopyFrom(ffm2) infer_shape(firstN, OutType.NONE, max_sequence_length) feature_cfgs.out_configs['firstN'].CopyFrom(firstN) logging.info(f"feature_cfgs : {feature_cfgs} ") feature_cfg, table_cfg, feature_name_sort, table_name_sort = self.get_feature_cfg( feature_cfgs, num_ps) fid_offset_list = list() feature_offset_list = [0] nfl_offset_list = [0] nfl_offset_list2 = [0] sparse_features = ExampleBatch(batch_size=batch_size) std_features = defaultdict(list) fids_dict = {} fid_row_split_list = [[0] for _ in range(num_ps * len(table_name_sort))] for feature_name in feature_name_sort: slot = int(feature_name.split("fc_slot_")[-1]) named_feature_list = sparse_features.named_feature_list.add() named_feature_list.id = slot named_feature_list.name = feature_name is_shared = True if slot % 2 == 0 else False logging.info(f"show shared {named_feature_list.name} {is_shared}") named_feature_list.type = FeatureListType.SHARED if is_shared else FeatureListType.INDIVIDUAL f_cfg = feature_cfg[feature_name] table_name = f_cfg["table_name"] t_cfg = table_cfg[table_name] table_index = t_cfg["table_index"] dim_sum = f_cfg["dim_sum"] if table_name not in fids_dict: fids_dict[table_name] = defaultdict(list) index2 = [0] * num_ps * len(table_cfg) def make_fids(feature): std_features[named_feature_list.name].append(feature) ''' fids = list( set([(slot << 48) + randint(100, 1000000) for _ in range(randint(1, 5))])) ''' fids = list( set([(slot * 10000) + (i + 1) * 1000 + randint(1, 9) * 100 for i in range(randint(1, max_sequence_length * 2))])) logging.info(f"show fids {fids}") feature.fid_v2_list.value.extend(fids) for fid in fids: idx = fid % num_ps full_index = self.get_pre_output_offset(idx, f_cfg) index1 = table_index * num_ps + idx fid_offset = full_index << 32 | index2[index1] index2[index1] += 1 fid_offset_list.append(fid_offset) fids_dict[table_name][idx].append((dim_sum, fid)) feature_offset_list.append(len(fid_offset_list)) if is_shared: feature = named_feature_list.feature.add() make_fids(feature) else: for _ in range(batch_size): feature = named_feature_list.feature.add() make_fids(feature) for ps_i in range(num_ps): fid_row_split_list[table_index * num_ps + ps_i].append( len(fids_dict[table_name][ps_i])) nfl_index = len(feature_offset_list) - 1 if is_shared: # add shared encode, 向前一位 nfl_offset_list[-1] |= SHARD_BIT nfl_offset_list.append(nfl_index) nfl_size_list = [len(nfl_offset_list)] feature_size_list = [len(feature_offset_list)] fid_size_list = [len(fid_offset_list)] logging.info(f"show fid_row_split_list: {fid_row_split_list}") logging.info(f"sparse_features : {sparse_features} ") fid_to_emb = {} embeddings_list = [] emb_size_list = [] for table_name, table in fids_dict.items(): for idx in sorted(table): values = table[idx] #emb = np.random.uniform(size=size) #logging.info(f"show emb {emb}") emb = [] for i, (dim, fid) in enumerate(values): fid_emb = [] for j in range(dim): fid_emb.append(fid + j) fid_to_emb[fid] = np.array(fid_emb, dtype=float) emb.extend(fid_emb) emb_size_list.append(len(emb)) emb = np.array(emb, dtype=float) logging.info(f"show emb2 {emb}") embeddings_list.append( tf.reshape(tf.constant(value=emb, dtype=tf.float32), [-1])) #sparse_features_str = tf.constant(value=sparse_features.SerializeToString(), # dtype=tf.string) if shard_op_version: get_default_parser_ctx().enable_fused_layout = True parsed_results = parse_example_batch(sparse_features.SerializeToString(), sparse_features=[], dense_features=[], dense_feature_shapes=[], dense_feature_types=[], extra_features=[], extra_feature_shapes=[]) sparse_varint = parsed_results.pop( ParserCtx.sharding_sparse_fids_sparse_features_key) fid_list, fid_offset_list_ts, feature_offset_list_ts, nfl_offset_list_ts, batch_size_ts, nfl_size_list_ts, feature_size_list_ts, \ fid_size_list_ts, emb_size_list_ts, fid_row_split_list_ts, fid_row_split_size_list_ts, fid_list_emb_row_lenth, \ fid_list_table_row_length, fid_list_shard_row_lenth = sharding_sparse_fids( sparse_varint, num_ps, feature_cfgs, False, "examplebatch", parallel_flag=0, fid_list_ret_list=True, version=shard_op_version) else: fid_row_split_list_ts = fid_row_split_list fid_offset_list_ts = tf.constant(fid_offset_list, dtype=tf.uint64) feature_offset_list_ts = tf.constant(feature_offset_list, dtype=tf.int32) nfl_offset_list_ts = tf.constant(nfl_offset_list, dtype=tf.uint32) fid_list_emb_row_lenth = None if op_version >= 3: raise TypeError('Not imple') batch_size_ts = tf.constant([0] * batch_size, dtype=tf.int32) else: batch_size_ts = tf.constant([batch_size], dtype=tf.int32) nfl_size_list_ts = tf.constant(nfl_size_list, dtype=tf.int32) feature_size_list_ts = tf.constant(feature_size_list, dtype=tf.int32) fid_size_list_ts = tf.constant(fid_size_list, dtype=tf.int32) emb_size_list_ts = tf.constant(emb_size_list, dtype=tf.int32) variant_type = 'example_batch' if use_gpu: assert op_version >= 3 if op_version == 4: embeddings_list_new = [] for ps_i in range(num_ps): for table_i in range(len(table_name_sort)): embeddings_list_new.append(embeddings_list[table_i * num_ps + ps_i]) embeddings_list = [tf.concat(embeddings_list_new, axis=-1)] with test_util.use_gpu() if use_gpu else tf.device("CPU:0"): layouts_op = distribution_ops.fused_embedding_to_layout( embeddings_list, fid_row_split_list_ts, fid_offset_list_ts, feature_offset_list_ts, nfl_offset_list_ts, batch_size_ts, variant_type, feature_cfgs, num_ps, fid_list_emb_row_lenth=fid_list_emb_row_lenth, nfl_size=nfl_size_list_ts, feature_size=feature_size_list_ts, fid_size=fid_size_list_ts, emb_size=emb_size_list_ts, parallel_flag=parallel_flag, version=op_version) with self.session() as sess: layouts = sess.run(layouts_op) #logging.info(f"show layouts: {layouts}") layout_names = sorted([x for x in feature_cfgs.out_configs.keys()]) out_tensors = {} layout_info = {} out_tensor_list = [] out_tensor_name_list = [] # get layout configs. for ln in layout_names: out_config = feature_cfgs.out_configs[ln] out_tensors[ln] = [] info = {} if len(out_config.shape) == 1: for shape in out_config.shape: real_shape = list(shape.dims) real_shape[0] = batch_size ts = np.zeros(shape=real_shape, dtype=np.float32) #logging.info(f" {ln} {ts} ") out_tensors[ln].append(ts) out_tensor_list.append(ts) out_tensor_name_list.append(ln + ":" + str(len(out_tensors[ln]))) offset = 0 for i, sc in enumerate(out_config.slice_configs): key = get_key(ln, sc) dim = sc.end - sc.start if out_config.out_type == OutType.CONCAT: info[key] = (ts, offset) offset += dim elif out_config.out_type == OutType.STACK: info[key] = (ts, i) elif out_config.out_type == OutType.ADDN: info[key] = (ts, 0) else: raise Exception("error") else: for sc, shape in zip(out_config.slice_configs, out_config.shape): real_shape = list(shape.dims) real_shape[0] = batch_size ts = np.zeros(shape=real_shape, dtype=np.float32) out_tensors[ln].append(ts) out_tensor_list.append(ts) out_tensor_name_list.append(ln + ":" + str(len(out_tensors[ln]))) key = get_key(ln, sc) info[key] = (ts, 0) layout_info[ln] = info # {name: (out, offset)} for ln in layout_names: out_config = feature_cfgs.out_configs[ln] out_type = out_config.out_type for slice_conf in out_config.slice_configs: name = slice_conf.feature_name features = std_features[name] feature_config = feature_cfgs.feature_configs[name] pooling_type = feature_config.pooling_type max_length = feature_config.max_sequence_length key = get_key(ln, slice_conf) dim = slice_conf.end - slice_conf.start (ts, offset) = layout_info[ln][key] if out_type == OutType.ADDN: tmp_addn = np.zeros(ts.shape) # per slice tmp out #logging.info(f" {ln} {ts} ") for i in range(batch_size): if i < len(features): tmp = [] for fid in features[i].fid_v2_list.value: fid_emb = fid_to_emb[fid] emb_slice = fid_emb[slice_conf.start:slice_conf.end] tmp.append(emb_slice) if out_type == OutType.CONCAT: ts[i, offset:offset + dim] = pooling(pooling_type, tmp, max_length) elif out_type == OutType.STACK: ts[i, offset, :] = pooling(pooling_type, tmp, max_length) elif out_type == OutType.ADDN: ret = pooling(pooling_type, tmp, max_length) tmp_addn[i, :] = ret else: ts[i, :] = pooling(pooling_type, tmp, max_length) else: # shared & copy if out_type == OutType.CONCAT: ts[i, offset:offset + dim] = ts[i - 1, offset:offset + dim] elif out_type == OutType.STACK: ts[i, offset, :] = ts[i - 1, offset, :] elif out_type == OutType.ADDN: tmp_addn[i, :] = tmp_addn[i - 1, :] else: ts[i, :] = ts[i - 1, :] if out_type == OutType.ADDN: ts += tmp_addn #logging.info(f" {ln} {ts} ") #logging.info(f"xxx out_tensor_list: {out_tensor_list}") for name, t, p in zip(out_tensor_name_list, out_tensor_list, layouts): #logging.info(f"xxx show result: {name} \n ans:{t} \n res:{p}") flag = np.allclose(t, p, rtol=1e-04, atol=1e-07, equal_nan=False) if not flag: logging.error(f"xxx show result: {name} \n ans:{t} \n res:{p}") else: logging.info(f"show result: {name} \n ans:{t} \n res:{p}") assert flag def test_fused_embedding_to_layout_use_shard_op(self): self.test_fused_embedding_to_layout(shard_op_version=2) def test_fused_embedding_to_layout_use_shard_op3(self): self.test_fused_embedding_to_layout(shard_op_version=3, op_version=3) def test_fused_embedding_to_layout_use_shard_op3_gpu(self): self.test_fused_embedding_to_layout(shard_op_version=3, op_version=3, use_gpu=True) def test_fused_embedding_to_layout_use_shard_op4(self): self.test_fused_embedding_to_layout(shard_op_version=4, op_version=4) def test_fused_embedding_to_layout_use_shard_op4_gpu(self): self.test_fused_embedding_to_layout(shard_op_version=4, op_version=4, use_gpu=True) def test_fused_embedding_to_layout_parallel(self): self.test_fused_embedding_to_layout(parallel_flag=0) def test_fused_embedding_to_layout_grad(self, shard_op_version=None, op_version=2, parallel_flag=1, use_gpu=False): batch_size = 256 num_ps = 3 slot_num = 30 slot_table_split = [10, 20] #slot split for [table_one, table_two, table_three] max_sequence_length = 3 alphabet_name = list(string.ascii_lowercase) + ['za', 'zb', 'zc', 'zd'] feature_cfgs = FeatureConfigs() #sparse_features = list() bias = OutConfig() vec = OutConfig() ffm1 = OutConfig() ffm2 = OutConfig() firstN = OutConfig() for slot in range(1, slot_num): feature_name = f"fc_slot_{alphabet_name[slot - 1]}" fconf = FeatureConfig() if slot >= slot_table_split[1]: table_name = "table_one" #table_three, but now test for table with different dim slice_dims = [1, 4, 16] sequence_length = max_sequence_length pooling_type = PoolingType.FIRSTN slice_config = firstN.slice_configs.add() slice_config.feature_name = feature_name slice_config.start = 0 slice_config.end = 21 else: sequence_length = 0 if slot < slot_table_split[0]: table_name = "table_one" slice_dims = [1, 4, 8] pooling_type = PoolingType.SUM slice_config = ffm1.slice_configs.add() slice_config.feature_name = feature_name slice_config.start = 5 slice_config.end = 13 else: table_name = "table_two" slice_dims = [1, 4, 16] pooling_type = PoolingType.MEAN slice_config = ffm2.slice_configs.add() slice_config.feature_name = feature_name slice_config.start = 5 slice_config.end = 21 slice_config = bias.slice_configs.add() slice_config.feature_name = feature_name slice_config.start = 0 slice_config.end = 1 slice_config = vec.slice_configs.add() slice_config.feature_name = feature_name slice_config.start = 1 slice_config.end = 5 fconf.table = table_name fconf.slice_dims.extend(slice_dims) fconf.max_sequence_length = sequence_length fconf.pooling_type = pooling_type feature_cfgs.feature_configs[feature_name].CopyFrom(fconf) infer_shape(bias, OutType.ADDN) feature_cfgs.out_configs['bias'].CopyFrom(bias) infer_shape(vec, OutType.CONCAT) feature_cfgs.out_configs['vec'].CopyFrom(vec) infer_shape(ffm1, OutType.STACK) feature_cfgs.out_configs['ffm1'].CopyFrom(ffm1) infer_shape(ffm2, OutType.NONE) feature_cfgs.out_configs['ffm2'].CopyFrom(ffm2) infer_shape(firstN, OutType.NONE, max_sequence_length) feature_cfgs.out_configs['firstN'].CopyFrom(firstN) logging.info(f"feature_cfgs : {feature_cfgs} ") feature_cfg, table_cfg, feature_name_sort, table_name_sort = self.get_feature_cfg( feature_cfgs, num_ps) # gen all fids slot2fid = defaultdict(list) sparse_features = [Example() for i in range(batch_size)] fid_idx_list_batch = defaultdict(lambda: [[] for i in range(batch_size)]) for slot in range(1, slot_num): feature_name = f"fc_slot_{alphabet_name[slot - 1]}" fids = list( set([(slot << 48) + randint(100, 1000000) for _ in range(randint(batch_size + 1, batch_size + 10))])) slot2fid[slot] = fids for bi in range(batch_size): sparse_feature = sparse_features[bi] named_feature = sparse_feature.named_feature.add() named_feature.name = feature_name fid_idx_list = [i for i in range(bi, len(fids) - batch_size + 1 + bi)] fid_idx_list_batch[slot][bi] = fid_idx_list fid_list = [fids[idx] for idx in fid_idx_list] named_feature.feature.fid_v2_list.value.extend(fid_list) # gen offset slot2fid_offset = defaultdict(list) embedding_fid_list = [[] for _ in range(num_ps * len(table_cfg))] emb_size_list = [0 for _ in range(num_ps * len(table_cfg))] fid_row_split_list = [[0] for _ in range(num_ps * len(table_cfg))] # record the truth truth = defaultdict(lambda: defaultdict(list)) for slot in range(1, slot_num): feature_name = f"fc_slot_{alphabet_name[slot - 1]}" f_cfg = feature_cfg[feature_name] dim_sum = f_cfg["dim_sum"] table_idx = f_cfg["table_index"] fids = slot2fid[slot] embedding_fid_list_tmp = [[] for _ in range(num_ps)] for fid in fids: ps_index = fid % num_ps index1 = table_idx * num_ps + ps_index embedding_fid_list[index1].append((fid, dim_sum)) emb_size_list[index1] += dim_sum index2 = len(embedding_fid_list[index1]) - 1 embedding_fid_list_tmp[ps_index].append(fid) feature_index = len(embedding_fid_list_tmp[ps_index]) - 1 full_index = self.get_pre_output_offset(ps_index, f_cfg) #[index1(table_index), index2(fid in table index) # , full_index(all_feature_index), feature_index(fid in all_feature index)] slot2fid_offset[slot].append( [index1, index2, full_index, feature_index]) truth[index1][index2] = [0, dim_sum] for ps_i in range(num_ps): index1 = table_idx * num_ps + ps_i fid_row_split_list[index1].append(len(embedding_fid_list[index1])) # gen offset fid_offset_list = list() feature_offset_list = [0] nfl_offset_list = [0] for slot in range(1, slot_num): feature_name = f"fc_slot_{alphabet_name[slot - 1]}" feature_config = feature_cfgs.feature_configs[feature_name] pooling_type = feature_config.pooling_type max_length = feature_config.max_sequence_length for bi in range(batch_size): fid_idx_list = fid_idx_list_batch[slot][bi] for i, idx in enumerate(fid_idx_list): index1, index2, full_index, feature_index = slot2fid_offset[slot][idx] fid_offset = full_index << 32 | feature_index fid_offset_list.append(fid_offset) if pooling_type == PoolingType.FIRSTN and i >= max_length: pass elif pooling_type == PoolingType.MEAN: truth[index1][index2][0] += 1 / len(fid_idx_list) else: truth[index1][index2][0] += 1 feature_offset_list.append(len(fid_offset_list)) nfl_index = len(feature_offset_list) - 1 nfl_offset_list.append(nfl_index) nfl_size_list = [len(nfl_offset_list)] feature_size_list = [len(feature_offset_list)] fid_size_list = [len(fid_offset_list)] # gen emb embeddings_list = list() for idx, embedding_fid in enumerate(embedding_fid_list): dim_sum = 0 for fid, dim in embedding_fid: dim_sum += dim size = (dim_sum, 1) emb = np.random.uniform(size=size) embeddings_list.append( tf.reshape(tf.constant(value=emb, dtype=tf.float32), [-1])) with self.session() as sess: if shard_op_version: get_default_parser_ctx().enable_fused_layout = True parsed_results = parse_examples( [sparse.SerializeToString() for sparse in sparse_features], sparse_features=[], dense_features=[], dense_feature_shapes=[], dense_feature_types=[], extra_features=[], extra_feature_shapes=[]) sparse_varint = parsed_results.pop( ParserCtx.sharding_sparse_fids_sparse_features_key) fid_list, fid_offset_list_ts, feature_offset_list_ts, nfl_offset_list_ts, batch_size_ts, nfl_size_list_ts, feature_size_list_ts, \ fid_size_list_ts, emb_size_list_ts, fid_row_split_list, fid_row_split_size_list, fid_list_emb_row_lenth, \ fid_list_table_row_length, fid_list_shard_row_lenth = sharding_sparse_fids( sparse_varint, num_ps, feature_cfgs, True, "example", parallel_flag=0, fid_list_ret_list=True, version=shard_op_version) assert op_version == shard_op_version else: fid_offset_list_ts = tf.constant(fid_offset_list, dtype=tf.uint64) feature_offset_list_ts = tf.constant(feature_offset_list, dtype=tf.int32) nfl_offset_list_ts = tf.constant(nfl_offset_list, dtype=tf.uint32) batch_size_ts = tf.constant([batch_size], dtype=tf.int32) nfl_size_list_ts = tf.constant(nfl_size_list, dtype=tf.int32) feature_size_list_ts = tf.constant(feature_size_list, dtype=tf.int32) fid_size_list_ts = tf.constant(fid_size_list, dtype=tf.int32) emb_size_list_ts = tf.constant(emb_size_list, dtype=tf.int32) fid_list_emb_row_lenth = None if op_version >= 3: raise TypeError('Not imple') if use_gpu: assert op_version >= 3 if op_version == 4: embeddings_list_new = [] for ps_i in range(num_ps): for table_i in range(len(table_name_sort)): embeddings_list_new.append(embeddings_list[table_i * num_ps + ps_i]) embeddings_list = [tf.concat(embeddings_list_new, axis=-1)] variant_type = 'example' with test_util.use_gpu() if use_gpu else tf.device("CPU:0"): layouts = distribution_ops.fused_embedding_to_layout( embeddings_list, fid_row_split_list, fid_offset_list_ts, feature_offset_list_ts, nfl_offset_list_ts, batch_size_ts, variant_type, feature_cfgs, num_ps, fid_list_emb_row_lenth=fid_list_emb_row_lenth, nfl_size=nfl_size_list_ts, feature_size=feature_size_list_ts, fid_size=fid_size_list_ts, emb_size=emb_size_list_ts, parallel_flag=parallel_flag, version=op_version) #layouts_ret = sess.run(layouts) #logging.info(f"show result: {layouts_ret}") test_grads = tf.gradients(layouts, embeddings_list) if op_version == 4: recv_embeddings_split = tf.split(test_grads[0], fid_list_emb_row_lenth) test_grads = [None] * (num_ps * len(table_name_sort)) recv_embeddings_split_index = 0 for ps_index in range(num_ps): for table_idx in range(len(table_name_sort)): test_grads[ table_idx * num_ps + ps_index] = recv_embeddings_split[recv_embeddings_split_index] recv_embeddings_split_index += 1 ''' TODO test_grads = distribution_ops.fused_embedding_to_layout_grad( nfl_offset_list_ts, feature_offset_list_ts, fid_offset_list_ts, batch_size_ts, embeddings_list, fid_row_split_list, layouts, variant_type, feature_cfgs, num_ps, parallel_flag=parallel_flag, version=2, ) ''' grads = sess.run(test_grads) logging.info(f"show result: {grads}") logging.info(f"show truth: {truth}") assert len(grads) == len(truth) for i in range(len(truth)): part_truth = truth[i] grad = grads[i] offset = 0 for j in range(len(part_truth)): t, dim = part_truth[j] # There is no slice use twice in the UT data, so the grads of one fid embedding should be the same assert len(np.unique(grad[offset: offset + dim])) == 1, \ f"Alert All The Same! [{i}, {j}] [{(t, dim)}, {grad[offset: offset + dim]}]" # The gound truth should be the fid used times assert np.allclose(t, grad[offset], rtol=1e-04, atol=1e-07, equal_nan=False), \ f"Alert Equal! [{i}, {j}] [{t} {grad[offset]}]" offset += dim def test_fused_embedding_to_layout_grad_no_parallel(self): self.test_fused_embedding_to_layout_grad(parallel_flag=0) def test_fused_embedding_to_layout_grad_use_shard_op(self): self.test_fused_embedding_to_layout_grad(shard_op_version=2, op_version=2) def test_fused_embedding_to_layout_grad_use_shard_op3(self): self.test_fused_embedding_to_layout_grad(shard_op_version=3, op_version=3) def test_fused_embedding_to_layout_grad_use_shard_op3_gpu(self): self.test_fused_embedding_to_layout_grad(shard_op_version=3, op_version=3, use_gpu=True) def test_fused_embedding_to_layout_grad_use_shard_op4(self): self.test_fused_embedding_to_layout_grad(shard_op_version=4, op_version=4) def test_fused_embedding_to_layout_grad_use_shard_op4_gpu(self): self.test_fused_embedding_to_layout_grad(shard_op_version=4, op_version=4, use_gpu=True) class FusedEmbeddingToLayoutFitPreTest(tf.test.TestCase): def test_fused_embedding_to_layout(self): batch_size = 10 feature_cfgs = FeatureConfigs() sparse_features = ExampleBatch(batch_size=batch_size) std_features = defaultdict(list) fids_dict = {} bias = OutConfig() vec = OutConfig() ffm1 = OutConfig() ffm2 = OutConfig() index1 = 0 index2 = [0] * 10 fid_offset_list = list() feature_offset_list = [0] nfl_offset_list = [0] feature_names_list = list() slot_to_nfl_map = dict() for slot in range(1, 50): named_feature_list = sparse_features.named_feature_list.add() named_feature_list.id = slot named_feature_list.name = f"fc_slot_{slot}" feature_names_list.append(named_feature_list.name) is_shared = True if slot % 2 == 0 else False named_feature_list.type = FeatureListType.SHARED if is_shared else FeatureListType.INDIVIDUAL slot_to_nfl_map[slot] = named_feature_list # set all offset_list by sorted sorted_feature_names_list sorted_feature_names_list = sorted(feature_names_list) for feature_name in sorted_feature_names_list: slot = int(feature_name.split("fc_slot_")[-1]) named_feature_list = slot_to_nfl_map[slot] fconf = FeatureConfig() if slot < 25: table_name = "table_one" slice_dims = [1, 4, 8] pooling_type = PoolingType.SUM slice_config = ffm1.slice_configs.add() slice_config.feature_name = f"fc_slot_{slot}" slice_config.start = 5 slice_config.end = 13 else: table_name = "table_two" slice_dims = [1, 4, 16] pooling_type = PoolingType.MEAN slice_config = ffm2.slice_configs.add() slice_config.feature_name = f"fc_slot_{slot}" slice_config.start = 5 slice_config.end = 21 slice_config = bias.slice_configs.add() slice_config.feature_name = f"fc_slot_{slot}" slice_config.start = 0 slice_config.end = 1 slice_config = vec.slice_configs.add() slice_config.feature_name = f"fc_slot_{slot}" slice_config.start = 1 slice_config.end = 5 fconf.table = table_name fconf.slice_dims.extend(slice_dims) fconf.pooling_type = pooling_type feature_cfgs.feature_configs[f"fc_slot_{slot}"].CopyFrom(fconf) table_name = "table_one" if slot < 25 else "table_two" if table_name not in fids_dict: fids_dict[table_name] = defaultdict(list) if named_feature_list.type == FeatureListType.SHARED: feature = named_feature_list.feature.add() std_features[named_feature_list.name].append(feature) fids = list( set([(slot << 48) + randint(100, 1000000) for _ in range(randint(1, 5))])) feature.fid_v2_list.value.extend(fids) for fid in fids: idx = fid % 5 if table_name == "table_one": index1 = 0 + idx fid_offset = index1 << 32 | index2[index1] index2[index1] += 1 else: index1 = 5 + idx fid_offset = index1 << 32 | index2[index1] index2[index1] += 1 fid_offset_list.append(fid_offset) fids_dict[table_name][idx].append(fid) feature_offset_list.append(len(fid_offset_list)) else: for _ in range(batch_size): feature = named_feature_list.feature.add() fids = list( set([(slot << 48) + randint(100, 1000000) for _ in range(randint(1, 5))])) feature.fid_v2_list.value.extend(fids) for fid in fids: idx = fid % 5 if table_name == "table_one": index1 = 0 + idx fid_offset = index1 << 32 | index2[index1] index2[index1] += 1 else: index1 = 5 + idx fid_offset = index1 << 32 | index2[index1] index2[index1] += 1 fid_offset_list.append(fid_offset) fids_dict[table_name][idx].append(fid) feature_offset_list.append(len(fid_offset_list)) std_features[named_feature_list.name].append(feature) nfl_index = len(feature_offset_list) - 1 nfl_offset_list.append(nfl_index) nfl_size_list = [len(nfl_offset_list)] feature_size_list = [len(feature_offset_list)] fid_size_list = [len(fid_offset_list)] # add shared encode nfl_idx = 0 for feature_name in sorted_feature_names_list: slot = int(feature_name.split("fc_slot_")[-1]) is_shared = True if slot % 2 == 0 else False head_bit = SHARD_BIT if is_shared else 0 nfl_offset_list[nfl_idx] |= head_bit nfl_idx += 1 infer_shape(bias, OutType.ADDN) feature_cfgs.out_configs['bias'].CopyFrom(bias) infer_shape(vec, OutType.CONCAT) feature_cfgs.out_configs['vec'].CopyFrom(vec) infer_shape(ffm1, OutType.STACK) feature_cfgs.out_configs['ffm1'].CopyFrom(ffm1) infer_shape(ffm2, OutType.NONE) feature_cfgs.out_configs['ffm2'].CopyFrom(ffm2) embeddings_dict, fid_to_emb = {}, {} emb_size_list = [] fids_list, embeddings_list, embeddings_np_list = [], [], [] for table_name, table in fids_dict.items(): embeddings_dict[table_name] = {} for idx in sorted(table): values = table[idx] fids_list.append(tf.constant(value=values, dtype=tf.int64)) if table_name == "table_one": length = 13 else: length = 21 size = (len(values), length) emb = np.random.uniform(size=size) emb_size_list.append(len(values) * length) embeddings_dict[table_name][idx] = emb embeddings_list.append( tf.constant(value=emb, shape=size, dtype=tf.float32)) embeddings_np_list.append(emb) for i, fid in enumerate(values): fid_to_emb[fid] = (len(embeddings_np_list) - 1, i) sparse_features_str = tf.constant(value=sparse_features.SerializeToString(), dtype=tf.string) variant_type = 'example_batch' # layouts = distribution_ops.fused_embedding_to_layout(sparse_features_str, fids_list, embeddings_list, variant_type, feature_cfgs) fid_offset_list_ts = tf.constant(fid_offset_list, dtype=tf.uint64) feature_offset_list_ts = tf.constant(feature_offset_list, dtype=tf.int32) nfl_offset_list_ts = tf.constant(nfl_offset_list, dtype=tf.uint32) batch_size_ts = tf.constant([batch_size], dtype=tf.int32) nfl_size_list_ts = tf.constant(nfl_size_list, dtype=tf.int32) feature_size_list_ts = tf.constant(feature_size_list, dtype=tf.int32) fid_size_list_ts = tf.constant(fid_size_list, dtype=tf.int32) emb_size_list_ts = tf.constant(emb_size_list, dtype=tf.int32) layouts_op = distribution_ops.fused_embedding_to_layout( embeddings_list, None, fid_offset_list_ts, feature_offset_list_ts, nfl_offset_list_ts, batch_size_ts, variant_type, feature_cfgs, -1, nfl_size=nfl_size_list_ts, feature_size=feature_size_list_ts, fid_size=fid_size_list_ts, emb_size=emb_size_list_ts, version=1) with self.session() as sess: layouts = sess.run(layouts_op) layout_names = sorted(['bias', 'vec', 'ffm1', 'ffm2']) out_tensors = {} layout_info = {} out_tensor_list = [] # get layout configs. for ln in layout_names: out_config = feature_cfgs.out_configs[ln] out_tensors[ln] = [] info = {} if len(out_config.shape) == 1: for shape in out_config.shape: real_shape = list(shape.dims) real_shape[0] = batch_size ts = np.zeros(shape=real_shape, dtype=np.float32) out_tensors[ln].append(ts) out_tensor_list.append(ts) offset = 0 for i, sc in enumerate(out_config.slice_configs): key = get_key(ln, sc) dim = sc.end - sc.start if out_config.out_type == OutType.CONCAT: info[key] = (ts, offset) offset += dim elif out_config.out_type == OutType.STACK: info[key] = (ts, i) elif out_config.out_type == OutType.ADDN: info[key] = (ts, 0) else: raise Exception("error") else: for sc, shape in zip(out_config.slice_configs, out_config.shape): real_shape = list(shape.dims) real_shape[0] = batch_size ts = np.zeros(shape=real_shape, dtype=np.float32) out_tensors[ln].append(ts) out_tensor_list.append(ts) key = get_key(ln, sc) info[key] = (ts, 0) layout_info[ln] = info # {name: (out, offset)} for ln in layout_names: out_config = feature_cfgs.out_configs[ln] out_type = out_config.out_type for slice_conf in out_config.slice_configs: name = slice_conf.feature_name features = std_features[name] feature_config = feature_cfgs.feature_configs[name] pooling_type = feature_config.pooling_type max_length = feature_config.max_sequence_length key = get_key(ln, slice_conf) dim = slice_conf.end - slice_conf.start (ts, offset) = layout_info[ln][key] if out_type == OutType.ADDN: tmp_addn = np.zeros(ts.shape) # per slice tmp out for i in range(batch_size): if i < len(features): tmp = [] for fid in features[i].fid_v2_list.value: (idx, row) = fid_to_emb[fid] emb_slice = embeddings_np_list[idx][ row, slice_conf.start:slice_conf.end] tmp.append(emb_slice) if out_type == OutType.CONCAT: ts[i, offset:offset + dim] = pooling(pooling_type, tmp, max_length) elif out_type == OutType.STACK: ts[i, offset, :] = pooling(pooling_type, tmp, max_length) elif out_type == OutType.ADDN: ret = pooling(pooling_type, tmp, max_length) tmp_addn[i, :] = ret else: ts[i, :] = pooling(pooling_type, tmp, max_length) else: # shared & copy if out_type == OutType.CONCAT: ts[i, offset:offset + dim] = ts[i - 1, offset:offset + dim] elif out_type == OutType.STACK: ts[i, offset, :] = ts[i - 1, offset, :] elif out_type == OutType.ADDN: tmp_addn[i, :] = tmp_addn[i - 1, :] else: ts[i, :] = ts[i - 1, :] if out_type == OutType.ADDN: ts += tmp_addn for t, p in zip(out_tensor_list, layouts): logging.info(f"fused_embedding_to_layout show {t} {p}") assert np.allclose(t, p, rtol=1e-04, atol=1e-07, equal_nan=False) def test_fused_embedding_to_layout_grad(self): batch_size = 4 slot_num = 30 alphabet_name = list(string.ascii_lowercase) + ['za', 'zb', 'zc', 'zd'] feature_cfgs = FeatureConfigs() sparse_features = list() # 13, 13, 21, 21 embedding_fid_list = [[], [], [], []] bias = OutConfig() vec = OutConfig() ffm1 = OutConfig() ffm2 = OutConfig() slot2fid = dict() slot2fid_offset = defaultdict(list) for slot in range(1, slot_num): table_idx = 0 if slot < slot_num / 2 else 2 fids = list( set([(slot << 48) + randint(100, 1000000) for _ in range(randint(2, 10))])) slot2fid[slot] = fids for fid in fids: if fid % 2: index1 = table_idx else: index1 = table_idx + 1 embedding_fid_list[index1].append(fid) index2 = len(embedding_fid_list[index1]) - 1 slot2fid_offset[slot].append([index1, index2]) fconf = FeatureConfig() table_name = "table_one" if slot < slot_num / 2 else "table_two" if slot < slot_num / 2: slice_dims = [1, 4, 8] pooling_type = PoolingType.SUM slice_config = ffm1.slice_configs.add() slice_config.feature_name = f"fc_slot_{alphabet_name[slot - 1]}" slice_config.start = 5 slice_config.end = 13 else: slice_dims = [1, 4, 16] pooling_type = PoolingType.MEAN slice_config = ffm2.slice_configs.add() slice_config.feature_name = f"fc_slot_{alphabet_name[slot - 1]}" slice_config.start = 5 slice_config.end = 21 slice_config = bias.slice_configs.add() slice_config.feature_name = f"fc_slot_{alphabet_name[slot - 1]}" slice_config.start = 0 slice_config.end = 1 slice_config = vec.slice_configs.add() slice_config.feature_name = f"fc_slot_{alphabet_name[slot - 1]}" slice_config.start = 1 slice_config.end = 5 fconf.table = table_name fconf.slice_dims.extend(slice_dims) fconf.pooling_type = pooling_type feature_cfgs.feature_configs[ f"fc_slot_{alphabet_name[slot - 1]}"].CopyFrom(fconf) infer_shape(bias, OutType.ADDN) feature_cfgs.out_configs['bias'].CopyFrom(bias) infer_shape(vec, OutType.CONCAT) feature_cfgs.out_configs['vec'].CopyFrom(vec) infer_shape(ffm1, OutType.STACK) feature_cfgs.out_configs['ffm1'].CopyFrom(ffm1) infer_shape(ffm2, OutType.NONE) feature_cfgs.out_configs['ffm2'].CopyFrom(ffm2) # record the truth truth = defaultdict(lambda: defaultdict(int)) fid_offset_list = list() feature_offset_list = [0] nfl_offset_list = [0] embeddings_list = list() sparse_features = [Example() for i in range(batch_size)] for slot in range(1, slot_num): feature_name = f"fc_slot_{alphabet_name[slot - 1]}" feature_config = feature_cfgs.feature_configs[feature_name] pooling_type = feature_config.pooling_type max_length = feature_config.max_sequence_length for bi in range(batch_size): sparse_feature = sparse_features[bi] named_feature = sparse_feature.named_feature.add() named_feature.name = f"fc_slot_{alphabet_name[slot - 1]}" all_fids = slot2fid[slot] fid_num = randint(0, len(all_fids)) fid_idx_list = [randint(1, len(all_fids) - 1) for i in range(fid_num)] fid_list = [all_fids[idx] for idx in fid_idx_list] named_feature.feature.fid_v2_list.value.extend(fid_list) for idx in fid_idx_list: index1, index2 = slot2fid_offset[slot][idx] fid_offset = index1 << 32 | index2 fid_offset_list.append(fid_offset) if pooling_type == PoolingType.FIRSTN and i >= max_length: pass elif pooling_type == PoolingType.MEAN: truth[index1][index2] += 1 / len(fid_idx_list) else: truth[index1][index2] += 1 feature_offset_list.append(len(fid_offset_list)) nfl_index = len(feature_offset_list) - 1 nfl_offset_list.append(nfl_index) nfl_size_list = [len(nfl_offset_list)] feature_size_list = [len(feature_offset_list)] fid_size_list = [len(fid_offset_list)] idx = 0 emb_size_list = [] for embedding_fid in embedding_fid_list: if idx < 2: embeddings_dim = [13 for i in embedding_fid] else: embeddings_dim = [21 for i in embedding_fid] size = (len(embeddings_dim), embeddings_dim[0]) emb = np.random.uniform(size=size) emb_size_list.append(size[0] * size[1]) embeddings_list.append( tf.constant(value=emb, shape=size, dtype=tf.float32)) idx += 1 variant_type = 'example' with self.session() as sess: fid_offset_list_ts = tf.constant(fid_offset_list, dtype=tf.uint64) feature_offset_list_ts = tf.constant(feature_offset_list, dtype=tf.int32) nfl_offset_list_ts = tf.constant(nfl_offset_list, dtype=tf.uint32) batch_size_ts = tf.constant([batch_size], dtype=tf.int32) nfl_size_list_ts = tf.constant(nfl_size_list, dtype=tf.int32) feature_size_list_ts = tf.constant(feature_size_list, dtype=tf.int32) fid_size_list_ts = tf.constant(fid_size_list, dtype=tf.int32) emb_size_list_ts = tf.constant(emb_size_list, dtype=tf.int32) layouts = distribution_ops.fused_embedding_to_layout( embeddings_list, None, fid_offset_list_ts, feature_offset_list_ts, nfl_offset_list_ts, batch_size_ts, variant_type, feature_cfgs, -1, nfl_size=nfl_size_list_ts, feature_size=feature_size_list_ts, fid_size=fid_size_list_ts, emb_size=emb_size_list_ts, version=1) test_grads = tf.gradients(layouts, embeddings_list) grads = sess.run(test_grads) for i in range(len(grads)): grad = grads[i] for j in range(grad.shape[0]): # There is no slice use twice in the UT data, so the grads of one fid embedding should be the same assert len(np.unique( grad[j, :])) == 1, "Alert All The Same! [{}, {}]".format(i, j) # The gound truth should be the fid used times logging.info( f"fused_embedding_to_layout grad show {truth[i][j]} {grad[j, 0]}") assert np.allclose(truth[i][j],grad[j, 0], rtol=1e-04, atol=1e-07, equal_nan=False), \ "Alert Equal! [{}, {}]".format(i, j) if __name__ == "__main__": # tf.compat.v1.disable_eager_execution() tf.test.main() ================================================ FILE: monolith/native_training/gen_seq_mask.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import tensorflow as tf from typing import Union from monolith.native_training.runtime.ops import gen_monolith_ops ops = gen_monolith_ops def gen_seq_mask(splits: Union[tf.Tensor, tf.RaggedTensor], max_seq_length: int) -> tf.Tensor: if isinstance(splits, tf.RaggedTensor): splits = splits.row_splits() return ops.gen_seq_mask(splits=splits, max_seq_length=max_seq_length) ================================================ FILE: monolith/native_training/gen_seq_mask_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import tensorflow as tf from monolith.native_training.gen_seq_mask import gen_seq_mask class GenSeqMaskTest(tf.test.TestCase): def test_gen_seq_mask_int32(self): split = tf.constant([0, 5, 7, 9, 13], dtype=tf.int32) mask = gen_seq_mask(split, 6) result = tf.constant([[1, 1, 1, 1, 1, 0], [1, 1, 0, 0, 0, 0], [1, 1, 0, 0, 0, 0], [1, 1, 1, 1, 0, 0]], dtype=tf.int32) self.assertAllEqual(mask, result) def test_gen_seq_mask_int64(self): split = tf.constant([0, 5, 7, 9, 13], dtype=tf.int64) mask = gen_seq_mask(split, 6) result = tf.constant([[1, 1, 1, 1, 1, 0], [1, 1, 0, 0, 0, 0], [1, 1, 0, 0, 0, 0], [1, 1, 1, 1, 0, 0]], dtype=tf.int64) self.assertAllEqual(mask, result) if __name__ == "__main__": tf.test.main() ================================================ FILE: monolith/native_training/gflags_utils.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from absl.flags import FlagValues from absl import logging, flags import dataclasses from dataclasses import Field from enum import Enum import inspect import re import sys from typing import get_type_hints, Iterable, Tuple, Dict FLAGS = flags.FLAGS _SPACE = re.compile(r"\s+") _PARAM = re.compile(r"^:param\s+([a-zA-Z0-9._-]+)\s*:\s*(.*)") class Status(Enum): Init = 1 Open = 2 Extend = 3 Closed = 4 def _extract_help_info(cls, help_info, is_nested): if is_nested: for base in cls.__bases__: assert _extract_help_info(base, help_info, is_nested) == Status.Closed doc = [ " ".join(re.split(_SPACE, line.strip())) for line in cls.__doc__.split('\n') if len(line.strip()) > 0 ] key_stack = [] status = Status.Init for i, line in enumerate(doc): matched = _PARAM.match(line) if matched: new_key, info = matched.groups() if status == Status.Init: help_info[new_key] = [info] key_stack.append(new_key) elif status == Status.Open or status == Status.Extend: old_key = key_stack.pop() assert old_key != new_key help_info[new_key] = [info] key_stack.append(new_key) else: assert status == Status.Closed break # trans status status = Status.Open else: if status == Status.Init: pass elif status == Status.Open: key = key_stack[-1] help_info[key].append(line) status = Status.Extend elif status == Status.Extend: key = key_stack[-1] help_info[key].append(line) else: assert status == Status.Closed break if i + 1 == len(doc): status = Status.Closed return status def extract_help_info(cls, is_nested=True): help_info = {} status = _extract_help_info(cls, help_info, is_nested) assert status == Status.Closed return {key: " ".join(value) for key, value in help_info.items()} def extract_flags_decorator(remove_flags=None, is_nested=True): def decorator(cls): extract_flags(flags, cls, is_nested, remove_flags) return cls return decorator def extract_flags(gflags, dcls, is_nested=True, skip_flags=None) -> FlagValues: FLAGS = gflags.FLAGS help_info = extract_help_info(dcls, is_nested) skip_flags = set() if skip_flags is None else set(skip_flags) for key, dtype in get_type_hints(dcls).items(): if key not in help_info.keys() or key in skip_flags: continue default = getattr(dcls, key) help_str = "default={}, {}".format(default, help_info.get(key, "")) try: if dtype == int: gflags.DEFINE_integer(key, default, "{}, {}".format('int', help_str)) elif dtype == bool: gflags.DEFINE_bool(key, default, "{}, {}".format('bool', help_str)) elif dtype == str: gflags.DEFINE_string(key, default, "{}, {}".format('string', help_str)) elif dtype == float: gflags.DEFINE_float(key, default, "{}, {}".format('float', help_str)) elif issubclass(dtype, Enum): default_value = default.name.lower() enum_values = [name.lower() for name in dtype._member_names_] gflags.DEFINE_enum(key, default_value, enum_values, "{}, {}".format('enum', help_str)) else: raise ValueError("only is support!") except: pass return FLAGS def get_flags_parser(flags, FLAGS): def flags_parser(args): try: return FLAGS(args) except flags.Error as error: logging.error('FATAL Flags parsing error: {}\n{}'.format( error, FLAGS.get_help(include_special_flags=False))) logging.error('Pass --helpshort or --helpfull to see help on flags.\n') sys.exit(1) return flags_parser def update(config): """ update config's attr value using flags.FLAGS if config's attr value is default value and FLAGS' attr value is not default config: any type of Config like CpuTraingingConfig, DistributedCpuTrainingConfig example: see gflags_utils_test.py test_update() """ FLAGS = flags.FLAGS cls = config.__class__ for key, dtype in get_type_hints(cls).items(): tmp = getattr(cls, key) if isinstance(tmp, Field): field = tmp default = field.default if field.default is not None else field.default_factory( ) else: default = tmp from_code = config.__dict__.get(key, default) try: if not hasattr(FLAGS, key): continue except: continue if issubclass(dtype, Enum): from_cmd = dtype[getattr(FLAGS, key).upper()] else: from_cmd = getattr(FLAGS, key) if from_code == default and from_cmd != default: # user has not set this field, it should not overwrite by cmd config.__dict__[key] = from_cmd else: continue return config @dataclasses.dataclass class _MonolithGflagMeta: # Link a field `name` to `flag` linked_map: Dict[str, str] = dataclasses.field(default_factory=dict) def _get_flag_obj(cls, set_if_not_exists=False) -> _MonolithGflagMeta: attr = "_monolith_gflag_meta" if not hasattr(cls, attr): if set_if_not_exists: setattr(cls, attr, _MonolithGflagMeta()) else: return _MonolithGflagMeta() return getattr(cls, attr) class LinkDataclassToFlags: """Links a field's default value to a flag. Example: flags.DEFINE_int("c_value", 0, "") @LinkDataclassToFlags(linked_map={"v": "v_value"}) @dataclass.dataclasses class C: v: int = 0 When we instantiate C as c, if v_value is not 0, and c.v is the default value (which is 0 in this case), c.v = FLAGS.v_value """ def __init__(self, linked_list=None, linked_map=None): """Elements `e` in linked_list is equivalent to `e:e` in the linekd_map """ self._m = {} self._m.update(linked_map or {}) linked_list = linked_list or [] for name in linked_list: self._m[name] = name def __call__(self, cls): assert dataclasses.is_dataclass( cls), "LinkDataclassToFlag should be used on dataclasses" fields = dataclasses.fields(cls) named_fields = {field.name: field for field in fields} for name, flag in self._m.items(): if name not in named_fields: raise ValueError(f"{name} is not a valid attribute of {type(cls)}") if flag not in FLAGS: raise ValueError(f"{flag} is not defined in gflags") obj = _get_flag_obj(cls, set_if_not_exists=True) for name, flag in self._m.items(): obj.linked_map[name] = flag return cls def _get_merged_meta(cls): meta = _MonolithGflagMeta() classes = inspect.getmro(cls) for c in classes: obj = _get_flag_obj(c) meta.linked_map.update(obj.linked_map) return meta def update_by_flags(cls): assert dataclasses.is_dataclass( cls), "update_by_flag should be used on dataclasses" orig_init = cls.__init__ fields = dataclasses.fields(cls) named_fields = {field.name: field for field in fields} del fields meta = _get_merged_meta(cls) def create_init(orig_init, named_fields, meta): def __init__(self, *args, **kwargs): orig_init(self, *args, **kwargs) for name, flag in meta.linked_map.items(): if getattr(self, name) == named_fields[name].default and getattr( FLAGS, flag) != FLAGS[flag].default: setattr(self, name, getattr(FLAGS, flag)) return __init__ cls.__init__ = create_init(orig_init, named_fields, meta) return cls ================================================ FILE: monolith/native_training/gflags_utils_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from absl import flags from absl.testing import absltest from typing import get_type_hints from enum import Enum import dataclasses from monolith.native_training import gflags_utils as utils FLAGS = flags.FLAGS @dataclasses.dataclass class TestConfig: """ :param test_int1: integer 1 for test :param test_int2: integer 2 for test :param test_str: string for test and test another line """ test_int1: int = 0 test_int2: int = 0 test_str: str = None @utils.extract_flags_decorator() @dataclasses.dataclass class TestConfig2: """ :param testconfig2_int1: integer 1 for TestConfig2 :param testconfig2_str1: str 1 for TestConfig2 """ testconfig2_int1: int = 1 testconfig2_str1: str = "str1" @utils.extract_flags_decorator({"testconfig3_int1"}) @dataclasses.dataclass class TestConfig3: """ :param testconfig3_int1: testconfig3_int1 :param testconfig3_int2: testconfig3_int2 :param testconfig3_str1: testconfig3_str1 """ testconfig3_int1: int = 1 testconfig3_int2: int = 2 testconfig3_str1: str = "str1" @dataclasses.dataclass class TestConfig4(TestConfig): """ :param testconfig4_int1: testconfig4_int1 :param testconfig4_str1: testconfig4_str1 """ testconfig4_int1: int = 4 testconfig4_str1: str = "testconfig4_str1" @dataclasses.dataclass class TestConfig5Base: """ :param testconfig5base_int1: testconfig5base_int1 :param testconfig5base_int2: testconfig5base_int2 :param testconfig5base_str: testconfig5base_str """ testconfig5base_int1: int = 0 testconfig5base_int2: int = 0 testconfig5base_str: str = None @utils.extract_flags_decorator(is_nested=False) @dataclasses.dataclass class TestConfig5(TestConfig5Base): """ :param testconfig5_int1: testconfig5_int1 :param testconfig5_str1: testconfig5_str1 """ testconfig5_int1: int = 5 testconfig5_str1: str = "testconfig5_str1" class GflagUtilsTest(absltest.TestCase): def _check_help_info(self, cls, skip_flags=set()): help_info = utils.extract_help_info(cls) for key, _ in get_type_hints(cls).items(): if key in skip_flags: continue self.assertIn(key, help_info, '{} is not in {}, please add a help info'.format(key, cls)) def test_extract_help_info(self): res = utils.extract_help_info(TestConfig) self.assertIn("test_int1", res) self.assertIn("test_int2", res) self.assertIn("test_str", res) self.assertEqual("integer 1 for test", res["test_int1"]) self.assertEqual("integer 2 for test", res["test_int2"]) self.assertEqual("string for test and test another line", res["test_str"]) res2 = utils.extract_help_info(TestConfig4, is_nested=False) self.assertIn("testconfig4_int1", res2) self.assertIn("testconfig4_str1", res2) self.assertNotIn("test_int1", res2) self.assertNotIn("test_int2", res2) self.assertNotIn("test_str", res2) def test_update(self): FLAGS = flags.FLAGS flags.DEFINE_integer("test_int1", 2, "test int 1") flags.DEFINE_integer("test_int2", 3, "test int 2") config = TestConfig( test_int1=1, test_int2=0, ) utils.update(config) # will not update test_int1 because test_int1 in config is not default value. # will update test_int2 because test_int2 in config is default value # and FLAGS.test_int2 is not default value. self.assertEqual(config.test_int1, 1) #not updated self.assertEqual(config.test_int2, 3) #updated self.assertEqual(config.test_str, None) # for test_str attr, no FLAGS is define, so nothing will happend def test_extract_gflags_decorator(self): FLAGS = flags.FLAGS conf = TestConfig2(testconfig2_int1=2, testconfig2_str1="newstr1") self.assertEqual(FLAGS.testconfig2_int1, 1) self.assertEqual(FLAGS.testconfig2_str1, "str1") self.assertEqual(conf.testconfig2_int1, 2) self.assertEqual(conf.testconfig2_str1, "newstr1") conf3 = TestConfig3() self.assertEqual(hasattr(FLAGS, "testconfig3_int1"), False) self.assertEqual(hasattr(FLAGS, "testconfig3_int2"), True) self.assertEqual(hasattr(FLAGS, "testconfig3_str1"), True) self.assertEqual(hasattr(conf3, "testconfig3_int1"), True) self.assertEqual(hasattr(conf3, "testconfig3_int2"), True) self.assertEqual(hasattr(conf3, "testconfig3_str1"), True) conf5 = TestConfig5() self.assertEqual(hasattr(FLAGS, "testconfig5_int1"), True) self.assertEqual(hasattr(FLAGS, "testconfig5_str1"), True) self.assertEqual(hasattr(FLAGS, "testconfig5base_int1"), False) self.assertEqual(hasattr(FLAGS, "testconfig5base_int2"), False) self.assertEqual(hasattr(FLAGS, "testconfig5base_str"), False) self.assertEqual(hasattr(conf5, "testconfig5_int1"), True) self.assertEqual(hasattr(conf5, "testconfig5_str1"), True) self.assertEqual(hasattr(conf5, "testconfig5base_int1"), True) self.assertEqual(hasattr(conf5, "testconfig5base_int2"), True) self.assertEqual(hasattr(conf5, "testconfig5base_str"), True) flags.DEFINE_string("testflag6", "", "") @utils.update_by_flags @utils.LinkDataclassToFlags(linked_list=["testflag6"], linked_map={"v": "testflag6"}) @dataclasses.dataclass class TestClass6: v: str = None testflag6: str = None def test_link_flag(self): FLAGS.testflag6 = "" c = self.TestClass6() self.assertEqual(c.v, None) self.assertEqual(c.testflag6, None) FLAGS.testflag6 = "a" c = self.TestClass6() self.assertEqual(c.v, "a") self.assertEqual(c.testflag6, "a") FLAGS.testflag6 = "b" c = self.TestClass6(v="v") self.assertEqual(c.v, "v") self.assertEqual(c.testflag6, "b") flags.DEFINE_string("testflag7", "", "") @utils.LinkDataclassToFlags(linked_map={"v": "testflag7"}) @dataclasses.dataclass class TestClass7Base: v: str = None @utils.update_by_flags @dataclasses.dataclass class TestClass7Inherit(TestClass7Base): v2: str = "v2" def test_link_flag_inheritance(self): FLAGS.testflag7 = "a" c = self.TestClass7Inherit() self.assertEqual(c.v, "a") self.assertEqual(c.v2, "v2") if __name__ == "__main__": absltest.main() ================================================ FILE: monolith/native_training/graph_meta.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Callable import tensorflow as tf _MONOLITH_GRAPH_META = "monolith_graph_meta" def get_meta(key: str, MetaFactory: Callable): g = tf.compat.v1.get_default_graph() l = g.get_collection_ref(_MONOLITH_GRAPH_META) if not l: l.append({}) meta_dict = l[0] if key not in meta_dict: meta_dict[key] = MetaFactory() return meta_dict[key] ================================================ FILE: monolith/native_training/graph_utils.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import tensorflow as tf def add_batch_norm_into_update_ops(): ops = tf.compat.v1.get_default_graph().get_operations() update_ops = [ op for op in ops if 'AssignMovingAvg' in op.name and op.type == "AssignSubVariableOp" ] for update_op in update_ops: tf.compat.v1.add_to_collection(tf.compat.v1.GraphKeys.UPDATE_OPS, update_op) ================================================ FILE: monolith/native_training/hash_filter_ops.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from absl import logging from contextlib import nullcontext import os from typing import List import tensorflow as tf from tensorflow.python.framework import ops from monolith.native_training import basic_restore_hook from monolith.native_training import save_utils from monolith.native_training import utils from monolith.native_training.runtime.ops import gen_monolith_ops from monolith.utils import get_libops_path from monolith.native_training.model_export.export_context import is_exporting_standalone HASH_FILTER_CAPACITY = 300000000 HASH_FILTER_SPLIT_NUM = 7 filter_ops = gen_monolith_ops filter_save_op = gen_monolith_ops filter_restore_op = gen_monolith_ops _TIMEOUT_IN_MS = 30 * 60 * 1000 class FilterType(object): SLIDING_HASH_FILTER = 'sliding_hash_filter' PROBABILISTIC_FILTER = 'probabilistic_filter' NO_FILTER = 'no_filter' def create_hash_filter(capacity: int, split_num: int, config: bytes = b"", name_suffix: str = "") -> tf.Tensor: """Creates a hash filter""" return filter_ops.MonolithHashFilter(capacity=capacity, split_num=split_num, config=config, shared_name="MonolithHashFilter" + name_suffix) def create_probabilistic_filter(equal_probability, config: bytes = b"", name_suffix: str = "") -> tf.Tensor: """Creates a probabilistic filter""" return filter_ops.MonolithProbabilisticFilter( equal_probability=equal_probability, config=config, shared_name="MonolithProbabilisticFilter" + name_suffix) def create_dummy_hash_filter(name_suffix: str = "0") -> tf.Tensor: """Creates a dummy hash filter""" return filter_ops.MonolithDummyHashFilter(shared_name="DummyHashFilter" + name_suffix) def _create_hash_filter( enable_hash_filter: bool, config: bytes = b"", name_suffix: str = "", filter_capacity: int = HASH_FILTER_CAPACITY, filter_split_num: int = HASH_FILTER_SPLIT_NUM, filter_equal_probability: bool = False, filter_type: FilterType = FilterType.SLIDING_HASH_FILTER) -> tf.Tensor: if enable_hash_filter is True: if filter_type == FilterType.SLIDING_HASH_FILTER: return create_hash_filter(filter_capacity, filter_split_num, config, name_suffix) elif filter_type == FilterType.PROBABILISTIC_FILTER: return create_probabilistic_filter(filter_equal_probability, config, name_suffix) elif filter_type == FilterType.NO_FILTER: return create_dummy_hash_filter(name_suffix) else: raise ValueError("Invalid filter type, please investigate and retry!") else: return create_dummy_hash_filter(name_suffix) def create_hash_filters( ps_num: int, enable_hash_filter: bool, config: bytes = b"", filter_capacity: int = HASH_FILTER_CAPACITY, filter_split_num: int = HASH_FILTER_SPLIT_NUM, filter_equal_probability: bool = False, filter_type: FilterType = FilterType.SLIDING_HASH_FILTER ) -> List[tf.Tensor]: logging.info( "Create hash fitlers, enable_hash_filter:{}.".format(enable_hash_filter)) if ps_num == 0: return [ _create_hash_filter(enable_hash_filter, config, "", filter_capacity, filter_split_num, filter_equal_probability=filter_equal_probability, filter_type=filter_type) ] else: hash_filters = [] for i in range(ps_num): ps_device_name = utils.ps_device(i) with nullcontext() if is_exporting_standalone() else tf.device( ps_device_name): hash_filters.append( _create_hash_filter( enable_hash_filter, config, "_" + str(i), filter_capacity, filter_split_num, filter_equal_probability=filter_equal_probability, filter_type=filter_type)) return hash_filters def save_hash_filter(hash_filter: tf.Tensor, hash_filter_basename: tf.Tensor, enable_hash_filter: bool = False) -> tf.Operation: if enable_hash_filter is True: return filter_save_op.monolith_hash_filter_save(hash_filter, hash_filter_basename) else: return tf.no_op() def restore_hash_filter(hash_filter: tf.Tensor, hash_filter_base_name: tf.Tensor, enable_hash_filter: bool = False) -> tf.Operation: if enable_hash_filter is True: return filter_restore_op.monolith_hash_filter_restore( hash_filter, hash_filter_base_name) else: return tf.no_op() def intercept_gradient(filter_tensor: tf.Tensor, ids: tf.Tensor, embeddings: tf.Tensor): """ If id is supposed to be filtered, the gradient will be intercepted. Output the same embeddings. Args: ids - a 1-D int64 tensor. embeddings - a N-d embedding tensor whose the first dimension is corresponding to ids. """ return filter_ops.MonolithHashFilterInterceptGradient( filter_handle=filter_tensor, ids=ids, embeddings=embeddings) class HashFilterCheckpointSaverListener(tf.estimator.CheckpointSaverListener): """ Saves the hash filters when saver is run. """ def __init__(self, basename: str, hash_filters: [tf.Tensor], enable_hash_filter: bool = False, enable_save_restore: bool = True): """ |basename| should be a file name which is same as what is passed to saver. |hash_filters| hash filters to save in checkpoint. |enable_hash_filter| whether use real hash filters. If true, will save hash filters in checkpoint. If false, will skip save logic internally. enable_hash_filter: TODO(zouxuan) Whether to use save and restore on the hash filter. Hash filter is broken for save restore during sync training right now. """ super().__init__() self._helper = save_utils.SaveHelper(basename) self._hash_filters = hash_filters self._enable_hash_filter = enable_hash_filter self._enable_save_restore = enable_save_restore self._hash_filter_id_to_placeholder = {} self._save_op = self._build_save_graph() def before_save(self, sess, global_step_value): """ We use before save so the checkpoint file is updated after we successfully save the hash filter. """ if self._enable_hash_filter is False or self._enable_save_restore is False: return feed_dict = {} hash_filter_names = [] asset_dir = self._helper.get_ckpt_asset_dir( self._helper.get_ckpt_prefix(global_step_value)) tf.io.gfile.makedirs(asset_dir) for ps_idx, hash_filter in enumerate(self._hash_filters): hash_filter_basename = asset_dir + "hash_filter_{}".format(ps_idx) hash_filter_names.append(hash_filter_basename) feed_dict.update({ self._hash_filter_id_to_placeholder[id(hash_filter)]: hash_filter_basename }) sess.run(self._save_op, feed_dict=feed_dict, options=tf.compat.v1.RunOptions(timeout_in_ms=_TIMEOUT_IN_MS)) logging.info("Finished saving hash filters.") def _build_save_graph(self) -> tf.Operation: if self._enable_hash_filter is False or self._enable_save_restore is False: return tf.no_op() last_op = tf.no_op() for ps_idx, hash_filter in enumerate(self._hash_filters): hash_filter_basename = tf.compat.v1.placeholder(tf.string, shape=[]) self._hash_filter_id_to_placeholder.update( {id(hash_filter): hash_filter_basename}) with tf.control_dependencies([last_op]): last_op = save_hash_filter(hash_filter, hash_filter_basename, True) return last_op class HashFilterCheckpointRestorerListener( basic_restore_hook.CheckpointRestorerListener): """Restores the hash filters from basename""" def __init__(self, basename: str, hash_filters: [tf.Tensor], enable_hash_filter: bool = False, enable_save_restore: bool = True): """ |basename| should be a file name which is same as what is passed to saver. |hash_filters| hash filters to save in checkpoint. |enable_hash_filter| whether use real hash filters. If true, will save hash filters in checkpoint. If false, will skip save logic internally. enable_hash_filter: TODO(zouxuan) Whether to use save and restore on the hash filter. Hash filter is broken for save restore during sync training right now. """ super().__init__() self._basename = basename self._helper = save_utils.SaveHelper(self._basename) self._hash_filters = hash_filters self._enable_hash_filter = enable_hash_filter self._enable_save_restore = enable_save_restore self._hash_filter_id_to_placeholder = {} self._restore_op = self._build_restore_graph() def before_restore(self, session): """ We use before restore so as to strictly control the order of restorer listeners. """ ckpt_prefix = tf.train.latest_checkpoint(os.path.dirname(self._basename)) if not ckpt_prefix: logging.info("No checkpoint found in %s. Skip the hash filters restore.", self._basename) return logging.info("Restore hash filter from %s", ckpt_prefix) asset_dir = self._helper.get_ckpt_asset_dir(ckpt_prefix) if tf.io.gfile.exists(asset_dir): self._restore_from_path_prefix(session, asset_dir) else: # This is the legacy behavior and should be removed soon. self._restore_from_path_prefix(session, ckpt_prefix) def _restore_from_path_prefix(self, sess, path_prefix): if self._enable_hash_filter is False or self._enable_save_restore is False: return feed_dict = {} hash_filter_names = [] for ps_idx, hash_filter in enumerate(self._hash_filters): hash_filter_basename = path_prefix + "hash_filter_{}".format(ps_idx) hash_filter_names.append(hash_filter_basename) feed_dict.update({ self._hash_filter_id_to_placeholder[id(hash_filter)]: hash_filter_basename }) sess.run(self._restore_op, feed_dict=feed_dict, options=tf.compat.v1.RunOptions(timeout_in_ms=_TIMEOUT_IN_MS)) def _build_restore_graph(self) -> tf.Operation: if self._enable_hash_filter is False or self._enable_save_restore is False: return tf.no_op() restore_ops = [] for ps_idx, hash_filter in enumerate(self._hash_filters): hash_filter_basename = tf.compat.v1.placeholder(tf.string, shape=[]) self._hash_filter_id_to_placeholder.update( {id(hash_filter): hash_filter_basename}) restore_ops.append( restore_hash_filter(hash_filter, hash_filter_basename, True)) return tf.group(restore_ops) @ops.RegisterGradient("MonolithHashFilterInterceptGradient") def _intercept_gradient_gradient(op: tf.Operation, grad: tf.Tensor): filter_tensor = op.inputs[0] ids = op.inputs[1] filtered_grad = filter_ops.MonolithHashFilterInterceptGradientGradient( filter_handle=filter_tensor, ids=ids, grad=grad) return None, None, filtered_grad ================================================ FILE: monolith/native_training/hash_filter_ops_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os from typing import List import tensorflow as tf from tensorflow.python.lib.io import tf_record import monolith.native_training.hash_filter_ops as ops from monolith.native_training.runtime.hash_table import \ embedding_hash_table_pb2 def get_config_str(occurrence_threshold=0): config = embedding_hash_table_pb2.SlotOccurrenceThresholdConfig() config.default_occurrence_threshold = occurrence_threshold return config.SerializeToString() class HashFilterOpsTest(tf.test.TestCase): def _count_files(self, basename: str): return len(tf.io.gfile.glob(basename + "*")) def _GetHashFilterSplitMetaDump(self, ckpt_file: str): for record in tf_record.tf_record_iterator(ckpt_file): return embedding_hash_table_pb2.HashFilterSplitMetaDump.FromString(record) return None def test_hash_filter_basic(self): config = get_config_str(3) hash_filter = ops.create_hash_filter(100, 7, config) # we choose a key that is unique enough so they won't collide with each other. ids = tf.constant([1, 3 << 17, 1], dtype=tf.int64) embedding = tf.zeros([3, 2]) loss = ops.intercept_gradient(hash_filter, ids, embedding) grad = tf.gradients(loss, embedding)[0] with self.session() as sess: grad_value = sess.run(grad) self.assertAllEqual(grad_value, [[0, 0], [0, 0], [0, 0]]) grad_value = sess.run(grad) self.assertAllEqual(grad_value, [[0, 0], [0, 0], [1, 1]]) grad_value = sess.run(grad) self.assertAllEqual(grad_value, [[1, 1], [0, 0], [1, 1]]) grad_value = sess.run(grad) self.assertAllEqual(grad_value, [[1, 1], [1, 1], [1, 1]]) def test_hash_filter_save_restore(self): config = get_config_str(3) hash_filter = ops.create_hash_filter(100, 7, config) # we choose a key that is unique enough so they won't collide with each other. ids = tf.constant([1, 3 << 17, 1], dtype=tf.int64) embedding = tf.zeros([3, 2]) loss = ops.intercept_gradient(hash_filter, ids, embedding) grad = tf.gradients(loss, embedding)[0] base_folder = os.path.join(os.environ["TEST_TMPDIR"], "test_hash_filter_save_restore") with self.session() as sess: # save checkpoint 0 ckpt_basename_0 = os.path.join(base_folder, "hash_filter_test_0") hash_filter_save_op = ops.save_hash_filter(hash_filter, ckpt_basename_0, True) sess.run(hash_filter_save_op) self.assertEqual(self._count_files(ckpt_basename_0), 7) # restore checkpoint 0 hash_filter_restore_op = ops.restore_hash_filter(hash_filter, ckpt_basename_0, True) sess.run(hash_filter_restore_op) grad_value = sess.run(grad) self.assertAllEqual(grad_value, [[0, 0], [0, 0], [0, 0]]) grad_value = sess.run(grad) self.assertAllEqual(grad_value, [[0, 0], [0, 0], [1, 1]]) # save checkpoint 1 ckpt_basename_1 = os.path.join(base_folder, "hash_filter_test_1") hash_filter_save_op = ops.save_hash_filter(hash_filter, ckpt_basename_1, True) sess.run(hash_filter_save_op) files = sorted(tf.io.gfile.glob(ckpt_basename_1 + "*")) self.assertEqual(self._count_files(ckpt_basename_1), 7) # restore checkpoint 1 hash_filter_restore_op = ops.restore_hash_filter(hash_filter, ckpt_basename_1, True) sess.run(hash_filter_restore_op) grad_value = sess.run(grad) self.assertAllEqual(grad_value, [[1, 1], [0, 0], [1, 1]]) grad_value = sess.run(grad) self.assertAllEqual(grad_value, [[1, 1], [1, 1], [1, 1]]) ckpt_basename = os.path.join(base_folder, "hash_filter_test") def test_hash_filter_save_restore_across_multiple_filters(self): config = get_config_str(2) # Each hash filter contains up to 2 elements. hash_filter = ops.create_hash_filter(300, 100, config) # we choose a key that is unique enough so they won't collide with each other. ids = tf.constant([1, 1 << 17, 2 << 17, 1], dtype=tf.int64) embedding = tf.zeros([4, 2]) loss = ops.intercept_gradient(hash_filter, ids, embedding) grad = tf.gradients(loss, embedding)[0] base_folder = os.path.join( os.environ["TEST_TMPDIR"], "test_hash_filter_save_restore_across_multiple_filters") with self.session() as sess: # save checkpoint 0 ckpt_basename_0 = os.path.join(base_folder, "hash_filter_test_0") hash_filter_save_op = ops.save_hash_filter(hash_filter, ckpt_basename_0, True) sess.run(hash_filter_save_op) # Verify checkpoint content ckpt_0_files = sorted(tf.io.gfile.glob(ckpt_basename_0 + "*")) self.assertEqual(len(ckpt_0_files), 100) for file in ckpt_0_files: dump = self._GetHashFilterSplitMetaDump(file) self.assertEqual(dump.total_size, 3) self.assertEqual(dump.num_elements, 0) self.assertEqual(dump.sliding_hash_filter_meta.split_num, 100) self.assertEqual(dump.sliding_hash_filter_meta.head, 0) self.assertEqual(dump.sliding_hash_filter_meta.head_increment, 0) # restore checkpoint 0 hash_filter_restore_op = ops.restore_hash_filter(hash_filter, ckpt_basename_0, True) sess.run(hash_filter_restore_op) grad_value = sess.run(grad) self.assertAllEqual(grad_value, [[0, 0], [0, 0], [0, 0], [0, 0]]) grad_value = sess.run(grad) self.assertAllEqual(grad_value, [[1, 1], [0, 0], [0, 0], [1, 1]]) # save checkpoint 1 ckpt_basename_1 = os.path.join(base_folder, "hash_filter_test_1") hash_filter_save_op = ops.save_hash_filter(hash_filter, ckpt_basename_1, True) sess.run(hash_filter_save_op) # verify checkpoint 1 ckpt_1_files = sorted(tf.io.gfile.glob(ckpt_basename_1 + "*")) self.assertEqual(len(ckpt_1_files), 100) for file in ckpt_1_files[:4]: dump = self._GetHashFilterSplitMetaDump(file) self.assertEqual(dump.total_size, 3) self.assertEqual(dump.num_elements, 2) self.assertEqual(dump.sliding_hash_filter_meta.split_num, 100) self.assertEqual(dump.sliding_hash_filter_meta.head, 4) self.assertEqual(dump.sliding_hash_filter_meta.head_increment, 4) for file in ckpt_1_files[4:]: dump = self._GetHashFilterSplitMetaDump(file) self.assertEqual(dump.total_size, 3) self.assertEqual(dump.num_elements, 0) self.assertEqual(dump.sliding_hash_filter_meta.split_num, 100) self.assertEqual(dump.sliding_hash_filter_meta.head, 4) self.assertEqual(dump.sliding_hash_filter_meta.head_increment, 4) # restore checkpoint 1 hash_filter_restore_op = ops.restore_hash_filter(hash_filter, ckpt_basename_1, True) sess.run(hash_filter_restore_op) grad_value = sess.run(grad) self.assertAllEqual(grad_value, [[1, 1], [1, 1], [1, 1], [1, 1]]) grad_value = sess.run(grad) self.assertAllEqual(grad_value, [[1, 1], [1, 1], [1, 1], [1, 1]]) def test_dummy_hash_filter_basic(self): hash_filter = ops.create_dummy_hash_filter() # we choose a key that is unique enough so they won't collide with each other. ids = tf.constant([1, 3 << 17, 1], dtype=tf.int64) embedding = tf.zeros([3, 2]) loss = ops.intercept_gradient(hash_filter, ids, embedding) grad = tf.gradients(loss, embedding)[0] with self.session() as sess: grad_value = sess.run(grad) self.assertAllEqual(grad_value, [[1, 1], [1, 1], [1, 1]]) def test_dummy_hash_filter_save_restore(self): basename = "dummy_hash_filter" hash_filter = ops.create_dummy_hash_filter() # we choose a key that is unique enough so they won't collide with each other. ids = tf.constant([1, 3 << 17, 1], dtype=tf.int64) embedding = tf.zeros([3, 2]) loss = ops.intercept_gradient(hash_filter, ids, embedding) grad = tf.gradients(loss, embedding)[0] with self.session() as sess: grad_value = sess.run(grad) self.assertAllEqual(grad_value, [[1, 1], [1, 1], [1, 1]]) hash_filter_save_op = ops.save_hash_filter(hash_filter, basename, False) sess.run(hash_filter_save_op) self.assertEqual(self._count_files(basename), 0) grad_value = sess.run(grad) self.assertAllEqual(grad_value, [[1, 1], [1, 1], [1, 1]]) hash_filter_restore_op = ops.restore_hash_filter(hash_filter, basename, False) hash_filter_restore_op = ops.restore_hash_filter(hash_filter, basename, False) self.assertAllEqual(grad_value, [[1, 1], [1, 1], [1, 1]]) self.assertEqual(self._count_files(basename), 0) def test_restore_not_found(self): with self.session() as sess: non_existent_files = os.path.join(os.environ["TEST_TMPDIR"], "test_restore_not_found", "hash_filters") config = get_config_str(2) hash_filter = ops.create_hash_filter(300, 7, config) restore_op = ops.restore_hash_filter(hash_filter, non_existent_files, True) with self.assertRaises(Exception): sess.run(restore_op) if __name__ == "__main__": tf.compat.v1.disable_eager_execution() tf.test.main() ================================================ FILE: monolith/native_training/hash_table_ops.proto ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. syntax="proto2"; package monolith; message HashTableProto { enum OptionalBool { kBoolNone = -1; kFalse = 0; kTrue = 1; }; optional string table_tensor = 1; optional int32 dim_size = 2; optional string shared_name = 3; optional bytes slot_expire_time_config = 4; optional string learning_rate_tensor = 5; optional int32 saver_parallel = 6; repeated string extra_restore_names = 7; optional OptionalBool export_share_embedding = 8; } ================================================ FILE: monolith/native_training/hash_table_ops.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import abc import copy import concurrent.futures import dataclasses import hashlib import os import threading from typing import Tuple, Union, Dict, List from collections import defaultdict from absl import logging from google.protobuf import text_format import tensorflow as tf from tensorflow.python.framework import ops from monolith.native_training import basic_restore_hook from monolith.native_training import entry from monolith.native_training import hash_filter_ops from monolith.native_training import distributed_serving_ops from monolith.native_training import graph_meta from monolith.native_training import hash_table_ops_pb2 from monolith.native_training.runtime.ops import gen_monolith_ops from monolith.native_training import save_utils from monolith.native_training.hash_table_utils import infer_dim_size from monolith.utils import get_libops_path from monolith.native_training.model_export.export_context import \ is_exporting from monolith.native_training.runtime.hash_table import \ embedding_hash_table_pb2 hash_table_ops = gen_monolith_ops _TIMEOUT_IN_MS = 60 * 60 * 1000 class BaseHashTable(abc.ABC): """ The base class for the hash table. For the write operation, it will return a new HashTable. This makes it easier to chain operations that need to use the updated tables. User can use this behavior to balance the parallelism and the data freshness. """ @abc.abstractmethod def assign(self, ids: tf.Tensor, values: tf.Tensor, req_time: tf.Tensor = None) -> "BaseHashTable": """ Assign values to |id| entry in hash table. ids - a 1D tensor represents which entry should be added by value values - a 2D tensor. The first dim should equal to ids's length, the second dim should equal to hash_table's dim size. Returns updated hash table. """ pass @abc.abstractmethod def assign_add(self, ids: tf.Tensor, values: tf.Tensor, req_time: tf.Tensor = None) -> "BaseHashTable": """ Assign add values to |id| entry in hash table. ids - a 1D tensor represents which entry should be added by value values - a 2D tensor. The first dim should equal to ids's length, the second dim should equal to hash_table's dim size. Returns updated hash table. """ pass @abc.abstractmethod def lookup(self, ids: tf.Tensor, use_multi_threads=False, enable_dedup=False) -> tf.Tensor: """ Look up the embeddings in hash table. The embedding will be summed up in the same batch. ids - a 1D int64 tensor use_multi_threads - True if the caller wants to lookup using multi-threads. enable_dedup - True if the caller wants to lookup without duplicate ids Returns a 2-D tensor which maps id to embeddings. """ pass @property @abc.abstractmethod def dim_size(self): pass @abc.abstractmethod def apply_gradients(self, ids: tf.Tensor, grads: tf.Tensor, global_step: tf.Tensor, use_multi_threads=False, enable_dedup=False, req_time: tf.Tensor = None) -> "BaseHashTable": """Applies the gradients with respect to the ids.""" pass @abc.abstractmethod def as_op(self) -> Union[tf.Tensor, tf.Operation]: """ Convert hash table to an op or tensor. Useful to do the dependency control. """ pass _HASH_TABLE_GRAPH_KEY = "monolith_hash_tables" @dataclasses.dataclass class HashTableMetadata: name_set: set = dataclasses.field(default_factory=set) tensor_table_to_obj_dict: Dict = dataclasses.field(default_factory=dict) _BOOL_MAP = { None: hash_table_ops_pb2.HashTableProto.kBoolNone, False: hash_table_ops_pb2.HashTableProto.kFalse, True: hash_table_ops_pb2.HashTableProto.kTrue, } _BOOL_REVERSE_MAP = {v: k for k, v in _BOOL_MAP.items()} class HashTable(BaseHashTable): """ It maps a int64 to a float32 embedding. """ def __init__(self, table: tf.Tensor = None, shared_name: str = None, dim_size: int = None, slot_expire_time_config: bytes = None, learning_rate_tensor: tf.Tensor = None, saver_parallel: int = -1, extra_restore_names=None, table_proto=None, import_scope=None): if table_proto is not None: self._init_from_proto(table_proto, import_scope) return self._table = table self._dim_size = dim_size self._init_table_name = shared_name self._check_and_insert_name(shared_name) self._slot_expire_time_config = slot_expire_time_config self._learning_rate_tensor = learning_rate_tensor self._saver_parallel = saver_parallel self._extra_restore_names = extra_restore_names or [] self.export_share_embedding = None ops.get_collection_ref(_HASH_TABLE_GRAPH_KEY).append(self) def _init_from_proto(self, proto: hash_table_ops_pb2.HashTableProto = None, import_scope: str = None): g = tf.compat.v1.get_default_graph() self._table = g.as_graph_element( ops.prepend_name_scope(proto.table_tensor, import_scope)) self._dim_size = proto.dim_size self._init_table_name = proto.shared_name self._slot_expire_time_config = proto.slot_expire_time_config self._learning_rate_tensor = g.as_graph_element( ops.prepend_name_scope(proto.learning_rate_tensor, import_scope)) self._saver_parallel = proto.saver_parallel self._extra_restore_names = tuple(proto.extra_restore_names) self.export_share_embedding = _BOOL_REVERSE_MAP[ proto.export_share_embedding] def to_proto(self, export_scope=None): if (export_scope is not None and not self._table.name.startswith(export_scope)): return None proto = hash_table_ops_pb2.HashTableProto() proto.table_tensor = ops.strip_name_scope(self._table.name, export_scope) proto.dim_size = self._dim_size proto.shared_name = self._init_table_name proto.slot_expire_time_config = self._slot_expire_time_config proto.learning_rate_tensor = ops.strip_name_scope( self._learning_rate_tensor.name, export_scope) proto.saver_parallel = self._saver_parallel proto.extra_restore_names.extend(self._extra_restore_names) proto.export_share_embedding = _BOOL_MAP[self.export_share_embedding] return proto @staticmethod def from_proto(table_proto, import_scope=None): return HashTable(table_proto=table_proto, import_scope=import_scope) @classmethod def get_metadata(cls) -> HashTableMetadata: return graph_meta.get_meta("hash_table_metadata", HashTableMetadata) @classmethod def _check_and_insert_name(cls, name): meta = cls.get_metadata() if name in meta.name_set: raise ValueError("shared_name {} has already been used.".format(name)) meta.name_set.add(name) @property def table(self): """Returns table tensor.""" return self._table @property def name(self): """Return table name.""" return self._init_table_name @property def extra_restore_names(self): """Returns other possible original table names.""" return self._extra_restore_names @property def dim_size(self): """Return dim size.""" return self._dim_size """Implements BaseHashTable""" def assign(self, ids: tf.Tensor, values: tf.Tensor, req_time: tf.Tensor = None) -> "HashTable": if req_time is None: req_time = tf.constant(0, dtype=tf.int64) # Makes test easier ids = tf.convert_to_tensor(ids, tf.int64) values = tf.convert_to_tensor(values, tf.float32) return self._copy_with_new_table( hash_table_ops.monolith_hash_table_assign(self._table, ids, values, req_time)) def assign_add(self, ids: tf.Tensor, values: tf.Tensor, req_time: tf.Tensor = None) -> "HashTable": if req_time is None: req_time = tf.constant(0, dtype=tf.int64) return self._copy_with_new_table( hash_table_ops.monolith_hash_table_assign_add(self._table, ids, values, req_time)) def lookup(self, ids: tf.Tensor, use_multi_threads=False, enable_dedup=False) -> tf.Tensor: lookup_tensor = hash_table_ops.monolith_hash_table_lookup( self._table, ids, self._dim_size, use_multi_threads=use_multi_threads) return lookup_tensor def lookup_entry(self, ids: tf.Tensor) -> tf.Tensor: lookup_tensor = hash_table_ops.monolith_hash_table_lookup_entry( self._table, ids) return lookup_tensor def apply_gradients(self, ids: tf.Tensor, grads: tf.Tensor, global_step: tf.Tensor, use_multi_threads=False, enable_dedup=False, req_time: tf.Tensor = None) -> "HashTable": if req_time is None: req_time = tf.constant(0, dtype=tf.int64) updated_op = hash_table_ops.monolith_hash_table_optimize( self._table, ids, grads, self._learning_rate_tensor, req_time, global_step, use_multi_threads=use_multi_threads, enable_dedup=enable_dedup) with tf.control_dependencies([updated_op]): new_table = self._copy_with_new_table(tf.identity(self._table)) return new_table def as_op(self): return self._table def table_update(self, update_op: tf.Tensor) -> "HashTable": with tf.control_dependencies([update_op]): new_table = self._copy_with_new_table(tf.identity(self._table)) return new_table def save(self, basename: tf.Tensor, random_sleep_ms: int = 0) -> "HashTable": new_table = hash_table_ops.monolith_hash_table_save( self._table, basename, slot_expire_time_config=self._slot_expire_time_config, nshards=self._saver_parallel, random_sleep_ms=random_sleep_ms) return self._copy_with_new_table(new_table) def restore(self, basename: tf.Tensor) -> "HashTable": new_table = hash_table_ops.monolith_hash_table_restore( self._table, basename) return self._copy_with_new_table(new_table) def _copy_with_new_table(self, new_table: tf.Tensor): copied = copy.copy(self) copied.__dict__["_table"] = new_table return copied def size(self) -> tf.Tensor: return hash_table_ops.monolith_hash_table_size(self._table) def save_as_tensor(self, shard_idx, num_shards, limit, offset) -> Tuple[tf.Tensor, tf.Tensor]: """ Dumps the hash table as tensors. Args: shard_idx - the idx of shard, should be within [0, num_shards) num_shards - the number of shards we want to have. This is helpful for dumping tensor in parallel. limit - at most, how many tensors will be output. If the output dump tensor's size is less than limit, it means we finish the current shard. offset - the offset from current shard. If we want to start from begining, set it to 0. Returns 2 tensors: 1 A 0-D int64 tensor represents the new offset. 2. A 1-D string tensor which is serialized format of monolith::EntryDump. """ shard_idx = tf.convert_to_tensor(shard_idx, tf.int32) num_shards = tf.convert_to_tensor(num_shards, tf.int32) limit = tf.convert_to_tensor(limit, tf.int64) offset = tf.convert_to_tensor(offset, tf.int64) return hash_table_ops.monolith_hash_table_save_as_tensor( self._table, shard_idx, num_shards, limit, offset, name="monolith_hash_table_save_as_tensor") def fused_lookup(tables: tf.Tensor, ids: tf.Tensor, fused_slot_size: tf.Tensor, num_of_shards: int, req_time: tf.Tensor = None) -> Tuple[tf.Tensor]: """ A fused operation for lookup. This op takes a fused_ids, and fused_slot_sizes, lookup via a list of tables, and return a concatenated embedding. Several auxiluary results are also returned to simplify processing at later stages. Example: tables = [{1: [1], 2: [2]}, {3: [3, 3], 4: [4, 4]}] ids = [1, 3, 2, 4] fused_slot_size = [1, 1, 1, 1] num_of_shards = 2 After the op, the outputs are: embeddings = [1, 3, 3, 2, 4, 4] recv_splits = [3, 3] id_offsets = [0, 1, 2, 3] emb_offsets = [0, 1, 3, 4] For a setup of K tables, N shards: Args: tables: A list of tables with shape [K], it is ordered by the tables' hashed_keys. ids: A flattened IDs with shape [M], M=sum(fused_slot_size[i]). fused_slot_size: A list with shape [K*N]. num_of_shards: a integer N. Returns: embeddings: A 1-D flattened embeddings with shape [L], L=sum(embedding_sizes[i]) recv_splits: A 1-D flattened tensor with shape [N]. id_offsets: A 1-D flattened tensor wih shape [K*N], and it is an artifact used by apply_gradients. emb_offsets: A 1-D flattened tensor with shape [K*N], and it is an artifact used by apply_gradients. """ if req_time is None: req_time = tf.constant(0, tf.int64) return hash_table_ops.monolith_hash_table_fused_lookup( tables, ids, fused_slot_size, req_time, num_of_shards) def fused_apply_gradient( tables: List[tf.Tensor], ids: tf.Tensor, indices: tf.Tensor, fused_slot_size: tf.Tensor, id_grads: tf.Tensor, id_offsets: tf.Tensor, grad_offsets: tf.Tensor, learning_rate_tensors: tf.Tensor, req_time: tf.Tensor, global_step: tf.Tensor, num_of_shards: int, enable_grad_accumulation: bool = False, ): """A fused operation for applying gradients. This op takes fused ids and fused gradients, and several other positional information, and applies the gradient updates to the list of tables. Example: tables = [{1: [1], 2: [2]}, {3: [3, 3], 4: [4, 4]}] ids = [1, 3, 2, 4] fused_slot_size = [1, 1, 1, 1] id_grads = [1, 2, 2, 1, 2, 2] id_offsets = [0, 1, 2, 3] grad_offsets = [0, 1, 3, 4] learning_rate_tensors = [1, 1] req_time = time_in_seconds global_step = 1 num_of_shards = 2 After calling the op, with SGD, the output is the updated table: tables = [{1: [0], 2: [1]}, {3: [1, 1], 4: [2, 2]}] For a setup of K tables, N shards: Args: tables: A list of tables with shape [K], it is ordered by the tables' hashed_keys. ids: A flattened IDs with shape [M], M=sum(fused_slot_size[i]). fused_slot_size: A list with shape [K*N]. id_offsets: A 1-D flattened tensor wih shape [K*N], it is an intermediate artifact from fused_lookup. grad_offsets: A 1-D flattened tensor with shape [K*N], it is an intermediate artifact from fused_lookup. learning_rate_tensors: A 1-D flattened tensor wih shape [L], L=sum(learning_rate_lengths). req_time: A scalar tensor with type tf.int64. global_step: A scalar tensor with type tf.int64. num_of_shards: a integer N. enable_grad_accumulation: if enabled, the gradient accumulation is activated from the PS side for cross-shard gradients. Returns: An updated tables tensor. """ return hash_table_ops.monolith_hash_table_fused_optimize( tables, ids, indices, fused_slot_size, id_grads, id_offsets, grad_offsets, learning_rate_tensors, req_time, global_step, num_of_shards, enable_grad_accumulation) def hash_table_from_config(config: entry.HashTableConfigInstance, hash_filter: tf.Tensor = None, name_suffix="", sync_client: tf.Tensor = None, saver_parallel: int = -1) -> HashTable: table_config = embedding_hash_table_pb2.EmbeddingHashTableConfig() table_config.CopyFrom(config.table_config) assert table_config.HasField("type") table_type = table_config.WhichOneof("type") logging.info("Hash table type: {}".format(table_type)) use_gpu = table_type == "gpucuco" d = "/device:GPU:0" if use_gpu else "/device:CPU:0" if is_exporting(): table_config.entry_config.entry_type = embedding_hash_table_pb2.EntryConfig.EntryType.SERVING dim_size = infer_dim_size(config.table_config) table_config_str = table_config.SerializeToString() slot_expire_time_config = config.table_config.slot_expire_time_config.SerializeToString( ) hash_table_name = "MonolithHashTable_" + name_suffix if len(config.learning_rate_fns) != len( config.table_config.entry_config.segments): raise ValueError( "Size of learning_rate_fns and size of segments must be equal.") if hash_filter is None: with tf.device(d): hash_filter = hash_filter_ops.create_dummy_hash_filter( name_suffix=name_suffix) if sync_client is None or use_gpu: # We don't have gpu sync for now, get rid of or use_gpu if added one with tf.device(d): sync_client = distributed_serving_ops.create_dummy_sync_client() with tf.device( d ): # Merged Device is essential here to avoid affecting job task placement table_op = hash_table_ops.monolith_hash_table( name=hash_table_name, filter_handle=hash_filter, sync_client_handle=sync_client, config=table_config_str, shared_name=hash_table_name) return HashTable(table_op, shared_name=hash_table_name, dim_size=dim_size, slot_expire_time_config=slot_expire_time_config, learning_rate_tensor=config.call_learning_rate_fns(), saver_parallel=saver_parallel, extra_restore_names=config.extra_restore_names) def test_hash_table( dim_size, enable_hash_filter=False, name_suffix=None, learning_rate=1.0, occurrence_threshold=0, use_adagrad=False, expire_time=36500, # For testing, the Default expire time is 100 years. sync_client: tf.Tensor = None, extra_restore_names=None, use_gpu=False, ) -> HashTable: """ Returns a hash table which essentially is a |dim_size| float table with sgd optimizer. """ table_config = embedding_hash_table_pb2.EmbeddingHashTableConfig() if use_gpu: table_config.gpucuco.SetInParent() else: table_config.cuckoo.SetInParent() segment = table_config.entry_config.segments.add() segment.dim_size = dim_size if use_adagrad: segment.opt_config.adagrad.SetInParent() # use adagrad for gpu hash table else: segment.opt_config.sgd.SetInParent() if use_gpu: segment.init_config.ones.SetInParent() # check ones else: segment.init_config.zeros.SetInParent() segment.comp_config.fp32.SetInParent() slot_occurrence_threshold_config = embedding_hash_table_pb2.SlotOccurrenceThresholdConfig( ) slot_occurrence_threshold_config.default_occurrence_threshold = occurrence_threshold table_config.slot_expire_time_config.default_expire_time = expire_time config = entry.HashTableConfigInstance( table_config, [learning_rate], extra_restore_names=extra_restore_names) if not use_gpu: hash_filters = hash_filter_ops.create_hash_filters( 0, enable_hash_filter, slot_occurrence_threshold_config.SerializeToString()) if not name_suffix: name_suffix = tf.compat.v1.get_default_graph().unique_name("test") if not use_gpu: return hash_table_from_config(config=config, hash_filter=hash_filters[0], name_suffix=name_suffix, sync_client=sync_client) return hash_table_from_config(config=config, name_suffix=name_suffix, sync_client=sync_client) def vocab_hash_table(vocab_size: int, dim_size: int, enable_hash_filter=False, learning_rate=1.0) -> HashTable: """ Returns a hash table which essentially is a [vocab_size, dim_size] float table with sgd optimizer. """ # Here we use a hash table which is more powerful than vocab table. return test_hash_table(dim_size, enable_hash_filter, learning_rate=learning_rate) def _all_table_tensor_prefix(table: HashTable) -> List[str]: all_names = [table.name] + table._extra_restore_names return [name.replace(":", "-").replace("/", "-") for name in all_names] def _table_tensor_prefix(table: HashTable) -> str: return _all_table_tensor_prefix(table)[0] class HashTableCheckpointSaverListener(tf.estimator.CheckpointSaverListener): """Saves the hash tables when saver is run.""" def __init__(self, basename: str): """|basename| should be a file name which is same as what is passed to saver.""" super().__init__() self._helper = save_utils.SaveHelper(basename) self._table_id_to_placeholder = {} self._save_op = self._build_save_graph() def before_save(self, sess, global_step_value): """ We use before save so the checkpoint file is updated after we successfully save the hash table. """ logging.info("Starting saving hash tables.") feed_dict = {} base_dir = self._helper.get_ckpt_asset_dir( self._helper.get_ckpt_prefix(global_step_value)) tf.io.gfile.makedirs(base_dir) for table in ops.get_collection(_HASH_TABLE_GRAPH_KEY): table_basename = base_dir + _table_tensor_prefix(table) feed_dict.update( {self._table_id_to_placeholder[table.name]: table_basename}) sess.run(self._save_op, feed_dict=feed_dict, options=tf.compat.v1.RunOptions(timeout_in_ms=_TIMEOUT_IN_MS)) logging.info("Finished saving hash tables.") def _build_save_graph(self) -> tf.Operation: save_tensors = [] # This reduces disk metadata modification pressure. random_sleep_ms = 15 * len(ops.get_collection(_HASH_TABLE_GRAPH_KEY)) for table in ops.get_collection(_HASH_TABLE_GRAPH_KEY): table_basename = tf.compat.v1.placeholder(tf.string, shape=[]) self._table_id_to_placeholder.update({table.name: table_basename}) save_tensors.append( table.save(table_basename, random_sleep_ms=random_sleep_ms).table) with tf.control_dependencies(save_tensors): return tf.no_op() class HashTableCheckpointRestorerListener( basic_restore_hook.CheckpointRestorerListener): """Restores the hash tables from basename""" def __init__(self, basename: str, ps_monitor=None): super().__init__() self._basename = basename self._helper = save_utils.SaveHelper(basename) self._table_id_to_placeholder = {} self._restore_ops_per_device = self._build_restore_graph() self._ps_monitor = ps_monitor def before_restore(self, session): """ We use before restore so as to strictly control the order of restorer listeners. """ ckpt_prefix = tf.train.latest_checkpoint(os.path.dirname(self._basename)) if not ckpt_prefix: logging.info( "No checkpoint found in %s. Looking for assets(sparse only).", self._basename) # for sparse only ckpt converted from sail assets_list = tf.io.gfile.glob( os.path.join(os.path.dirname(self._basename), "*.assets")) if len(assets_list) == 0: logging.info("No assets(sparse only) found, skipping.") return elif len(assets_list) > 1: logging.info( f"Found {len(assets_list)} sparse assets of value {assets_list}, skipping." ) return asset_dir = assets_list[0] + "/" else: asset_dir = self._helper.get_ckpt_asset_dir(ckpt_prefix) logging.info("Restore hash tables from %s.", asset_dir) self._restore_from_path_prefix(session, asset_dir) logging.info("Finished restore.") def _restore_from_path_prefix(self, sess, path_prefix): def get_restore_prefix(prefixes: List[str]): for prefix in prefixes: if len(tf.io.gfile.glob(path_prefix + prefix + "*")): return prefix raise ValueError( ("Unable to find table checkpoint in '%s' for table: %s. " "Maybe the model structure has been changed."), path_prefix, repr(prefixes)) tables = tf.compat.v1.get_collection(_HASH_TABLE_GRAPH_KEY) with concurrent.futures.ThreadPoolExecutor(max_workers=50) as executor: table_to_prefix = { table.name: executor.submit(get_restore_prefix, _all_table_tensor_prefix(table)) for table in tables } for table in tables: table_to_prefix[table.name] = table_to_prefix[table.name].result() feed_dict = {} for table in tables: table_basename = path_prefix + table_to_prefix[table.name] feed_dict.update( {self._table_id_to_placeholder[table.name]: table_basename}) restore_ops_all = [] for device, restore_ops in self._restore_ops_per_device.items(): if not self._ps_monitor or self._ps_monitor.is_ps_uninitialized( sess, device): restore_ops_all.extend(restore_ops) sess.run(restore_ops_all, feed_dict=feed_dict, options=tf.compat.v1.RunOptions(timeout_in_ms=_TIMEOUT_IN_MS)) def _build_restore_graph(self): restore_ops_per_device = defaultdict(list) for table in ops.get_collection(_HASH_TABLE_GRAPH_KEY): table_basename = tf.compat.v1.placeholder(tf.string, shape=[]) self._table_id_to_placeholder.update({table.name: table_basename}) restore_op = table.restore(table_basename).as_op() restore_ops_per_device[table.table.device].append(restore_op) return restore_ops_per_device # This is for ByteDance internal use only def extract_slot_from_entry(entry: tf.Tensor, fid_v2=True): return hash_table_ops.monolith_extract_slot_from_entry(entry, fid_v2=fid_v2) class HashTableRestorerSaverListener(tf.estimator.CheckpointSaverListener): """Since we use restore to remove stale entries, we create a saver listener here.""" def __init__(self, ckpt_prefix: str): self._l = HashTableCheckpointRestorerListener(ckpt_prefix) def after_save(self, session, global_step_value): self._l.before_restore(session) ops.register_proto_function(_HASH_TABLE_GRAPH_KEY, proto_type=hash_table_ops_pb2.HashTableProto, to_proto=HashTable.to_proto, from_proto=HashTable.from_proto) ================================================ FILE: monolith/native_training/hash_table_ops_benchmark.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import time import numpy as np import tensorflow as tf import sys from monolith.native_training import hash_filter_ops import monolith.native_training.hash_table_ops as ops def _get_id_tensor(x): return tf.constant(x, dtype=tf.int64) # TODO: use tf.test.Benchmark class HashTableOpsBenchmark(tf.test.TestCase): def test_lookup(self): with tf.compat.v1.Session() as sess: len, dim_size = (10000, 32) id_tensor = _get_id_tensor([x for x in range(len)]) hash_table = ops.test_hash_table(dim_size) hash_table = hash_table.assign_add(id_tensor[:-5], tf.ones([len, dim_size])) hash_table = hash_table.assign_add(id_tensor[-5:], tf.zeros([len, dim_size])) iters = 100 embedding_one = [float(1) * iters for _ in range(32)] embedding_zero = [0 for _ in range(32)] start = time.time() _embeddings = hash_table.lookup(id_tensor) for _ in range(iters): embeddings = sess.run(_embeddings) total_wall_time = time.time() - start print('wall time: {}'.format(total_wall_time / iters)) self.assertAllClose(embeddings[:-5], [embedding_one for _ in range(len - 5)]) self.assertAllClose(embeddings[-5:], [embedding_zero for _ in range(5)]) def test_lookup_multi_thread(self): with tf.compat.v1.Session() as sess: len, dim_size = (10000, 32) id_tensor = _get_id_tensor([x for x in range(len)]) hash_table = ops.test_hash_table(dim_size) hash_table = hash_table.assign_add(id_tensor[:-5], tf.ones([len, dim_size])) hash_table = hash_table.assign_add(id_tensor[-5:], tf.zeros([len, dim_size])) iters = 100 embedding_one = [float(1) * iters for _ in range(32)] embedding_zero = [0 for _ in range(32)] start = time.time() _embeddings = hash_table.lookup(id_tensor, use_multi_threads=True) for _ in range(iters): embeddings = sess.run(_embeddings) total_wall_time = time.time() - start print('wall time(MT): {}'.format(total_wall_time / iters)) self.assertAllClose(embeddings[:-5], [embedding_one for _ in range(len - 5)]) self.assertAllClose(embeddings[-5:], [embedding_zero for _ in range(5)]) def test_basic_optimize(self): with tf.compat.v1.Session() as sess: len, dim_size = (1000000, 32) # We assume each ID is appeared 4 times. id_tensor = _get_id_tensor([x // 4 for x in range(len)]) hash_table = ops.test_hash_table(dim_size, learning_rate=0.001, use_adagrad=True) hash_table = hash_table.assign_add(id_tensor[:-5], tf.ones([len, dim_size])) hash_table = hash_table.assign_add(id_tensor[-5:], tf.zeros([len, dim_size])) start = time.time() embeddings = hash_table.lookup(id_tensor) loss = -embeddings grads = tf.gradients(loss, embeddings) hash_table = hash_table.apply_gradients(zip(grads, [embeddings])) embeddings = hash_table.lookup(id_tensor) embeddings = sess.run(embeddings) total_wall_time = time.time() - start print('wall time: {}'.format(total_wall_time)) def test_multi_threads_optimize(self): with tf.compat.v1.Session() as sess: len, dim_size = (1000000, 32) # We assume each ID is appeared 4 times. id_tensor = _get_id_tensor([x // 4 for x in range(len)]) hash_table = ops.test_hash_table(dim_size, learning_rate=0.001, use_adagrad=True) hash_table = hash_table.assign_add(id_tensor[:-5], tf.ones([len, dim_size])) hash_table = hash_table.assign_add(id_tensor[-5:], tf.zeros([len, dim_size])) start = time.time() embeddings = hash_table.lookup(id_tensor) loss = -embeddings grads = tf.gradients(loss, embeddings) hash_table = hash_table.apply_gradients(zip(grads, [embeddings]), use_multi_threads=True) embeddings = hash_table.lookup(id_tensor) embeddings = sess.run(embeddings) total_wall_time = time.time() - start print('wall time: {}'.format(total_wall_time)) def test_multi_threads_optimize_with_dedup(self): with tf.compat.v1.Session() as sess: len, dim_size = (1000000, 32) # We assume each ID is appeared 4 times. id_tensor = _get_id_tensor([x // 4 for x in range(len)]) hash_table = ops.test_hash_table(dim_size, learning_rate=0.001, use_adagrad=True) hash_table = hash_table.assign_add(id_tensor[:-5], tf.ones([len, dim_size])) hash_table = hash_table.assign_add(id_tensor[-5:], tf.zeros([len, dim_size])) start = time.time() embeddings = hash_table.lookup(id_tensor) loss = -embeddings grads = tf.gradients(loss, embeddings) hash_table = hash_table.apply_gradients(zip(grads, [embeddings]), use_multi_threads=True, enable_dedup=True) embeddings = hash_table.lookup(id_tensor) embeddings = sess.run(embeddings) total_wall_time = time.time() - start print('wall time: {}'.format(total_wall_time)) if __name__ == '__main__': tf.compat.v1.disable_eager_execution() tf.test.main() ================================================ FILE: monolith/native_training/hash_table_ops_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import datetime import os import random from typing import Dict, List import numpy as np import tensorflow as tf from tensorflow.python.training import monitored_session from monolith.native_training import basic_restore_hook from monolith.native_training import entry from monolith.native_training import hash_filter_ops from monolith.native_training import learning_rate_functions from monolith.native_training import save_utils import monolith.native_training.hash_table_ops as ops from monolith.native_training.runtime.hash_table import \ embedding_hash_table_pb2 def _get_id_tensor(x): return tf.constant(x, dtype=tf.int64) def test_hash_table_with_hash_filters(dim_size, hash_filters, name_suffix="0", learning_rate=1.0) -> ops.HashTable: """ Returns a hash table which essentially is a |dim_size| float table with sgd optimizer. """ table_config = embedding_hash_table_pb2.EmbeddingHashTableConfig() table_config.cuckoo.SetInParent() segment = table_config.entry_config.segments.add() segment.dim_size = dim_size segment.opt_config.sgd.SetInParent() segment.init_config.zeros.SetInParent() config = entry.HashTableConfigInstance(table_config, [learning_rate]) return ops.hash_table_from_config(config=config, hash_filter=hash_filters[0], name_suffix=name_suffix) def test_hash_table(*args, **kwargs): """Serialize and deserialize hash table to make sure this process works fine""" with tf.name_scope("scope") as scope: h = ops.test_hash_table(*args, **kwargs) proto = h.to_proto(export_scope=scope) return ops.HashTable.from_proto(proto, import_scope=scope) class HashTableOpsTest(tf.test.TestCase): def test_basic(self): with tf.compat.v1.Session() as sess: dim_size = 1 hash_table = ops.vocab_hash_table(3, dim_size) hash_table = hash_table.assign_add(_get_id_tensor([0, 1]), tf.ones([2, dim_size])) embeddings = hash_table.lookup(_get_id_tensor([0, 1, 2])) size = hash_table.size() embeddings, size = sess.run([embeddings, size]) self.assertAllEqual(embeddings, [[1], [1], [0]]) self.assertAllEqual(size, 2) self.assertNotEqual(hash_table.name, "MonolithHashTable") def test_assign(self): with tf.compat.v1.Session() as sess: dim_size = 1 hash_table = ops.vocab_hash_table(3, dim_size) hash_table = hash_table.assign(_get_id_tensor([0, 1]), tf.ones([2, dim_size])) embeddings1 = hash_table.lookup(_get_id_tensor([0, 1, 2])) # Ensure the second assign happens after the first lookup with tf.control_dependencies([embeddings1]): hash_table = hash_table.assign( _get_id_tensor([1]), tf.constant([5 for _ in range(dim_size)], dtype=tf.float32)) embeddings2 = hash_table.lookup(_get_id_tensor([0, 1, 2])) embeddings1, embeddings2 = sess.run([embeddings1, embeddings2]) self.assertAllEqual(embeddings1, [[1], [1], [0]]) self.assertAllEqual(embeddings2, [[1], [5], [0]]) self.assertNotEqual(hash_table.name, "MonolithHashTable") def test_lookup_entry(self): table = test_hash_table(1) updated_table = table.assign(_get_id_tensor([0, 1, 2]), [[0.1], [0.2], [0.3]]) self.evaluate(updated_table.as_op()) entry_strs = table.lookup_entry(_get_id_tensor([0, 1, 2, 3])) entry_strs = self.evaluate(entry_strs) nums = list() for i in range(3): # OK to parse dump = embedding_hash_table_pb2.EntryDump() dump.ParseFromString(entry_strs[i]) nums.append(dump.num) self.assertAllClose(nums, [[0.1], [0.2], [0.3]]) self.assertEqual(entry_strs[3], b"") def test_save_as_tensor(self): table = test_hash_table(1) updated_table = table.assign(_get_id_tensor([0, 1, 2]), [[0.1], [0.2], [0.3]]) self.evaluate(updated_table.as_op()) _, dump_str = table.save_as_tensor(0, 1, 1000, 0) dump_str = self.evaluate(dump_str) for i in range(len(dump_str)): # OK to parse dump = embedding_hash_table_pb2.EntryDump() dump.ParseFromString(dump_str[i]) def testNameConflict(self): with self.session() as sess: hash_table = test_hash_table(1, name_suffix="same_suffix") with self.assertRaises(ValueError): test_hash_table(1, name_suffix="same_suffix") def test_gradients(self): with tf.compat.v1.Session() as sess: hash_table = test_hash_table(1, learning_rate=0.1) id_tensor = _get_id_tensor([0, 0, 1]) embeddings = hash_table.lookup(id_tensor) loss = -embeddings grads = tf.gradients(loss, embeddings) global_step = _get_id_tensor(0) hash_table = hash_table.apply_gradients(id_tensor, grads[0], global_step=global_step) new_embeddings = hash_table.lookup(_get_id_tensor([0, 1])) new_embeddings = sess.run(new_embeddings) self.assertAllClose(new_embeddings, [[0.2], [0.1]]) def test_gradients_with_learning_rate_fn(self): with tf.compat.v1.Session() as sess: hash_table = test_hash_table(1, learning_rate=lambda: 0.1) id_tensor = _get_id_tensor([0, 0, 1]) embeddings = hash_table.lookup(id_tensor) loss = -embeddings grads = tf.gradients(loss, embeddings) global_step = _get_id_tensor(0) hash_table = hash_table.apply_gradients(id_tensor, grads[0], global_step=global_step) new_embeddings = hash_table.lookup(_get_id_tensor([0, 1])) new_embeddings = sess.run(new_embeddings) self.assertAllClose(new_embeddings, [[0.2], [0.1]]) def test_gradients_with_learning_rate_decay(self): with tf.compat.v1.Session() as sess: global_step = tf.compat.v1.train.get_or_create_global_step() self.evaluate(tf.compat.v1.global_variables_initializer()) self.evaluate(tf.compat.v1.assign_add(global_step, 1)) hash_table = test_hash_table( 1, learning_rate=learning_rate_functions.PolynomialDecay( initial_learning_rate=0.01, decay_steps=10, end_learning_rate=0.11)) id_tensor = _get_id_tensor([0, 0, 1]) embeddings = hash_table.lookup(id_tensor) loss = -embeddings grads = tf.gradients(loss, embeddings) hash_table = hash_table.apply_gradients(id_tensor, grads[0], global_step=global_step) new_embeddings = hash_table.lookup(_get_id_tensor([0, 1])) new_embeddings = sess.run(new_embeddings) self.assertAllClose(new_embeddings, [[0.04], [0.02]]) def test_gradients_with_dedup(self): vec_dim = 10 with tf.compat.v1.Session() as sess: hash_table = test_hash_table(vec_dim, learning_rate=0.1) id_tensor = _get_id_tensor([0, 1, 0, 1, 0]) embeddings = hash_table.lookup(id_tensor) loss = -embeddings grads = tf.gradients(loss, embeddings) global_step = _get_id_tensor(0) hash_table = hash_table.apply_gradients(id_tensor, grads[0], global_step=global_step, enable_dedup=True) new_embeddings = hash_table.lookup(_get_id_tensor([0, 1])) new_embeddings = sess.run(new_embeddings) expected_output = [[0.3 for _ in range(vec_dim)], [0.2 for _ in range(vec_dim)]] self.assertAllClose(new_embeddings, expected_output) def test_gradients_with_different_ids(self): with tf.compat.v1.Session() as sess: hash_table = test_hash_table(1, learning_rate=0.1) embeddings = hash_table.lookup(_get_id_tensor([0, 0, 1])) loss = -embeddings grads = tf.gradients(loss, embeddings) global_step = _get_id_tensor(0) hash_table = hash_table.apply_gradients(_get_id_tensor([1, 0, 1]), grads[0], global_step=global_step) new_embeddings = hash_table.lookup(_get_id_tensor([0, 1])) new_embeddings = sess.run(new_embeddings) self.assertAllClose(new_embeddings, [[0.1], [0.2]]) def test_gradients_with_hash_filter(self): with tf.compat.v1.Session() as sess: hash_table = test_hash_table(1, enable_hash_filter=True, learning_rate=0.1, occurrence_threshold=3) id_tensor = _get_id_tensor([0, 0, 1]) embeddings = hash_table.lookup(id_tensor) loss = -embeddings grads = tf.gradients(loss, embeddings) global_step = _get_id_tensor(0) hash_table = hash_table.apply_gradients(id_tensor, grads[0], global_step=global_step) expected_results = [ # occurrence_threshold=3 # id 0, first apply gradient changes count=1, first apply gradient changes count=2 # both <=3, no real update. # id 1, first apply gradient changes count=1 <= 3, no real update [[0], [0]], # id 0, first apply gradient changes count=3, first apply gradient changes count=4 # first update <= 3, second update > 3, update once # id 1, first apply gradient changes count=2 <= 3, no real update [[0.1], [0]], # id 0, first apply gradient changes count=5, first apply gradient changes count=6 # both update count > 3, update twice # id 1, first apply gradient changes count=3 <= 3, no real update [[0.3], [0.0]], # id 0, first apply gradient changes count=7, first apply gradient changes count=8 # both update count > 3, update twice # id 1, first apply gradient changes count=4 > 3, update once [[0.5], [0.1]] ] for i in range(0, 4): new_embeddings = hash_table.lookup(_get_id_tensor([0, 1])) new_embeddings = sess.run(new_embeddings) self.assertAllClose(new_embeddings, expected_results[i]) def test_save_restore(self): with self.session() as sess: hash_table = test_hash_table(1) hash_table = hash_table.assign_add( _get_id_tensor([-1, 1]), tf.constant([[1], [2]], dtype=tf.float32)) base_name = os.path.join(os.environ["TEST_TMPDIR"], "test_save_restore", "table") hash_table = hash_table.save(base_name) sess.run(hash_table.as_op()) with self.session() as sess: hash_table2 = test_hash_table(1, False) hash_table2 = hash_table2.restore(base_name) embedding = hash_table2.lookup(_get_id_tensor([-1, 1])) embedding = sess.run(embedding) self.assertAllEqual(embedding, [[1], [2]]) def test_restore_from_another_table(self): with self.session() as sess: hash_table1 = test_hash_table(1) hash_table1 = hash_table1.assign(_get_id_tensor([1]), tf.constant([[1]], dtype=tf.float32)) base_name = os.path.join(os.environ["TEST_TMPDIR"], "test_restore_from_another_table", "table") hash_table1 = hash_table1.save(base_name) sess.run(hash_table1.as_op()) hash_table2 = test_hash_table(1, extra_restore_names=[hash_table1.name]) hash_table2 = hash_table2.restore(base_name) embedding = hash_table2.lookup(_get_id_tensor([1])) embedding = sess.run(embedding) self.assertAllEqual(embedding, [[1]]) def test_save_restore_with_feature_eviction_assign_add(self): with self.session() as sess: # Default feature eviction time is expire_time. # Feature with ts older than expire_time will be evicted. expire_time = 1 hash_table = test_hash_table(dim_size=1, expire_time=expire_time) max_ts = 10000000 expire_time_in_sec = expire_time * 24 * 3600 evict_ts = max_ts - expire_time_in_sec - 1 hash_table = hash_table.assign_add(_get_id_tensor([1]), tf.constant([[1]], dtype=tf.float32), tf.constant(evict_ts, dtype=tf.int64)) # Feature with keep_ts which is newer than expire_time. # It will not be evicted after save. keep_ts = max_ts - expire_time_in_sec + 1 hash_table = hash_table.assign_add(_get_id_tensor([2]), tf.constant([[2]], dtype=tf.float32), tf.constant(keep_ts, dtype=tf.int64)) # Feature with max_ts date will be kept and also it will update the internal max_req_time. hash_table = hash_table.assign_add(_get_id_tensor([3]), tf.constant([[3]], dtype=tf.float32), tf.constant(max_ts, dtype=tf.int64)) base_name = os.path.join( os.environ["TEST_TMPDIR"], "test_save_restore_with_feature_eviction_assign_add", "table") hash_table = hash_table.save(base_name) sess.run(hash_table.as_op()) with self.session() as sess: hash_table2 = test_hash_table(1, False) hash_table2 = hash_table2.restore(base_name) embedding = hash_table2.lookup(_get_id_tensor([1, 2, 3])) embedding = sess.run(embedding) self.assertAllEqual(embedding, [[0], [2], [3]]) def test_save_restore_with_feature_eviction_apply_gradients(self): with self.session() as sess: # Default feature eviction time is expire_time. # Feature with evic_ts older than expire_time will be evicted. expire_time = 1 hash_table = test_hash_table(dim_size=1, expire_time=expire_time) max_ts = 10000000 expire_time_in_sec = expire_time * 24 * 3600 evict_ts = max_ts - expire_time_in_sec - 1 global_step = _get_id_tensor(0) hash_table = hash_table.apply_gradients(_get_id_tensor([1]), tf.constant([[1]], dtype=tf.float32), global_step, req_time=tf.constant( evict_ts, dtype=tf.int64)) # Feature with keep_ts which is newer than expire_time. # It will not be evicted after save. keep_ts = max_ts - expire_time_in_sec + 1 global_step = _get_id_tensor(0) hash_table = hash_table.apply_gradients(_get_id_tensor([2]), tf.constant([[2]], dtype=tf.float32), global_step, req_time=tf.constant( keep_ts, dtype=tf.int64)) # Feature with max_ts will be kept and also it will update the internal max_req_time. global_step = _get_id_tensor(0) hash_table = hash_table.apply_gradients(_get_id_tensor([3]), tf.constant([[3]], dtype=tf.float32), global_step, req_time=tf.constant( max_ts, dtype=tf.int64)) base_name = os.path.join( os.environ["TEST_TMPDIR"], "test_save_restore_with_feature_eviction_apply_gradients", "table") hash_table = hash_table.save(base_name) sess.run(hash_table.as_op()) with self.session() as sess: hash_table2 = test_hash_table(1, False) hash_table2 = hash_table2.restore(base_name) embedding = hash_table2.lookup(_get_id_tensor([1, 2, 3])) embedding = sess.run(embedding) self.assertAllEqual(embedding, [[0], [-2], [-3]]) def test_entry_ttl_zero(self): basename = os.path.join(os.environ["TEST_TMPDIR"], "test_entry_ttl", "table") with self.session() as sess: hash_table = test_hash_table(1, expire_time=0) hash_table = hash_table.assign_add( _get_id_tensor([-1, 1]), tf.constant([[1], [2]], dtype=tf.float32)) hash_table = hash_table.save(basename) sess.run(hash_table.as_op()) with self.session() as sess: hash_table2 = test_hash_table(1) hash_table2 = hash_table2.restore(basename) embedding = hash_table2.lookup(_get_id_tensor([-1, 1])) embedding = sess.run(embedding) self.assertAllEqual(embedding, [[0], [0]]) def test_entry_ttl_not_zero(self): basename = os.path.join(os.environ["TEST_TMPDIR"], "test_entry_ttl_not_zero", "table") with self.session() as sess: hash_table = test_hash_table(1, expire_time=60 * 60) hash_table = hash_table.assign_add( _get_id_tensor([-1, 1]), tf.constant([[1], [2]], dtype=tf.float32)) hash_table = hash_table.save(basename) sess.run(hash_table.as_op()) with self.session() as sess: hash_table2 = test_hash_table(1) hash_table2 = hash_table2.restore(basename) embedding = hash_table2.lookup(_get_id_tensor([-1, 1])) embedding = sess.run(embedding) self.assertAllEqual(embedding, [[1], [2]]) def test_entry_ttl_by_slots(self): basename = os.path.join(os.environ["TEST_TMPDIR"], "test_entry_ttl_by_slots", "table") table_config = embedding_hash_table_pb2.EmbeddingHashTableConfig() table_config.cuckoo.SetInParent() segment = table_config.entry_config.segments.add() segment.dim_size = 1 segment.opt_config.sgd.SetInParent() segment.init_config.zeros.SetInParent() table_config.slot_expire_time_config.default_expire_time = 60 * 60 slot_expire_time_1 = table_config.slot_expire_time_config.slot_expire_times.add( ) slot_expire_time_1.slot = 1 slot_expire_time_1.expire_time = 0 slot_expire_time_2 = table_config.slot_expire_time_config.slot_expire_times.add( ) slot_expire_time_2.slot = 2 slot_expire_time_2.expire_time = 1 hash_filters = hash_filter_ops.create_hash_filters(0, False) config = entry.HashTableConfigInstance(table_config, [1.0]) with self.session() as sess: id_1 = (1 << 48) id_2 = (2 << 48) name_suffix = tf.compat.v1.get_default_graph().unique_name("") hash_table = ops.hash_table_from_config(config, hash_filter=hash_filters[0], name_suffix=name_suffix) hash_table = hash_table.assign_add( _get_id_tensor([id_1, id_2]), tf.constant([[1], [2]], dtype=tf.float32), tf.constant(100, dtype=tf.int64)) hash_table = hash_table.save(basename) sess.run(hash_table.as_op()) basename_new = os.path.join(os.environ["TEST_TMPDIR"], "test_entry_ttl_by_slots", "table_new") with self.session() as sess: name_suffix = tf.compat.v1.get_default_graph().unique_name("") hash_table2 = ops.hash_table_from_config(config, hash_filter=hash_filters[0], name_suffix=name_suffix) hash_table2 = hash_table2.restore(basename) embedding_2 = hash_table2.lookup(_get_id_tensor([id_1, id_2])) embedding_2 = sess.run(embedding_2) hash_table2 = hash_table2.save(basename_new) sess.run(hash_table2.as_op()) self.assertAllEqual(embedding_2, [[0], [2]]) with self.session() as sess: hash_table3 = test_hash_table(1) hash_table3 = hash_table3.restore(basename_new) embedding_3 = hash_table3.lookup(_get_id_tensor([id_1, id_2])) self.assertAllEqual(embedding_3, [[0], [2]]) def test_restore_not_found(self): with self.session() as sess: non_existent_files = os.path.join(os.environ["TEST_TMPDIR"], "test_restore_not_found", "table") hash_table2 = test_hash_table(1) hash_table2 = hash_table2.restore(non_existent_files) with self.assertRaises(Exception): sess.run(hash_table2.as_op()) def test_save_restore_hook(self): basename = os.path.join(os.environ["TEST_TMPDIR"], "test_save_restore_hook", "model.ckpt") hash_filter = hash_filter_ops.create_dummy_hash_filter() hash_table = test_hash_table(1) add_op = hash_table.assign_add(_get_id_tensor([0]), tf.constant([[1]], dtype=tf.float32)).as_op() sub_op = hash_table.assign_add(_get_id_tensor([0]), tf.constant([[-1]], dtype=tf.float32)).as_op() embedding = hash_table.lookup(_get_id_tensor([0])) saver_listener = ops.HashTableCheckpointSaverListener(basename) # We need to create some variables to make saver happy. tf.compat.v1.train.create_global_step() saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables(), sharded=True, max_to_keep=10, keep_checkpoint_every_n_hours=2) saver_hook = tf.estimator.CheckpointSaverHook(os.path.dirname(basename), save_steps=1000, saver=saver, listeners=[saver_listener]) restorer_listener = ops.HashTableCheckpointRestorerListener(basename) restore_hook = basic_restore_hook.CheckpointRestorerHook( listeners=[restorer_listener]) with self.session() as sess: saver_hook.begin() sess.run(tf.compat.v1.global_variables_initializer()) # In the estimator API, graph will be finalized before calling hook g = tf.compat.v1.get_default_graph() g.finalize() sess.run(add_op) saver_hook.after_create_session(sess, None) sess.run(sub_op) # restore will override sub_op restore_hook.after_create_session(sess, None) embedding = sess.run(embedding) self.assertAllEqual(embedding, [[1]]) def test_restore_after_save(self): ckpt_prefix = os.path.join(os.environ["TEST_TMPDIR"], "test_restore_after_save", "model.ckpt") hash_table = test_hash_table(1) assign_1_op = hash_table.assign(_get_id_tensor([0]), tf.constant([[1]], dtype=tf.float32)).as_op() assign_2_op = hash_table.assign(_get_id_tensor([0]), tf.constant([[2]], dtype=tf.float32)).as_op() emb = hash_table.lookup(_get_id_tensor([0])) class AssignSaverListener(tf.estimator.CheckpointSaverListener): def after_save(self, session, global_step_value): session.run(assign_2_op) # We need to create some variables to make saver happy. tf.compat.v1.train.create_global_step() saver = tf.compat.v1.train.Saver() saver_hook = tf.estimator.CheckpointSaverHook( os.path.dirname(ckpt_prefix), save_steps=100, saver=saver, listeners=[ ops.HashTableCheckpointSaverListener(ckpt_prefix), AssignSaverListener(), ops.HashTableRestorerSaverListener(ckpt_prefix) ]) with self.session() as sess: saver_hook.begin() sess.run(tf.compat.v1.global_variables_initializer()) sess.run(assign_1_op) saver_hook.after_create_session(sess, None) self.assertAllEqual([[1]], sess.run(emb)) def test_save_restore_hook_with_feature_eviction_assign_add(self): basename = os.path.join( os.environ["TEST_TMPDIR"], "test_save_restore_hook_with_feature_eviction_assign_add", "model.ckpt") hash_filter = hash_filter_ops.create_dummy_hash_filter() # Default feature eviction time is expire_time. # Feature with ts older than expire_time will be evicted. expire_time = 1 hash_table = test_hash_table(dim_size=1, expire_time=expire_time) max_ts = 10000000 expire_time_in_sec = expire_time * 24 * 3600 evict_ts = max_ts - expire_time_in_sec - 1 assign_op_1 = hash_table.assign_add(_get_id_tensor([1]), tf.constant([[1]], dtype=tf.float32), tf.constant(evict_ts, dtype=tf.int64)).as_op() keep_ts = max_ts - expire_time_in_sec + 1 assign_op_2 = hash_table.assign_add(_get_id_tensor([2]), tf.constant([[2]], dtype=tf.float32), tf.constant(keep_ts, dtype=tf.int64)).as_op() assign_op_3 = hash_table.assign_add(_get_id_tensor([3]), tf.constant([[3]], dtype=tf.float32), tf.constant(max_ts, dtype=tf.int64)).as_op() embedding = hash_table.lookup(_get_id_tensor([1, 2, 3])) saver_listener = ops.HashTableCheckpointSaverListener(basename) # We need to create some variables to make saver happy. tf.compat.v1.train.create_global_step() saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables(), sharded=True, max_to_keep=10, keep_checkpoint_every_n_hours=2) saver_hook = tf.estimator.CheckpointSaverHook(os.path.dirname(basename), save_steps=1000, saver=saver, listeners=[saver_listener]) restorer_listener = ops.HashTableCheckpointRestorerListener(basename) restore_hook = basic_restore_hook.CheckpointRestorerHook( listeners=[restorer_listener]) with self.session() as sess: saver_hook.begin() sess.run(tf.compat.v1.global_variables_initializer()) # In the estimator API, graph will be finalized before calling hook g = tf.compat.v1.get_default_graph() g.finalize() sess.run(assign_op_1) sess.run(assign_op_2) sess.run(assign_op_3) embedding_values = sess.run(embedding) self.assertAllEqual(embedding_values, [[1], [2], [3]]) saver_hook.after_create_session(sess, None) restore_hook.after_create_session(sess, None) embedding_values = sess.run(embedding) self.assertAllEqual(embedding, [[0], [2], [3]]) def test_save_restore_hook_with_feature_eviction_apply_gradients(self): basename = os.path.join( os.environ["TEST_TMPDIR"], "test_save_restore_hook_with_feature_eviction_apply_gradients", "model.ckpt") hash_filter = hash_filter_ops.create_dummy_hash_filter() # Default feature eviction time is expire_time. # Feature with ts older than expire_time will be evicted. expire_time = 1 hash_table = test_hash_table(dim_size=1, expire_time=expire_time) max_ts = 10000000 expire_time_in_sec = expire_time * 24 * 3600 evict_ts = max_ts - expire_time_in_sec - 1 global_step = _get_id_tensor(0) assign_op_1 = hash_table.apply_gradients( _get_id_tensor([1]), tf.constant([[1]], dtype=tf.float32), global_step, req_time=tf.constant(evict_ts, dtype=tf.int64)).as_op() ts_to_keep = max_ts - expire_time_in_sec + 1 global_step = _get_id_tensor(0) assign_op_2 = hash_table.apply_gradients( _get_id_tensor([2]), tf.constant([[2]], dtype=tf.float32), global_step, req_time=tf.constant(ts_to_keep, dtype=tf.int64)).as_op() global_step = _get_id_tensor(0) assign_op_3 = hash_table.apply_gradients( _get_id_tensor([3]), tf.constant([[3]], dtype=tf.float32), global_step, req_time=tf.constant(max_ts, dtype=tf.int64)).as_op() embedding = hash_table.lookup(_get_id_tensor([1, 2, 3])) saver_listener = ops.HashTableCheckpointSaverListener(basename) # We need to create some variables to make saver happy. tf.compat.v1.train.create_global_step() saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables(), sharded=True, max_to_keep=10, keep_checkpoint_every_n_hours=2) saver_hook = tf.estimator.CheckpointSaverHook(os.path.dirname(basename), save_steps=1000, saver=saver, listeners=[saver_listener]) restorer_listener = ops.HashTableCheckpointRestorerListener(basename) restore_hook = basic_restore_hook.CheckpointRestorerHook( listeners=[restorer_listener]) with self.session() as sess: saver_hook.begin() sess.run(tf.compat.v1.global_variables_initializer()) # In the estimator API, graph will be finalized before calling hook g = tf.compat.v1.get_default_graph() g.finalize() sess.run(assign_op_1) sess.run(assign_op_2) sess.run(assign_op_3) embedding_values = sess.run(embedding) self.assertAllEqual(embedding_values, [[-1], [-2], [-3]]) saver_hook.after_create_session(sess, None) restore_hook.after_create_session(sess, None) embedding_values = sess.run(embedding) self.assertAllEqual(embedding, [[0], [-2], [-3]]) def test_save_restore_hook_with_no_req_time_feature_eviction_apply_gradients( self): basename = os.path.join( os.environ["TEST_TMPDIR"], "test_save_restore_hook_with_no_req_time_feature_eviction_apply_gradients", "model.ckpt") hash_filter = hash_filter_ops.create_dummy_hash_filter() hash_table = test_hash_table(dim_size=1, expire_time=1) global_step = _get_id_tensor(0) assign_op_1 = hash_table.apply_gradients( _get_id_tensor([1]), tf.constant([[1]], dtype=tf.float32), global_step).as_op() assign_op_2 = hash_table.apply_gradients( _get_id_tensor([2]), tf.constant([[2]], dtype=tf.float32), global_step).as_op() embedding = hash_table.lookup(_get_id_tensor([1, 2])) saver_listener = ops.HashTableCheckpointSaverListener(basename) # We need to create some variables to make saver happy. tf.compat.v1.train.create_global_step() saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables(), sharded=True, max_to_keep=10, keep_checkpoint_every_n_hours=2) saver_hook = tf.estimator.CheckpointSaverHook(os.path.dirname(basename), save_steps=1000, saver=saver, listeners=[saver_listener]) restorer_listener = ops.HashTableCheckpointRestorerListener(basename) restore_hook = basic_restore_hook.CheckpointRestorerHook( listeners=[restorer_listener]) with self.session() as sess: saver_hook.begin() sess.run(tf.compat.v1.global_variables_initializer()) # In the estimator API, graph will be finalized before calling hook g = tf.compat.v1.get_default_graph() g.finalize() sess.run(assign_op_1) sess.run(assign_op_2) embedding_values = sess.run(embedding) self.assertAllEqual(embedding_values, [[-1], [-2]]) saver_hook.after_create_session(sess, None) restore_hook.after_create_session(sess, None) embedding_values = sess.run(embedding) self.assertAllEqual(embedding, [[-1], [-2]]) def test_save_restore_hook_with_zero_req_time_feature_eviction_apply_gradients( self): basename = os.path.join( os.environ["TEST_TMPDIR"], "test_save_restore_hook_with_zero_req_time_feature_eviction_apply_gradients", "model.ckpt") hash_filter = hash_filter_ops.create_dummy_hash_filter() hash_table = test_hash_table(dim_size=1, expire_time=1) global_step = _get_id_tensor(0) assign_op_1 = hash_table.apply_gradients(_get_id_tensor([1]), tf.constant([[1]], dtype=tf.float32), global_step, req_time=tf.constant( 0, dtype=tf.int64)).as_op() global_step = _get_id_tensor(0) assign_op_2 = hash_table.apply_gradients(_get_id_tensor([2]), tf.constant([[2]], dtype=tf.float32), global_step, req_time=tf.constant( 0, dtype=tf.int64)).as_op() embedding = hash_table.lookup(_get_id_tensor([1, 2])) saver_listener = ops.HashTableCheckpointSaverListener(basename) # We need to create some variables to make saver happy. tf.compat.v1.train.create_global_step() saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables(), sharded=True, max_to_keep=10, keep_checkpoint_every_n_hours=2) saver_hook = tf.estimator.CheckpointSaverHook(os.path.dirname(basename), save_steps=1000, saver=saver, listeners=[saver_listener]) restorer_listener = ops.HashTableCheckpointRestorerListener(basename) restore_hook = basic_restore_hook.CheckpointRestorerHook( listeners=[restorer_listener]) with self.session() as sess: saver_hook.begin() sess.run(tf.compat.v1.global_variables_initializer()) # In the estimator API, graph will be finalized before calling hook g = tf.compat.v1.get_default_graph() g.finalize() sess.run(assign_op_1) sess.run(assign_op_2) embedding_values = sess.run(embedding) self.assertAllEqual(embedding_values, [[-1], [-2]]) saver_hook.after_create_session(sess, None) restore_hook.after_create_session(sess, None) embedding_values = sess.run(embedding) self.assertAllEqual(embedding, [[-1], [-2]]) def test_save_restore_hook_with_same_req_time_feature_eviction_apply_gradients( self): basename = os.path.join( os.environ["TEST_TMPDIR"], "test_save_restore_hook_with_same_req_time_feature_eviction_apply_gradients", "model.ckpt") hash_filter = hash_filter_ops.create_dummy_hash_filter() hash_table = test_hash_table(dim_size=1, expire_time=1) global_step = _get_id_tensor(0) assign_op_1 = hash_table.apply_gradients(_get_id_tensor([1]), tf.constant([[1]], dtype=tf.float32), global_step, req_time=tf.constant( 100, dtype=tf.int64)).as_op() global_step = _get_id_tensor(0) assign_op_2 = hash_table.apply_gradients(_get_id_tensor([2]), tf.constant([[2]], dtype=tf.float32), global_step, req_time=tf.constant( 100, dtype=tf.int64)).as_op() embedding = hash_table.lookup(_get_id_tensor([1, 2])) saver_listener = ops.HashTableCheckpointSaverListener(basename) # We need to create some variables to make saver happy. tf.compat.v1.train.create_global_step() saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables(), sharded=True, max_to_keep=10, keep_checkpoint_every_n_hours=2) saver_hook = tf.estimator.CheckpointSaverHook(os.path.dirname(basename), save_steps=1000, saver=saver, listeners=[saver_listener]) restorer_listener = ops.HashTableCheckpointRestorerListener(basename) restore_hook = basic_restore_hook.CheckpointRestorerHook( listeners=[restorer_listener]) with self.session() as sess: saver_hook.begin() sess.run(tf.compat.v1.global_variables_initializer()) # In the estimator API, graph will be finalized before calling hook g = tf.compat.v1.get_default_graph() g.finalize() sess.run(assign_op_1) sess.run(assign_op_2) embedding_values = sess.run(embedding) self.assertAllEqual(embedding_values, [[-1], [-2]]) saver_hook.after_create_session(sess, None) restore_hook.after_create_session(sess, None) embedding_values = sess.run(embedding) self.assertAllEqual(embedding, [[-1], [-2]]) def test_delete_save_path(self): basename = os.path.join(os.environ["TEST_TMPDIR"], "test_delete_save_path", "model.ckpt") helper = save_utils.SaveHelper(basename) class HashTableCheckpointRestore(ops.HashTableCheckpointRestorerListener): def restore_checkpoint(self, sess, global_step_value): path_prefix = helper.get_ckpt_asset_dir( helper.get_ckpt_prefix(global_step_value)) self._restore_from_path_prefix(sess, path_prefix) class HashFilterCheckpointRestore( hash_filter_ops.HashFilterCheckpointRestorerListener): def restore_checkpoint(self, sess, global_step_value): path_prefix = helper.get_ckpt_asset_dir( helper.get_ckpt_prefix(global_step_value)) self._restore_from_path_prefix(sess, path_prefix) config = embedding_hash_table_pb2.SlotOccurrenceThresholdConfig() config.default_occurrence_threshold = 0 enable_hash_filter = True hash_filters = hash_filter_ops.create_hash_filters( 0, enable_hash_filter, config.SerializeToString()) hash_table = test_hash_table_with_hash_filters(dim_size=1, hash_filters=hash_filters) add_op = hash_table.assign_add(_get_id_tensor([0]), tf.constant([[1]], dtype=tf.float32)).as_op() sub_op = hash_table.assign_add(_get_id_tensor([0]), tf.constant([[-1]], dtype=tf.float32)).as_op() lookup_op = hash_table.lookup(_get_id_tensor([0])) global_step = tf.compat.v1.train.get_or_create_global_step() train_op = tf.compat.v1.assign_add(global_step, 1) hash_table_saver_listener = ops.HashTableCheckpointSaverListener(basename) hash_filter_saver_listener = hash_filter_ops.HashFilterCheckpointSaverListener( basename, hash_filters, True) saver = save_utils.PartialRecoverySaver(tf.compat.v1.global_variables(), sharded=True, max_to_keep=1, keep_checkpoint_every_n_hours=2) saver_hook = save_utils.NoFirstSaveCheckpointSaverHook( os.path.dirname(basename), save_steps=1, saver=saver, listeners=[hash_table_saver_listener, hash_filter_saver_listener]) hash_table_restorer_listener = HashTableCheckpointRestore(basename) hash_filter_restorer_listener = HashFilterCheckpointRestore( basename, hash_filters, True) with tf.compat.v1.train.SingularMonitoredSession( hooks=[saver_hook], checkpoint_dir=os.path.dirname(basename)) as mon_sess: sess = mon_sess.raw_session() sess.run(add_op) # let saving happen in step 1 and step 10. mon_sess.run(train_op) for _ in range(8): sess.run(train_op) mon_sess.run(train_op) # hash table checkpoint 1 is deleted. with self.assertRaises(Exception): hash_table_restorer_listener.restore_checkpoint(sess, 1) # hash filter checkpoint 1 is deleted. with self.assertRaises(Exception): hash_filter_restorer_listener.restore_checkpoint(sess, 1) sess.run(sub_op) # checkpoint 10 is OK. hash_table_restorer_listener.restore_checkpoint(sess, 10) hash_filter_restorer_listener.restore_checkpoint(sess, 10) embedding = sess.run(lookup_op) self.assertAllEqual(embedding, [[1]]) def test_save_restore_with_hash_table_clear_logic(self): basename = os.path.join(os.environ["TEST_TMPDIR"], "test_save_restore_with_hash_table_clear_logic", "model.ckpt") hash_filter = hash_filter_ops.create_dummy_hash_filter() hash_table = test_hash_table(1) add_op_0 = hash_table.assign_add(_get_id_tensor([0]), tf.constant([[1]], dtype=tf.float32)).as_op() add_op_1 = hash_table.assign_add(_get_id_tensor([1]), tf.constant([[1]], dtype=tf.float32)).as_op() embedding_0 = hash_table.lookup(_get_id_tensor([0])) embedding_1 = hash_table.lookup(_get_id_tensor([1])) saver_listener = ops.HashTableCheckpointSaverListener(basename) # We need to create some variables to make saver happy. tf.compat.v1.train.create_global_step() saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables(), sharded=True, max_to_keep=10, keep_checkpoint_every_n_hours=2) saver_hook = tf.estimator.CheckpointSaverHook(os.path.dirname(basename), save_steps=1000, saver=saver, listeners=[saver_listener]) restorer_listener = ops.HashTableCheckpointRestorerListener(basename) restore_hook = basic_restore_hook.CheckpointRestorerHook( listeners=[restorer_listener]) with self.session() as sess: saver_hook.begin() sess.run(tf.compat.v1.global_variables_initializer()) # In the estimator API, graph will be finalized before calling hook g = tf.compat.v1.get_default_graph() g.finalize() sess.run(add_op_0) saver_hook.after_create_session(sess, None) sess.run(add_op_1) embedding_value = sess.run(embedding_1) self.assertAllEqual(embedding_value, [[1]]) restore_hook.after_create_session(sess, None) # update before save will be restored from checkpoint. embedding_value = sess.run(embedding_0) self.assertAllEqual(embedding_value, [[1]]) # update after save will not be restored from checkpoint. embedding_value = sess.run(embedding_1) self.assertAllEqual(embedding_value, [[0]]) def test_hash_table_and_hash_filter_save_restore_hook_together(self): basename = os.path.join( os.environ["TEST_TMPDIR"], "test_hash_table_and_hash_filter_save_restore_hook_together", "model.ckpt") config = embedding_hash_table_pb2.SlotOccurrenceThresholdConfig() config.default_occurrence_threshold = 2 enable_hash_filter = True hash_filters = hash_filter_ops.create_hash_filters( 0, enable_hash_filter, config.SerializeToString()) hash_table = test_hash_table_with_hash_filters(dim_size=1, hash_filters=hash_filters) add_op = hash_table.assign_add(_get_id_tensor([0]), tf.constant([[1]], dtype=tf.float32)).as_op() embedding = hash_table.lookup(_get_id_tensor([0])) hash_table_saver_listener = ops.HashTableCheckpointSaverListener(basename) hash_filter_saver_listener = hash_filter_ops.HashFilterCheckpointSaverListener( basename, hash_filters, enable_hash_filter) # We need to create some variables to make saver happy. tf.compat.v1.train.create_global_step() saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables(), sharded=True, max_to_keep=10, keep_checkpoint_every_n_hours=2) saver_hook = tf.estimator.CheckpointSaverHook( os.path.dirname(basename), save_steps=1000, saver=saver, listeners=[hash_table_saver_listener, hash_filter_saver_listener]) hash_table_restorer_listener = ops.HashTableCheckpointRestorerListener( basename) hash_filter_restorer_listener = hash_filter_ops.HashFilterCheckpointRestorerListener( basename, hash_filters, True) restore_hook = basic_restore_hook.CheckpointRestorerHook( listeners=[hash_table_restorer_listener, hash_filter_restorer_listener]) with self.session() as sess: saver_hook.begin() sess.run(tf.compat.v1.global_variables_initializer()) # In the estimator API, graph will be finalized before calling hook g = tf.compat.v1.get_default_graph() g.finalize() # add_op not actually works as count after adding in hash filter is 1. sess.run(add_op) embedding_value = sess.run(embedding) self.assertAllEqual(embedding_value, [[0]]) # save hash filter ckpt with count is 1. saver_hook.after_create_session(sess, None) embedding_value = sess.run(embedding) # add_op not actually works as count after adding in hash filter is 2. sess.run(add_op) embedding_value = sess.run(embedding) self.assertAllEqual(embedding_value, [[0]]) # add_op works as count after adding in hash filter is 3. sess.run(add_op) embedding_value = sess.run(embedding) self.assertAllEqual(embedding_value, [[1]]) # restore hash table ckpt (embedding value is 0) # and hash filter ckpt (count is 1) restore_hook.after_create_session(sess, None) embedding_value = sess.run(embedding) self.assertAllEqual(embedding_value, [[0]]) #add_op not works as count in hash filter is 2 after it restored from ckpt. sess.run(add_op) embedding_value = sess.run(embedding) self.assertAllEqual(embedding_value, [[0]]) # add_op works as count after adding in hash filter is 3. sess.run(add_op) embedding_value = sess.run(embedding) self.assertAllEqual(embedding_value, [[1]]) # restore again to test everything is good. # restore hash table ckpt (embedding value is 0) # and hash filter ckpt (count is 1) restore_hook.after_create_session(sess, None) embedding_value = sess.run(embedding) self.assertAllEqual(embedding_value, [[0]]) #add_op not works as count in hash filter is 2 after it restored from ckpt. sess.run(add_op) embedding_value = sess.run(embedding) self.assertAllEqual(embedding_value, [[0]]) # add_op works as count after adding in hash filter is 3. sess.run(add_op) embedding_value = sess.run(embedding) self.assertAllEqual(embedding_value, [[1]]) def test_two_hash_table_whose_name_is_prefix(self): with tf.compat.v1.Session() as sess: dim_size = 1 hash_table1 = test_hash_table(dim_size) hash_table2 = test_hash_table(dim_size) basename = os.path.join(os.environ["TEST_TMPDIR"], "test_two_hash_table_whose_name_is_prefix") hash_table1 = hash_table1.save(basename + "/table1") hash_table2 = hash_table2.save(basename + "/table10") sess.run([hash_table1.as_op(), hash_table2.as_op()]) hash_table1 = hash_table1.restore(basename + "/table1") hash_table2 = hash_table2.restore(basename + "/table10") sess.run([hash_table1.as_op(), hash_table2.as_op()]) def test_fused_lookup(self): with tf.compat.v1.Session() as sess: hash_tables = [] dim_sizes = [1, 1, 2] for x in range(len(dim_sizes)): dim_size = dim_sizes[x] hash_table = ops.vocab_hash_table(9, dim_size) hash_table = hash_table.assign( _get_id_tensor([0 + 3 * x, 1 + 3 * x]), tf.ones([2, dim_size]) if x % 2 == 0 else tf.zeros([2, dim_size])) hash_tables.append(hash_table) embeddings = ops.fused_lookup( [hash_table.table for hash_table in hash_tables], _get_id_tensor([0, 4, 6, 1, 3, 7]), fused_slot_size=tf.constant([1, 1, 1, 1, 1, 1]), num_of_shards=2) embeddings, recv_splits, id_offsets, emb_offsets, indices = sess.run( embeddings) self.assertAllEqual(embeddings, [1, 0, 1, 1, 1, 0, 1, 1]) self.assertAllEqual(recv_splits, [4, 4]) self.assertAllEqual(id_offsets, [0, 1, 2, 3, 4, 5, 6]) self.assertAllEqual(emb_offsets, [0, 1, 2, 4, 5, 6, 8]) def test_fused_optimize(self): with tf.compat.v1.Session() as sess: hash_tables = [] dim_sizes = [1, 2] fused_slot_size = tf.constant([1, 1, 1, 1]) ids = _get_id_tensor([0, 4, 1, 3]) for x in range(len(dim_sizes)): dim_size = dim_sizes[x] hash_table = ops.vocab_hash_table(6, dim_size) hash_table = hash_table.assign( _get_id_tensor([0 + 3 * x, 1 + 3 * x]), tf.ones([2, dim_size]) if x == 0 else tf.zeros([2, dim_size])) hash_tables.append(hash_table) hash_table_resource = [hash_table.table for hash_table in hash_tables] #embeddings=[1, 0, 0, 1, 0, 0] embeddings, recv_splits, id_offsets, emb_offsets, indices = ops.fused_lookup( hash_table_resource, ids, fused_slot_size, num_of_shards=2) new_tables = ops.fused_apply_gradient(hash_table_resource, ids, indices, fused_slot_size, tf.constant( [-1, -2, -2, -1, -2, -2], dtype=tf.float32), id_offsets, emb_offsets, tf.constant([0.1, 0.1], dtype=tf.float32), tf.constant(0, dtype=tf.int64), tf.constant(0, dtype=tf.int64), num_of_shards=2) with tf.control_dependencies(new_tables): lookup_op = ops.fused_lookup(hash_table_resource, ids, fused_slot_size, num_of_shards=2) embeddings, recv_splits, id_offsets, emb_offsets, indices = sess.run( lookup_op) self.assertAllClose(embeddings, [1.1, 0.2, 0.2, 1.1, 0.2, 0.2]) self.assertAllEqual(recv_splits, [3, 3]) self.assertAllEqual(id_offsets, [0, 1, 2, 3, 4]) self.assertAllEqual(emb_offsets, [0, 1, 3, 4, 6]) def test_batch_softmax_optimizer(self): table_config = embedding_hash_table_pb2.EmbeddingHashTableConfig() table_config.cuckoo.SetInParent() segment = table_config.entry_config.segments.add() segment.dim_size = 1 segment.opt_config.batch_softmax.SetInParent() segment.init_config.zeros.SetInParent() segment.comp_config.fp32.SetInParent() learning_rate = 0.1 config = entry.HashTableConfigInstance(table_config, [learning_rate]) with self.session() as sess: hash_table = ops.hash_table_from_config(config=config, name_suffix='batch_softmax') for global_step in range(1000): fids = list() if global_step % 5 == 0: fids.append(0) if global_step % 10 == 0: fids.append(1) if len(fids) == 0: continue id_tensor = _get_id_tensor(fids) global_step = _get_id_tensor(global_step) hash_table = hash_table.apply_gradients(id_tensor, tf.constant([0.1 for _ in fids], dtype=tf.float32), global_step=global_step) item_step_interval = hash_table.lookup(_get_id_tensor([0, 1])) item_step_interval = tf.math.maximum(item_step_interval, tf.constant([1.0], dtype=tf.float32)) item_step_interval = sess.run(item_step_interval) self.assertAllClose([1 / val for val in item_step_interval], [[0.2], [0.1]], atol=0.01) def test_extract_fid(self): entry = embedding_hash_table_pb2.EntryDump() entry.id = 1 << 48 slot_tensor = ops.extract_slot_from_entry([entry.SerializeToString()]) self.assertAllEqual(self.evaluate(slot_tensor), [1]) def test_meta_graph_export(self): table = test_hash_table(2) meta = tf.compat.v1.train.export_meta_graph() self.assertIn(ops._HASH_TABLE_GRAPH_KEY, meta.collection_def) if __name__ == '__main__': tf.compat.v1.disable_eager_execution() tf.test.main() ================================================ FILE: monolith/native_training/hash_table_utils.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Callable import tensorflow as tf from monolith.native_training.runtime.hash_table import embedding_hash_table_pb2 @tf.function def iterate_table_and_apply(table: "HashTable", apply_fn: Callable[[tf.Tensor], None], limit=1000, nshards=4, name="IterateTable"): """Iterate the hash table, and call apply_fn for each slice. Args: apply_fn - a fn that accepts a 1-D tf string which is serialized EntryDump. limit - the maximum number of strings that will be fed into apply_fn (to save the memory usage). nshards - the parallelism of calling apply_fn. """ for i in tf.range(nshards): offset = tf.constant(0, dtype=tf.int64) dump = tf.constant([], dtype=tf.string) while tf.math.equal(tf.size(dump), limit) or tf.math.equal(offset, 0): tf.autograph.experimental.set_loop_options( parallel_iterations=1, shape_invariants=[(dump, tf.TensorShape([None])), (offset, tf.TensorShape([]))]) offset, dump = table.save_as_tensor(i, nshards, limit, offset) apply_fn(dump) def infer_dim_size( config: embedding_hash_table_pb2.EmbeddingHashTableConfig) -> int: dim_size = 0 for segment in config.entry_config.segments: dim_size += segment.dim_size return dim_size ================================================ FILE: monolith/native_training/hash_table_utils_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import tensorflow as tf from monolith.native_training import hash_table_utils from monolith.native_training import hash_table_ops class HashTableUtilsTest(tf.test.TestCase): def test_iterate_table_and_apply(self): with self.session() as sess: table = hash_table_ops.test_hash_table(1) sess.run( table.assign(tf.range(100, dtype=tf.int64), [[0.0]] * 100).as_op()) count_var = tf.Variable(0) sess.run(count_var.initializer) def count_fn(dump: tf.Tensor): return count_var.assign_add(tf.size(dump), use_locking=True) sess.run( hash_table_utils.iterate_table_and_apply(table, count_fn, limit=2, nshards=10)) count = sess.run(count_var) self.assertEqual(count, 100) if __name__ == "__main__": tf.compat.v1.disable_eager_execution() tf.test.main() ================================================ FILE: monolith/native_training/hooks/BUILD ================================================ load("@rules_python//python:defs.bzl", "py_library", "py_test") load("@rules_proto//proto:defs.bzl", "proto_library") load("@com_github_grpc_grpc//bazel:python_rules.bzl", "py_proto_library") package( default_visibility = ["//visibility:public"], ) py_library( name = "session_hooks", srcs = ["session_hooks.py"], ) py_test( name = "session_hooks_test", srcs = ["session_hooks_test.py"], deps = [ ":session_hooks", ], ) proto_library( name = "ckpt_hooks_proto", srcs = ["ckpt_hooks.proto"], ) py_proto_library( name = "ckpt_hooks_py_proto", deps = [ ":ckpt_hooks_proto", ], ) py_library( name = "ckpt_hooks", srcs = ["ckpt_hooks.py"], deps = [ ":ckpt_hooks_py_proto", "//monolith/native_training:barrier_ops", "//monolith/native_training:graph_meta", ], ) py_test( name = "ckpt_hooks_test", srcs = ["ckpt_hooks_test.py"], deps = [ ":ckpt_hooks", "//monolith/native_training:save_utils", ], ) proto_library( name = "controller_hooks_proto", srcs = ["controller_hooks.proto"], ) py_proto_library( name = "controller_hooks_py_proto", deps = [ ":controller_hooks_proto", ], ) py_library( name = "controller_hooks", srcs = ["controller_hooks.py"], deps = [ ":controller_hooks_py_proto", "//monolith/native_training:barrier_ops", "//monolith/native_training:utils", ], ) py_test( name = "controller_hooks_test", srcs = ["controller_hooks_test.py"], deps = [ ":controller_hooks", ], ) py_library( name = "ckpt_info", srcs = ["ckpt_info.py"], deps = [ "//monolith/native_training:hash_table_ops", "//monolith/native_training:hash_table_utils", "//monolith/native_training:multi_hash_table_ops", "//monolith/native_training/proto:ckpt_info_py_proto", ], ) py_test( name = "ckpt_info_test", srcs = ["ckpt_info_test.py"], deps = [ ":ckpt_info", ], ) py_library( name = "hook_utils", srcs = ["hook_utils.py"], ) py_test( name = "hook_utils_test", srcs = ["hook_utils_test.py"], deps = [ ":hook_utils", ], ) py_library( name = "ps_check_hooks", srcs = ["ps_check_hooks.py"], deps = [ "//monolith/native_training:barrier_ops", "//monolith/native_training:logging_ops", "//monolith/native_training:utils", "//monolith/native_training/runtime/ops:logging_ops_py_proto", ], ) py_test( name = "ps_check_hooks_test", srcs = ["ps_check_hooks_test.py"], deps = [ ":ps_check_hooks", ], ) py_library( name = "feature_engineering_hooks", srcs = ["feature_engineering_hooks.py"], deps = [ "//idl:example_py_proto", ], ) ================================================ FILE: monolith/native_training/hooks/ckpt_hooks.proto ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. syntax="proto2"; package monolith.hooks; message WorkerCkptInfo { optional int64 global_step = 1; } ================================================ FILE: monolith/native_training/hooks/ckpt_hooks.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import dataclasses import os import time from absl import logging import tensorflow as tf from tensorflow.python.data.ops import iterator_ops from tensorflow.python.data.experimental.ops import distribute_options # Should be removed after tf2.5 from monolith.native_training.hooks import ckpt_hooks_pb2 from monolith.native_training import basic_restore_hook from monolith.native_training import barrier_ops from monolith.native_training import graph_meta @dataclasses.dataclass class Meta: info_var: tf.Variable info_var_placeholder: tf.compat.v1.placeholder info_var_assign_op: tf.Operation enable_iter_save_restore: bool = True SAVE_ACTION = "Save" def _get_meta() -> Meta: def factory(): info_var = tf.compat.v1.get_local_variable("WorkerCkptMetaInfo", dtype=tf.string, initializer="") info_var_placeholder = tf.compat.v1.placeholder(tf.string, []) info_var_assign_op = info_var.assign(info_var_placeholder) return Meta(info_var=info_var, info_var_placeholder=info_var_placeholder, info_var_assign_op=info_var_assign_op) return graph_meta.get_meta("worker_ckpt_meta", factory) def assign_ckpt_info(session: tf.compat.v1.Session, info: ckpt_hooks_pb2.WorkerCkptInfo): meta = _get_meta() session.run(meta.info_var_assign_op, feed_dict={meta.info_var_placeholder: info.SerializeToString()}) def get_ckpt_info( session: tf.compat.v1.Session) -> ckpt_hooks_pb2.WorkerCkptInfo: ckpt_info = ckpt_hooks_pb2.WorkerCkptInfo() ckpt_info.ParseFromString(session.run(_get_meta().info_var)) return ckpt_info class BarrierSaverListener(tf.estimator.CheckpointSaverListener): """During saving, set up barrier condition to block worker for chief.""" def __init__(self, barrier_op: barrier_ops.BarrierOp, wait_seconds=1, max_pending_seconds=30): self._barrier_op = barrier_op self._wait_seconds = wait_seconds self._max_pending_seconds = max_pending_seconds # Make sure meta is created. self._meta = _get_meta() self._release_barrier = False def before_save(self, session, global_step_value): assign_ckpt_info( session, ckpt_hooks_pb2.WorkerCkptInfo(global_step=global_step_value)) logging.info("Place barrier for saving.") start_time = time.time() try: self._barrier_op.place_barrier(session, action=SAVE_ACTION) self._release_barrier = True except barrier_ops.BarrierAlreadyPlacedError: logging.info("Barrier is placed by someone else already.") while not self._barrier_op.is_all_blocked(session): time.sleep(self._wait_seconds) if time.time() - start_time > self._max_pending_seconds: break unblocked_indices = self._barrier_op.get_unblocked_indices(session) if unblocked_indices: logging.info("Unblocked worker indices: {}.".format( str(unblocked_indices))) else: logging.info("All workers have been blocked.") def after_save(self, session, global_step_value): start_time = time.time() if self._release_barrier: logging.info("Remove barrier for saving.") self._barrier_op.remove_barrier(session) self._release_barrier = False class _WorkerCkptRestorerHook(tf.estimator.SessionRunHook): def __init__(self, saver: tf.compat.v1.train.Saver, model_dir: str, latest_filename: str): self._saver = saver self._model_dir = model_dir self._latest_filename = latest_filename def after_create_session(self, session, coord): latest_ckpt = tf.train.latest_checkpoint(self._model_dir, self._latest_filename) if latest_ckpt is not None and self._saver: self._saver.restore(session, latest_ckpt) else: logging.info("Skipped worker ckpt restore.") class WorkerCkptHelper: def __init__(self, model_dir: str, index: int): # Here we try to keep them as similar as tf.data.experimental.CheckpointInputPipelineHook self._model_dir = model_dir self._index = index checkpoint_prefix = "input_worker_{}".format(index) self._checkpoint_basename = checkpoint_prefix + ".ckpt" self._latest_filename = "checkpoint_" + checkpoint_prefix iterators = tf.compat.v1.get_collection(iterator_ops.GLOBAL_ITERATORS) saveables = [] if _get_meta().enable_iter_save_restore: saveables.extend([ iterator_ops._IteratorSaveable( i, i.name, external_state_policy=distribute_options.ExternalStatePolicy. IGNORE) for i in iterators ]) else: logging.info("The iterator save is disabled.") if saveables: self._saver = tf.compat.v1.train.Saver(var_list=saveables, sharded=True) else: # Saver will throw error if we try to saveables is an empty list. self._saver = None def create_save_iterator_callback(self): def callback(action: str, sess: tf.compat.v1.Session): if not action == SAVE_ACTION: return ckpt_info = get_ckpt_info(sess) try: if self._saver: self._saver.save(sess, os.path.join(self._model_dir, self._checkpoint_basename), global_step=ckpt_info.global_step, latest_filename=self._latest_filename, write_meta_graph=False) except tf.errors.UnimplementedError as e: logging.warning( "Current dataset iterators don't support save. This might be expected. %s", str(e)) return callback def create_restorer_hook(self): return _WorkerCkptRestorerHook(self._saver, self._model_dir, self._latest_filename) def disable_iterator_save_restore(): """In some situations (like in ByteDance we feed data via stdin), the input progress is not trackable by tensorflow. In this case, we should disable iterator restore since its state is inaccurate. NOTICE: this function should be called before any creation of classes in this module. """ _get_meta().enable_iter_save_restore = False ================================================ FILE: monolith/native_training/hooks/ckpt_hooks_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import time import threading from absl import logging import tensorflow as tf from monolith.native_training.hooks import ckpt_hooks from monolith.native_training.hooks import ckpt_hooks_pb2 from monolith.native_training import barrier_ops from monolith.native_training import save_utils class CountCheckpointSaverListener(tf.estimator.CheckpointSaverListener): def __init__(self): self.begin_count = 0 self.before_save_count = 0 self.after_save_count = 0 def begin(self): self.begin_count += 1 def before_save(self, session, global_step): self.before_save_count += 1 def after_save(self, session, global_step): self.after_save_count += 1 def get_counts(self): return { 'begin': self.begin_count, 'before_save': self.before_save_count, 'after_save': self.after_save_count } class FixedSessionCreator(tf.compat.v1.train.SessionCreator): def __init__(self, fixed_sess): self._sess = fixed_sess def create_session(self): return self._sess class WorkerCkptHooksTest(tf.test.TestCase): def testIteratorSaveRestore(self): model_dir = os.path.join(os.environ["TEST_TMPDIR"], "iterator_save") tf.io.gfile.makedirs(model_dir) ds = tf.data.Dataset.from_tensor_slices([0, 1, 2, 3]) it = tf.compat.v1.data.make_one_shot_iterator(ds) next_ele = it.get_next() helper = ckpt_hooks.WorkerCkptHelper(model_dir, 0) with self.session() as sess: sess.run(next_ele) ckpt_hooks.assign_ckpt_info(sess, ckpt_hooks_pb2.WorkerCkptInfo(global_step=10)) save_callback = helper.create_save_iterator_callback() save_callback(ckpt_hooks.SAVE_ACTION, sess) self.assertAllEqual(sess.run(next_ele), 1) with tf.compat.v1.train.MonitoredSession( hooks=[helper.create_restorer_hook()]) as sess: # Restore happens self.assertAllEqual(sess.run(next_ele), 1) def testNoCkpt(self): model_dir = os.path.join(os.environ["TEST_TMPDIR"], "no_ckpt") helper = ckpt_hooks.WorkerCkptHelper(model_dir, 0) with tf.compat.v1.train.MonitoredSession( hooks=[helper.create_restorer_hook()]) as sess: pass def testNoSaveables(self): model_dir = os.path.join(os.environ["TEST_TMPDIR"], "no_saveables") tf.io.gfile.makedirs(model_dir) helper = ckpt_hooks.WorkerCkptHelper(model_dir, 0) with self.session() as sess: ckpt_hooks.assign_ckpt_info(sess, ckpt_hooks_pb2.WorkerCkptInfo(global_step=10)) save_callback = helper.create_save_iterator_callback() save_callback(ckpt_hooks.SAVE_ACTION, sess) def testCkptDisabled(self): model_dir = os.path.join(os.environ["TEST_TMPDIR"], "ckpt_disabled") tf.io.gfile.makedirs(model_dir) ds = tf.data.Dataset.from_tensor_slices([0, 1, 2, 3]) it = tf.compat.v1.data.make_one_shot_iterator(ds) next_ele = it.get_next() ckpt_hooks.disable_iterator_save_restore() helper = ckpt_hooks.WorkerCkptHelper(model_dir, 0) with self.session() as sess: sess.run(next_ele) ckpt_hooks.assign_ckpt_info(sess, ckpt_hooks_pb2.WorkerCkptInfo(global_step=10)) save_callback = helper.create_save_iterator_callback() save_callback(ckpt_hooks.SAVE_ACTION, sess) self.assertAllEqual(sess.run(next_ele), 1) with tf.compat.v1.train.MonitoredSession( hooks=[helper.create_restorer_hook()]) as sess: # Restore should not happen self.assertAllEqual(sess.run(next_ele), 0) def test_saver_with_barrier(self): model_dir = os.path.join(os.environ["TEST_TMPDIR"], "saver_with_barrier") global_step = tf.compat.v1.train.get_or_create_global_step() train_op = tf.compat.v1.assign_add(global_step, 1) barrier_op = barrier_ops.BarrierOp(2, False) listener1 = ckpt_hooks.BarrierSaverListener(barrier_op) listener2 = CountCheckpointSaverListener() hook = save_utils.NoFirstSaveCheckpointSaverHook( model_dir, save_steps=1, listeners=[listener1, listener2], saver=tf.compat.v1.train.Saver()) class WaitAllWorkersHook(tf.estimator.SessionRunHook): def end(self, session): nonlocal barrier_op while not barrier_op.is_none_blocked(session): time.sleep(0.1) with tf.compat.v1.Session() as sess: g = tf.compat.v1.get_default_graph() sess.run(tf.compat.v1.global_variables_initializer()) sess.run(tf.compat.v1.local_variables_initializer()) def run(): with g.as_default(), tf.compat.v1.train.MonitoredSession( session_creator=FixedSessionCreator(sess), hooks=[hook, WaitAllWorkersHook()]) as mon_sess: mon_sess.run(train_op) worker = threading.Thread(target=run) worker.daemon = True worker.start() while not barrier_op.is_barrier_placed(sess): time.sleep(0.1) # Barrier is placed by save listener. self.assertEqual(1, sess.run(global_step)) self.assertEqual({ 'begin': 1, 'before_save': 0, 'after_save': 0, }, listener2.get_counts()) print("Start to wait") barrier_op.wait_until_barrier_removed(sess, 1) worker.join() self.assertEqual({ 'begin': 1, 'before_save': 1, 'after_save': 1, }, listener2.get_counts()) if __name__ == "__main__": tf.compat.v1.disable_eager_execution() tf.test.main() ================================================ FILE: monolith/native_training/hooks/ckpt_info.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os from typing import DefaultDict import numpy as np import tensorflow as tf from monolith.native_training import hash_table_ops from monolith.native_training import hash_table_utils from monolith.native_training import multi_hash_table_ops from monolith.native_training.proto import ckpt_info_pb2 _MAX_SLOT = 102400 class FidSlotCountSaverListener(tf.estimator.CheckpointSaverListener): def __init__(self, model_dir: str): self._model_dir = model_dir all_tables = tf.compat.v1.get_collection( hash_table_ops._HASH_TABLE_GRAPH_KEY) self.all_multi_hash_tables = tf.compat.v1.get_collection( multi_hash_table_ops._MULTI_HASH_TABLE_GRAPH_KEY) if not all_tables and not self.all_multi_hash_tables: # MultiHashTable info is collected in a different way. # This usually means the listener is created before hash table is created # Throws an error here raise ValueError( ("Unable to find hash tables. " "It may be caused by creating the listener before calling model_fn")) device_to_tables = DefaultDict(list) for table in all_tables: device_to_tables[table.table.device].append(table) self._count_vars = {} count_ops = [] for device, tables in device_to_tables.items(): with tf.device(device): device_unique_str = str(device).replace(":", "_") count_var = tf.compat.v1.get_variable( f"monolith_fid_slot_count/{device_unique_str}", shape=[_MAX_SLOT], dtype=tf.int64, initializer=tf.compat.v1.zeros_initializer(tf.int64), collections=[]) self._count_vars[device] = count_var def apply_fn(entry): slot = hash_table_ops.extract_slot_from_entry(entry) slot = tf.math.minimum(slot, _MAX_SLOT - 1) update = tf.ones_like(slot, dtype=tf.int64) index = tf.reshape(slot, [-1, 1]) scattered = tf.scatter_nd(index, update, [_MAX_SLOT]) count_var.assign_add(scattered, use_locking=True) for table in tables: count_ops.append( hash_table_utils.iterate_table_and_apply(table, apply_fn)) self._count_op = tf.group(count_ops) init_ops = [] for count_var in self._count_vars.values(): init_ops.append(count_var.initializer) self._init_op = tf.group(init_ops) def before_save(self, session, global_step_value): if self.all_multi_hash_tables: return session.run(self._init_op) session.run(self._count_op) counts = session.run(list(self._count_vars.values())) counts = np.sum(counts, axis=0) info = ckpt_info_pb2.CkptInfo() for slot, count in enumerate(counts): if count: info.slot_counts[slot] = count tf.io.gfile.makedirs(self._model_dir) with tf.io.gfile.GFile( os.path.join(self._model_dir, f"ckpt.info-{global_step_value}"), "w") as f: f.write(str(info)) ================================================ FILE: monolith/native_training/hooks/ckpt_info_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os from google.protobuf import text_format import tensorflow as tf from monolith.native_training import hash_table_ops from monolith.native_training.hooks import ckpt_info from monolith.native_training.proto import ckpt_info_pb2 class FidCountListener(tf.test.TestCase): def test_basic(self): h = hash_table_ops.test_hash_table(1) model_dir = os.path.join(os.environ["TEST_TMPDIR"], "basic") h = h.assign(tf.constant([1], dtype=tf.int64), [[0.0]]) l = ckpt_info.FidSlotCountSaverListener(model_dir) with self.session() as sess: sess.run(h.as_op()) l.before_save(sess, 0) with tf.io.gfile.GFile(os.path.join(model_dir, "ckpt.info-0")) as f: text = f.read() ckpt = ckpt_info_pb2.CkptInfo() text_format.Parse(text, ckpt) self.assertEqual(ckpt.slot_counts[0], 1) if __name__ == "__main__": tf.compat.v1.disable_eager_execution() tf.test.main() ================================================ FILE: monolith/native_training/hooks/controller_hooks.proto ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. syntax = "proto2"; package monolith; message ControllerHooksProto { enum Action { UNKNOWN = 0; TRIGGER_SAVE = 1; STOP = 2; } optional Action action = 1; } ================================================ FILE: monolith/native_training/hooks/controller_hooks.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import contextlib import os import threading import time import traceback from typing import Callable from absl import logging from google.protobuf import text_format import tensorflow as tf from monolith.native_training.hooks import controller_hooks_pb2 from monolith.native_training import barrier_ops from monolith.native_training import utils STOP_ACTION = "Stop" class ControllerHook(tf.estimator.SessionRunHook): def __init__(self, num_ps=0, barrier_op: barrier_ops.BarrierOp = None, trigger_save: Callable = None): self._barrier_op = barrier_op self._trigger_save = trigger_save device_ctx = tf.device( utils.ps_device(0)) if num_ps > 0 else contextlib.nullcontext() with tf.name_scope("monolith_controller_hook"), device_ctx: self._control_var = tf.compat.v1.get_local_variable( "control_var", initializer=[False, False], trainable=False) self._stop_op = self._control_var[0].assign(True) self._trigger_save_op = self._control_var[1].assign(True) self._reset_trigger_save_op = self._control_var[1].assign(False) @property def stop_op(self): return self._stop_op @property def trigger_save_op(self): return self._trigger_save_op def before_run(self, run_context): return tf.estimator.SessionRunArgs(self._control_var) def after_run(self, run_context, run_values): if run_values.results[0]: if self._barrier_op: self._barrier_op.place_barrier(run_context.session, action=STOP_ACTION) logging.info("Trying to stop all workers.") start_time = time.time() while time.time( ) - start_time < 30 and not self._barrier_op.is_all_blocked( run_context.session): time.sleep(2) self._barrier_op.remove_barrier(run_context.session) elif run_values.results[1]: run_context.session.run(self._reset_trigger_save_op) if self._trigger_save: self._trigger_save() class _StopHook(tf.estimator.SessionRunHook): def __init__(self, should_stop_fn): self._should_stop_fn = should_stop_fn def after_run(self, run_context, run_values): if self._should_stop_fn(): run_context.request_stop() class StopHelper: def __init__(self): self._should_stop = False def create_barrier_callback(self): def callback(action: str, sess: tf.compat.v1.Session): if action != STOP_ACTION: return self._should_stop = True logging.info("Receive the request to stop the training.") return callback def create_stop_hook(self): def should_stop(): return self._should_stop return _StopHook(should_stop) QUERY_INTERVAL = 60 class QueryActionHook(tf.estimator.SessionRunHook): def __init__(self, model_dir: str, hook: ControllerHook): self._query_path = os.path.join(model_dir, "monolith_action") self._resp_path = os.path.join(model_dir, "monolith_action_response") self._hook = hook self._session = None self._th = None self._close = threading.Event() def after_create_session(self, session, coord): self._session = session self._th = threading.Thread(name="QuertActionHookThread", target=self._query_loop, daemon=True) self._th.start() def end(self, session): self._close.set() if self._th: self._th.join() def _query_loop(self): while True: if self._close.wait(timeout=QUERY_INTERVAL): break try: self._query() except: logging.error(traceback.format_exc()) def _query(self): if not tf.io.gfile.exists(self._query_path): return with tf.io.gfile.GFile(self._query_path, "r") as f: text_proto = f.read() try: proto = controller_hooks_pb2.ControllerHooksProto() try: text_format.Parse(text_proto, proto) except text_format.ParseError as e: self._write_resp(str(e)) return if proto.action == controller_hooks_pb2.ControllerHooksProto.TRIGGER_SAVE: self._session.run(self._hook.trigger_save_op) elif proto.action == controller_hooks_pb2.ControllerHooksProto.STOP: self._session.run(self._hook.stop_op) else: self._write_resp("Unknown action: ", text_proto) return self._write_resp("OK") finally: tf.io.gfile.remove(self._query_path) def _write_resp(self, content: str): with tf.io.gfile.GFile(self._resp_path, "w") as f: f.write(content) ================================================ FILE: monolith/native_training/hooks/controller_hooks_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import time from unittest import mock import tensorflow as tf from monolith.native_training import barrier_ops from monolith.native_training.hooks import controller_hooks class ControllerHookTest(tf.test.TestCase): def testStop(self): helper = controller_hooks.StopHelper() op = barrier_ops.BarrierOp( 1, barrier_callbacks=[helper.create_barrier_callback()]) h1 = controller_hooks.ControllerHook(barrier_op=op) h2 = helper.create_stop_hook() dummy = tf.Variable(1) with tf.compat.v1.train.SingularMonitoredSession(hooks=[h1, h2]) as sess: sess.run(dummy) self.assertFalse(sess.should_stop()) sess.run(h1.stop_op) # Session might be stopped. But request_stop might be fetched before # we run stop_op, so it is possible that it is not stopped yet. if not sess.should_stop(): # Do a dummy run again. Session must be stopped after this. sess.run(dummy) self.assertTrue(sess.should_stop()) def testSave(self): trigger_save = mock.MagicMock() h = controller_hooks.ControllerHook(trigger_save=trigger_save) dummy = tf.Variable(1) with tf.compat.v1.train.SingularMonitoredSession(hooks=[h]) as sess: sess.run(h.trigger_save_op) sess.run(dummy) sess.run(dummy) trigger_save.assert_called_once() class QueryActionHookTest(tf.test.TestCase): @mock.patch("monolith.native_training.hooks.controller_hooks.QUERY_INTERVAL", 0.1) def testStop(self): model_dir = os.path.join(os.environ["TEST_TMPDIR"], "QueryActionHookTest_testStop") trigger_save = mock.MagicMock() h = controller_hooks.ControllerHook(trigger_save=trigger_save) qh = controller_hooks.QueryActionHook(model_dir, h) dummy = tf.constant(0) with tf.compat.v1.train.SingularMonitoredSession(hooks=[h, qh]) as sess: tf.io.gfile.makedirs(model_dir) query_path = os.path.join(model_dir, "monolith_action") with tf.io.gfile.GFile(query_path, "w") as f: f.write("action: TRIGGER_SAVE") now = time.time() while time.time() - now < 60 and tf.io.gfile.exists(query_path): time.sleep(0.1) sess.run(dummy) sess.run(dummy) trigger_save.assert_called_once() if __name__ == "__main__": tf.compat.v1.disable_eager_execution() tf.test.main() ================================================ FILE: monolith/native_training/hooks/feature_engineering_hooks.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import tensorflow as tf import uuid as gen_uid import os from struct import pack from absl import logging from idl.matrix.proto.example_pb2 import ExampleBatch, FeatureListType class FeatureEngineeringSaveHook(tf.estimator.SessionRunHook): def __init__(self, config, nxt_elem, cap=100): self._config = config self._nxt_elem = nxt_elem self._cap = cap def begin(self): self._batch_list = [] # List[Dict[str, tf.Tensor]] self._steps = 0 def before_run(self, run_context): self._steps += 1 # skip iter init if self._steps > 1: return tf.compat.v1.train.SessionRunArgs(self._nxt_elem) def _save_features(self): base_dir = os.path.join(self._config.model_dir, "features") try: tf.io.gfile.makedirs(base_dir) except tf.errors.OpError: pass file_path = "" if self._config.server_type == "worker" and self._config.index == 0: file_path = os.path.join(base_dir, "chief_" + str(gen_uid.uuid1()) + ".pb") else: file_path = os.path.join( base_dir, "worker" + str(self._config.index) + "_" + str(gen_uid.uuid1()) + ".pb") results = [] for batch in self._batch_list: # batch to ExampleBatch example_batch = ExampleBatch() for k, v in batch.items(): named_feature_list = example_batch.named_feature_list.add() named_feature_list.name = k named_feature_list.type = FeatureListType.INDIVIDUAL if isinstance(v, tf.compat.v1.ragged.RaggedTensorValue): lv = v.to_list() else: # np.ndarray lv = v.tolist() for fids in lv: feature = named_feature_list.feature.add() if len(fids) > 0 and isinstance(fids[0], float): feature.float_list.value.extend(fids) else: feature.fid_v2_list.value.extend(fids) example_batch.batch_size = len(lv) results.append(example_batch) with tf.io.gfile.GFile(file_path, "w") as f: for example_batch in results: ss = example_batch.SerializeToString() sz = len(ss) f.write(pack(' 1: self._batch_list.append(run_values.results) if len(self._batch_list) >= self._cap: self._save_features() self._batch_list.clear() def end(self, session): if len(self._batch_list) >= 0: self._save_features() self._batch_list.clear() ================================================ FILE: monolith/native_training/hooks/hook_utils.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import tensorflow as tf class BeforeSaveListener(tf.estimator.CheckpointSaverListener): """Only calls before save in the listener""" def __init__(self, listener: tf.estimator.CheckpointSaverListener): self._listener = listener def before_save(self, session, global_step_value): self._listener.before_save(session, global_step_value) def __repr__(self): return super().__repr__() + repr(self._listener) class AfterSaveListener(tf.estimator.CheckpointSaverListener): """Only calls after save in the listener""" def __init__(self, listener: tf.estimator.CheckpointSaverListener): self._listener = listener def after_save(self, session, global_step_value): self._listener.after_save(session, global_step_value) def __repr__(self): return super().__repr__() + repr(self._listener) ================================================ FILE: monolith/native_training/hooks/hook_utils_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import tensorflow as tf from monolith.native_training.hooks import hook_utils class HookUtilsTest(tf.test.TestCase): def testBeforeAfterSaverListener(self): # This is mainly for testing compiling base_l = tf.estimator.CheckpointSaverListener() l1 = hook_utils.BeforeSaveListener(base_l) l2 = hook_utils.AfterSaveListener(base_l) with self.session() as sess: l1.before_save(sess, 0) l1.after_save(sess, 0) l2.before_save(sess, 0) l2.after_save(sess, 0) if __name__ == "__main__": tf.test.main() ================================================ FILE: monolith/native_training/hooks/ps_check_hooks.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import threading from typing import Callable, Dict, NamedTuple from absl import logging from google.protobuf import text_format import tensorflow as tf from monolith.native_training import barrier_ops from monolith.native_training import logging_ops from monolith.native_training import utils from monolith.native_training.runtime.ops import logging_ops_pb2 def get_ps_machine_info_shared_name(index: int): return f"ps_machine_info_{index}" def _default_report(results: Dict[int, logging_ops_pb2.MachineHealthResult]): debugging_strs = [] for idx, result in results.items(): debugging_strs.append( f"PS {idx}: {text_format.MessageToString(result, as_one_line=True)}") logging.error("PS are not healthy:\n%s", "\n".join(debugging_strs)) # TODO(leqi.zou): Give some alerts class Config(NamedTuple): barrier_op: barrier_ops.BarrierOp num_ps: int ps_device_fn: Callable[[int], str] = utils.ps_device report_fn: Callable[[Dict[int, str]], None] = _default_report class _PsHealthChecker: def __init__(self, config: Config): self._config = config # self._cancel = threading.Event() self._machine_status_tensors = [] for i in range(config.num_ps): with tf.device(config.ps_device_fn(i)): handle = logging_ops.machine_info( shared_name=get_ps_machine_info_shared_name(i)) self._machine_status_tensors.append( logging_ops.check_machine_health(handle)) def create_threads(self, sess, coord: tf.train.Coordinator): # Daemon is important. It seems that if we have the error in the # after_create_session phase, the coordinator will never stop so # the process will be stuck forever. th = threading.Thread(target=self._run, args=(sess, coord), daemon=True) coord.register_thread(th) th.start() def _run(self, sess, coord: tf.train.Coordinator): while not coord.should_stop(): status_list = sess.run(self._machine_status_tensors) results = {} should_stop = False for idx, status in enumerate(status_list): if len(status) > 0: should_stop = True result = logging_ops_pb2.MachineHealthResult() result.ParseFromString(status) results[idx] = result if should_stop: self._config.report_fn(results) self._config.barrier_op.place_barrier(sess) coord.wait_for_stop() coord.wait_for_stop(timeout=30.0) class PsHealthCheckerHook(tf.estimator.SessionRunHook): def __init__(self, config: Config): self._config = config self._checker = None def begin(self): self._checker = _PsHealthChecker(self._config) def after_create_session(self, session, coord): self._checker.create_threads(session, coord) ================================================ FILE: monolith/native_training/hooks/ps_check_hooks_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import time from unittest import mock import tensorflow as tf from monolith.native_training import barrier_ops from monolith.native_training import logging_ops from monolith.native_training.hooks import ps_check_hooks from monolith.native_training.runtime.ops import logging_ops_pb2 class PrepareMachineInfoHook(tf.estimator.SessionRunHook): """Used to create machine info after session creation""" def __init__(self, machine_info): self._machine_info = machine_info def after_create_session(self, session, coord): session.run(self._machine_info) class RaiseErrorHook(tf.estimator.SessionRunHook): def __init__(self, raise_in_after_create_session=False, raise_in_before_run=False): self.raise_in_after_create_session = raise_in_after_create_session self.raise_in_before_run = raise_in_before_run self.exc = tf.errors.DeadlineExceededError(None, None, "Test exception") def after_create_session(self, session, coord): if self.raise_in_after_create_session: raise self.exc def before_run(self, run_context): if self.raise_in_before_run: print("RAISEd") raise self.exc class PsCheckHooksTest(tf.test.TestCase): def _set_up_hook(self, report_fn=None, mem_limit=1 << 60): op = barrier_ops.BarrierOp(1) report_fn = report_fn or ps_check_hooks._default_report config = ps_check_hooks.Config(barrier_op=op, num_ps=1, ps_device_fn=lambda idx: None, report_fn=report_fn) machine_info = logging_ops.machine_info( mem_limit=mem_limit, shared_name=ps_check_hooks.get_ps_machine_info_shared_name(0)) return [ PrepareMachineInfoHook(machine_info), ps_check_hooks.PsHealthCheckerHook(config) ] def test_basic(self): hooks = self._set_up_hook() with tf.compat.v1.train.SingularMonitoredSession(hooks=hooks): time.sleep(1) def test_oom(self): report_fn = mock.MagicMock() hooks = self._set_up_hook(report_fn, mem_limit=0) with tf.compat.v1.train.SingularMonitoredSession(hooks=hooks): time.sleep(1) report_fn.assert_called_once() def test_raise_in_after_create_session(self): hooks = self._set_up_hook() def run(): with tf.compat.v1.train.SingularMonitoredSession( hooks=hooks + [RaiseErrorHook(raise_in_after_create_session=True)]): pass self.assertRaises(tf.errors.DeadlineExceededError, run) def test_raise_in_before_run(self): hooks = self._set_up_hook() def run(): t = tf.constant(1.0) with tf.compat.v1.train.SingularMonitoredSession( hooks=hooks + [RaiseErrorHook(raise_in_before_run=True)]) as sess: sess.run(t) self.assertRaises(tf.errors.DeadlineExceededError, run) def test_default_report(self): # This mainly for grammar check ps_check_hooks._default_report({1: logging_ops_pb2.MachineHealthResult()}) if __name__ == "__main__": tf.compat.v1.disable_eager_execution() tf.test.main() ================================================ FILE: monolith/native_training/hooks/server/BUILD ================================================ load("@pip_deps//:requirements.bzl", "requirement") load("@rules_python//python:defs.bzl", "py_library", "py_test") load("@rules_proto//proto:defs.bzl", "proto_library") load("@com_github_grpc_grpc//bazel:python_rules.bzl", "py_grpc_library", "py_proto_library") package( default_visibility = ["//visibility:public"], ) proto_library( name = "service_proto", srcs = ["service.proto"], ) py_proto_library( name = "service_py_proto", deps = [":service_proto"], ) py_grpc_library( name = "service_py_grpc", srcs = [":service_proto"], deps = [":service_py_proto"], ) py_library( name = "constants", srcs = ["constants.py"], ) py_library( name = "server_lib", srcs = ["server_lib.py"], deps = [ ":constants", ":service_py_grpc", "//monolith/native_training:barrier_ops", "//monolith/native_training:net_utils", "//monolith/native_training:save_utils", requirement("grpcio"), ], ) py_library( name = "client_lib", srcs = ["client_lib.py"], deps = [ ":constants", ":service_py_grpc", requirement("grpcio"), ], ) py_test( name = "server_lib_test", srcs = ["server_lib_test.py"], deps = [ ":client_lib", ":server_lib", ], ) ================================================ FILE: monolith/native_training/hooks/server/client_lib.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import grpc import tensorflow as tf from monolith.native_training.hooks.server import service_pb2 from monolith.native_training.hooks.server import service_pb2_grpc from monolith.native_training.hooks.server import constants def get_stub_from_model_dir(model_dir: str): with tf.io.gfile.GFile( os.path.join(model_dir, constants.SERVER_ADDR_FILENAME), "r") as f: addr = f.read() channel = grpc.insecure_channel(addr) return service_pb2_grpc.ControllerStub(channel) ================================================ FILE: monolith/native_training/hooks/server/constants.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. SERVER_ADDR_FILENAME = "controller_server_addr.txt" ================================================ FILE: monolith/native_training/hooks/server/server_lib.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import concurrent.futures import os import socket import time import grpc import tensorflow as tf from tensorflow.python.lib.io import file_io from monolith.native_training import barrier_ops from monolith.native_training import net_utils from monolith.native_training import save_utils from monolith.native_training.hooks.server import constants from monolith.native_training.hooks.server import service_pb2 from monolith.native_training.hooks.server import service_pb2_grpc class ControllerServicer(service_pb2_grpc.ControllerServicer): def __init__(self, sess: tf.compat.v1.Session, barrier_op: barrier_ops.BarrierOp, saver_hook: save_utils.NoFirstSaveCheckpointSaverHook): self._sess = sess self._saver_hook = saver_hook self._barrier_op = barrier_op def StopTraining(self, req, ctx): try: self._barrier_op.place_barrier(self._sess) except barrier_ops.BarrierAlreadyPlacedError: ctx.abort(grpc.StatusCode.ALREADY_EXISTS, "Barrier is placed already.") return service_pb2.StopTrainingResponse() def ResumeTraining(self, req, ctx): self._barrier_op.remove_barrier(self._sess) return service_pb2.ResumeTrainingResponse() def GetBlockStatus(self, req, ctx): resp = service_pb2.GetBlockStatusResponse() blocked_indices = self._barrier_op.get_blocked_indices(self._sess) unblocked_indices = list( set(range(self._barrier_op.capacity)) - set(blocked_indices)) resp.blocked_indices.extend(blocked_indices) resp.unblocked_indices.extend(unblocked_indices) return resp def SaveCheckpoint(self, req, ctx): resp = service_pb2.SaveCheckpointResponse() self._saver_hook.trigger_save(self._sess) return resp def GetTrainingStatus(self, req, ctx): resp = service_pb2.GetTrainingStatusResponse() with self._sess.graph.as_default(): resp.global_step = self._sess.run(tf.compat.v1.train.get_global_step()) return resp class ServerHook(tf.estimator.SessionRunHook): def __init__(self, model_dir: str, barrier_op: barrier_ops.BarrierOp, saver_hook: save_utils.NoFirstSaveCheckpointSaverHook): self._model_dir = model_dir self._barrier_op = barrier_op self._saver_hook = saver_hook self._server = None def after_create_session(self, session, coord): servicer = ControllerServicer(session, self._barrier_op, self._saver_hook) self._server = grpc.server( concurrent.futures.ThreadPoolExecutor(max_workers=2)) service_pb2_grpc.add_ControllerServicer_to_server(servicer, self._server) port = self._server.add_insecure_port("[::]:0") addr = net_utils.get_local_server_addr(port) self._server.start() tf.io.gfile.makedirs(self._model_dir) file_io.atomic_write_string_to_file( os.path.join(self._model_dir, constants.SERVER_ADDR_FILENAME), addr) def end(self, session): self._server.stop(20) ================================================ FILE: monolith/native_training/hooks/server/server_lib_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import grpc import tensorflow as tf from monolith.native_training import barrier_ops from monolith.native_training import save_utils from monolith.native_training.hooks.server import client_lib from monolith.native_training.hooks.server import server_lib from monolith.native_training.hooks.server import service_pb2 from monolith.native_training.hooks.server import service_pb2_grpc class ServerTest(tf.test.TestCase): def test_basic(self): model_dir = os.path.join(os.environ["TEST_TMPDIR"], "basic") barrier = barrier_ops.BarrierOp(1) saver_hook = save_utils.NoFirstSaveCheckpointSaverHook(model_dir, save_secs=10000) server_hook = server_lib.ServerHook(model_dir, barrier, saver_hook) tf.compat.v1.train.create_global_step() with tf.compat.v1.train.SingularMonitoredSession( hooks=[server_hook, saver_hook]) as sess: stub = client_lib.get_stub_from_model_dir(model_dir) stub.StopTraining(service_pb2.StopTrainingRequest()) with self.assertRaises(grpc.RpcError): stub.StopTraining(service_pb2.StopTrainingRequest()) resp = stub.GetBlockStatus(service_pb2.GetBlockStatusRequest()) self.assertAllEqual(resp.blocked_indices, [0]) stub.ResumeTraining(service_pb2.ResumeTrainingRequest()) resp = stub.GetBlockStatus(service_pb2.GetBlockStatusRequest()) self.assertAllEqual(resp.blocked_indices, []) stub.SaveCheckpoint(service_pb2.SaveCheckpointRequest()) stub.GetTrainingStatus(service_pb2.GetTrainingStatusRequest()) if __name__ == "__main__": tf.compat.v1.disable_eager_execution() tf.test.main() ================================================ FILE: monolith/native_training/hooks/server/service.proto ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. syntax = "proto3"; message StopTrainingRequest { } message StopTrainingResponse { } message ResumeTrainingRequest { } message ResumeTrainingResponse { } message GetBlockStatusRequest { } message GetBlockStatusResponse { repeated int32 blocked_indices = 1; repeated int32 unblocked_indices = 2; } message SaveCheckpointRequest { } message SaveCheckpointResponse { } message GetTrainingStatusRequest { } message GetTrainingStatusResponse { int64 global_step = 1; } service Controller { // Requests stopping the training. All workers will be stopped gradually. rpc StopTraining(StopTrainingRequest) returns (StopTrainingResponse) { } // Requests resuming the training. All workers will be resumed gradually. rpc ResumeTraining(ResumeTrainingRequest) returns (ResumeTrainingResponse) { } // Checks the current block/unblock status. rpc GetBlockStatus(GetBlockStatusRequest) returns (GetBlockStatusResponse) { } // Triggers a on-demand checkpoint save. Can be called in any cases. For // example, // can be called immediately after StopTraining is returned. // When rpc returned successfully, a checkpoint is saved successfully. rpc SaveCheckpoint(SaveCheckpointRequest) returns (SaveCheckpointResponse) { } rpc GetTrainingStatus(GetTrainingStatusRequest) returns (GetTrainingStatusResponse) { } } ================================================ FILE: monolith/native_training/hooks/session_hooks.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import dataclasses import tensorflow as tf @dataclasses.dataclass class _Info: session: tf.compat.v1.Session = None _INFO = _Info() class SetCurrentSessionHook(tf.estimator.SessionRunHook): def after_create_session(self, session, coord): _INFO.session = session def end(self, session): _INFO.session = None def get_current_session(): """Returns the current session. If hook was added, it will return session in hook. Otherwise, it will return default session. """ if _INFO.session: return _INFO.session return tf.compat.v1.get_default_session() ================================================ FILE: monolith/native_training/hooks/session_hooks_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import tensorflow as tf from monolith.native_training.hooks import session_hooks class SessionHooksTest(tf.test.TestCase): def testBasic(self): self.assertTrue(session_hooks.get_current_session() is None) with tf.compat.v1.train.MonitoredSession( hooks=[session_hooks.SetCurrentSessionHook()]) as sess: sess.run([]) session_hooks.get_current_session() self.assertTrue(session_hooks.get_current_session() is not None) if __name__ == "__main__": tf.compat.v1.disable_eager_execution() tf.test.main() ================================================ FILE: monolith/native_training/hvd_lib.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import importlib import os import threading class _Lib: """A lib that will delay import when used.""" def __init__(self): self._lib = None self._lock = threading.Lock() @property def lib(self): with self._lock: if self._lib is None: if self.enable_bps: self._lib = importlib.import_module("byteps.tensorflow") else: self._lib = importlib.import_module("horovod.tensorflow") return self._lib @property def enable_bps(self): return int(os.getenv("MONOLITH_WITH_BYTEPS", "0")) def init(self): return self.lib.init() def rank(self): return self.lib.rank() def size(self): return self.lib.size() def allgather(self, *args, **kwargs): return self.lib.allgather(*args, **kwargs) def broadcast(self, *args, **kwargs): return self.lib.broadcast(*args, **kwargs) def BroadcastGlobalVariablesHook(self, *args, **kwargs): return self.lib.BroadcastGlobalVariablesHook(*args, **kwargs) _lib = _Lib() def __getattr__(name): """Export all method in _Lib class""" return getattr(_lib, name) ================================================ FILE: monolith/native_training/input.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import numpy as np from typing import List import tensorflow as tf def slot_to_key(slot: int): return "feature_{}".format(slot) def generate_ffm_example(vocab_sizes: List[int], length=5) -> str: """Generate a random training example.""" def _int64_feature(values): return tf.train.Feature(int64_list=tf.train.Int64List(value=values)) def _float32_feature(values): return tf.train.Feature(float_list=tf.train.FloatList(value=values)) feature = {} feature["label"] = _float32_feature([np.random.randint(low=0, high=1)]) max_vocab = max(vocab_sizes) for i, vocab_size in enumerate(vocab_sizes): num_ids = np.random.randint(length) + 1 ids = np.random.randint(max_vocab * i, max_vocab * i + vocab_size, size=num_ids).tolist() feature[slot_to_key(i)] = _int64_feature(ids) example_proto = tf.train.Example(features=tf.train.Features(feature=feature)) return example_proto.SerializeToString() ================================================ FILE: monolith/native_training/layers/BUILD ================================================ load("@rules_python//python:defs.bzl", "py_binary", "py_library", "py_test") load("@org_tensorflow//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test", "tf_kernel_library") package( default_visibility = ["//visibility:public"], ) py_library( name = "utils", srcs = ["utils.py"], srcs_version = "PY3", deps = [ "//monolith:utils", "//monolith/native_training:monolith_export", "//monolith/native_training:utils", "//monolith/native_training/summary:summary_ops", ], ) py_library( name = "add_bias", srcs = ["add_bias.py"], srcs_version = "PY3", deps = [ ":utils", ], ) py_test( name = "add_bias_test", srcs = ["add_bias_test.py"], python_version = "PY3", srcs_version = "PY3", deps = [ ":add_bias", "//monolith/core:testing_utils", ], ) py_library( name = "dense", srcs = ["dense.py"], srcs_version = "PY3", deps = [ ":utils", ], ) py_test( name = "dense_test", srcs = ["dense_test.py"], python_version = "PY3", srcs_version = "PY3", deps = [ ":dense", "//monolith/core:testing_utils", ], ) py_library( name = "advanced_activations", srcs = ["advanced_activations.py"], srcs_version = "PY3", deps = [ ":utils", ], ) py_test( name = "advanced_activations_test", srcs = ["advanced_activations_test.py"], python_version = "PY3", srcs_version = "PY3", deps = [ ":advanced_activations", "//monolith/core:testing_utils", ], ) py_library( name = "agru", srcs = ["agru.py"], srcs_version = "PY3", deps = [ ":utils", ], ) py_test( name = "agru_test", srcs = ["agru_test.py"], python_version = "PY3", srcs_version = "PY3", deps = [ ":agru", "//monolith/core:testing_utils", ], ) py_library( name = "norms", srcs = [ "norms.py", ], srcs_version = "PY3", deps = [ ":utils", ], ) py_test( name = "norms_test", srcs = ["norms_test.py"], python_version = "PY3", srcs_version = "PY3", deps = [ ":norms", "//monolith/core:testing_utils", ], ) py_library( name = "mlp", srcs = [ "mlp.py", ], srcs_version = "PY3", deps = [ ":advanced_activations", ":dense", ":norms", ":utils", ], ) py_test( name = "mlp_test", srcs = ["mlp_test.py"], python_version = "PY3", srcs_version = "PY3", deps = [ ":mlp", "//monolith/core:testing_utils", ], ) py_library( name = "feature_cross", srcs = [ "feature_cross.py", ], srcs_version = "PY3", deps = [ ":agru", ":layer_ops", ":mlp", ":utils", "//monolith/core:base_layer", ], ) py_test( name = "feature_cross_test", srcs = ["feature_cross_test.py"], python_version = "PY3", srcs_version = "PY3", deps = [ ":feature_cross", "//monolith/core:testing_utils", ], ) py_library( name = "feature_trans", srcs = [ "feature_trans.py", ], srcs_version = "PY3", deps = [ ":agru", ":mlp", ":utils", "//monolith/core:base_layer", ], ) py_test( name = "feature_trans_test", srcs = ["feature_trans_test.py"], python_version = "PY3", srcs_version = "PY3", deps = [ ":feature_trans", "//monolith/core:testing_utils", ], ) py_library( name = "feature_seq", srcs = [ "feature_seq.py", ], srcs_version = "PY3", deps = [ ":agru", ":mlp", ":utils", "//monolith/core:base_layer", ], ) py_test( name = "feature_seq_test", srcs = ["feature_seq_test.py"], python_version = "PY3", srcs_version = "PY3", deps = [ ":feature_seq", "//monolith/core:testing_utils", ], ) py_library( name = "pooling", srcs = [ "pooling.py", ], srcs_version = "PY3", deps = [ ":utils", ], ) py_test( name = "pooling_test", srcs = ["pooling_test.py"], python_version = "PY3", srcs_version = "PY3", deps = [ ":pooling", "//monolith/core:testing_utils", ], ) py_library( name = "logit_correction", srcs = [ "logit_correction.py", ], srcs_version = "PY3", deps = [ ":mlp", ":utils", ], ) py_test( name = "logit_correction_test", srcs = ["logit_correction_test.py"], python_version = "PY3", srcs_version = "PY3", deps = [ ":logit_correction", "//monolith/core:testing_utils", ], ) py_library( name = "lhuc", srcs = [ "lhuc.py", ], srcs_version = "PY3", deps = [ ":advanced_activations", ":dense", ":mlp", ":utils", ], ) py_test( name = "lhuc_test", srcs = ["lhuc_test.py"], python_version = "PY3", srcs_version = "PY3", deps = [ ":lhuc", "//monolith/core:testing_utils", ], ) py_library( name = "multi_task", srcs = [ "multi_task.py", ], srcs_version = "PY3", deps = [ ":advanced_activations", ":dense", ":mlp", ":utils", ], ) py_test( name = "multi_task_test", srcs = ["multi_task_test.py"], python_version = "PY3", srcs_version = "PY3", deps = [ ":multi_task", "//monolith/core:testing_utils", ], ) py_library( name = "layers", srcs = [ "__init__.py", ], srcs_version = "PY3", deps = [ ":add_bias", ":advanced_activations", ":agru", ":dense", ":feature_cross", ":feature_seq", ":feature_trans", ":lhuc", ":logit_correction", ":mlp", ":multi_task", ":norms", ":pooling", ":sparse_nas", "//monolith/native_training:utils", ], ) cc_library( name = "internal_kernels", alwayslink = 1, ) tf_kernel_library( name = "layer_tf_ops", srcs = [ "kernels/ffm_kernels.cc", "kernels/ffm_kernels.h", "kernels/feature_insight_kernels.cc", "kernels/fid_counter_kernel.cc", "ops/ffm_ops.cc", "ops/nas_ops.cc", "ops/feature_insight_ops.cc", "ops/fid_counter_op.cc", ], copts = ["-DNDEBUG"], gpu_srcs = [ "kernels/ffm_kernels.h", "kernels/ffm_kernels.cu.cc", ], deps = [ ":internal_kernels", "@org_tensorflow//tensorflow/core:framework_headers_lib", "@org_tensorflow//tensorflow/core/kernels:gpu_device_array_for_custom_op", ], ) py_library( name = "layer_ops", srcs = ["layer_ops.py"], deps = [ "//monolith:utils", "//monolith/core:testing_utils", "//monolith/native_training/runtime/ops:gen_monolith_ops", ], ) py_test( name = "layer_ops_test", srcs = ["layer_ops_test.py"], python_version = "PY3", srcs_version = "PY3", deps = [ ":layer_ops", ], ) py_library( name = "sparse_nas", srcs = ["sparse_nas.py"], deps = [ ":layer_ops", ":utils", "//monolith/native_training/data:feature_list", "@org_tensorflow//tensorflow:tensorflow_py", ], ) exports_files([ "kernels/ffm_kernels.cc", "kernels/ffm_kernels.cu.cc", "kernels/feature_insight_kernels.cc", "ops/ffm_ops.cc", "kernels/nas_kernels.cc", "ops/nas_ops.cc", "ops/feature_insight_ops.cc", "ops/fid_counter_ops.cc", ]) ================================================ FILE: monolith/native_training/layers/README.md ================================================ The layers in Monolith are a super set of tensorflow keras layers. Monolith adds/enhances the following layers: - Dense - MLP - AddBias - LayerNorm/GradNorm - GroupInt/AllInt/CDot/CAN/DCN/CIN/AutoInt/SeNet/iRazor/DIN/DIEN/DMR_U2I - LogitCorrection - SumPooling/AvgPooling/MaxPooling Monolith layers are compatible with keras layers, that means you can mix usage of keras layers and monolith layers. Here is an example of creating monolith layer: ```python import tensorflow as tf from monolith.native_training.layers import Dense # the first method to new a monolith layer, which is the same as keras dense = Dense(units=100, activation=tf.keras.activations.relu) # the second method to new a monlith layer dense_p = Dense.params() dense_p.units=100 dense_p.activation=tf.keras.activations.relu dense2 = dense_p.instantiate() model = tf.keras.Sequential([ dense, # create from constructor dense2, # create from new_instance tf.keras.layers.Dense(units=100, activation=tf.keras.activations.relu) # mix use ]) ``` As show above, there is two methods to create a layer, one is using constructor, the other employ `new_instance` method of `InstantiableParams`. In most case, just replace: ```python from tensorflow.keras import layers ``` with ```python from monolith.native_training import layers ``` we prefer monolith layers. ================================================ FILE: monolith/native_training/layers/__init__.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import types import tensorflow as tf from tensorflow.keras.layers import * from monolith.native_training.layers.mlp import MLP from monolith.native_training.layers.feature_cross import * from monolith.native_training.layers.feature_trans import * from monolith.native_training.layers.feature_seq import * from monolith.native_training.layers.advanced_activations import * from monolith.native_training.layers.add_bias import AddBias from monolith.native_training.layers.lhuc import LHUCTower from monolith.native_training.layers.logit_correction import LogitCorrection from monolith.native_training.layers.norms import LayerNorm, GradNorm from monolith.native_training.layers.pooling import SumPooling, AvgPooling, MaxPooling from monolith.native_training.layers.utils import MergeType, DCNType from monolith.native_training.layers.multi_task import MMoE, SNR from monolith.native_training.utils import params as _params del globals()['Dense'] from monolith.native_training.layers.dense import Dense keras_layers = {} for name in dir(tf.keras.layers): if name.startswith("_") or name == "Layer": continue cls = getattr(tf.keras.layers, name) try: if issubclass(cls, Layer) and not hasattr(cls, 'params'): cls.params = types.MethodType(_params, cls) keras_layers[name] = cls except: pass ================================================ FILE: monolith/native_training/layers/add_bias.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import tensorflow as tf from tensorflow.keras.layers import Layer, InputSpec from tensorflow.keras import initializers from tensorflow.python.keras import regularizers from monolith.native_training.utils import get_ndim, int_shape, with_params from monolith.native_training.monolith_export import monolith_export from monolith.native_training.layers.utils import check_dim, dim_size @monolith_export @with_params class AddBias(Layer): r"""AddBias 执行 :math:`y = x + b`, 与直接用`+`相比, AddBias处理了更多的shape问题 例如image有两种表示方式NWHC, NCWH, 对于时间序列也有类似的问题. AddBias可以让用户透明增加Bias >>> add_bias = AddBias(initializer=tf.initializers.Zeros()) >>> y = add_bias(x, data_format='channels_first') Args: initializer (:obj:`tf.initializer`): bias的初始化器 regularizer (:obj:`tf.regularizer`): bias的正则化器 """ def __init__(self, initializer=None, regularizer=None, **kwargs): super(AddBias, self).__init__(**kwargs) self.initializer = initializers.get(initializer) or tf.initializers.Zeros() self.regularizer = regularizers.get(regularizer) # allowed input specification self.input_spec = InputSpec(min_ndim=2) self.bias = None def build(self, input_shape): shape = list(map(check_dim, input_shape[1:])) self.bias = self.add_weight(name='bias', shape=shape, dtype=tf.float32, initializer=self.initializer, regularizer=self.regularizer) def call(self, inputs, **kwargs): data_format = kwargs.get('data_format', 'channels_last') if data_format not in {'channels_first', 'channels_last'}: raise ValueError('Unknown data_format: ' + str(data_format)) bias_shape = int_shape(self.bias) if len(bias_shape) != 1 and len(bias_shape) != get_ndim(inputs) - 1: raise ValueError( 'Unexpected bias dimensions %d, expect to be 1 or %d dimensions' % (len(bias_shape), get_ndim(inputs))) if get_ndim(inputs) == 5: if data_format == 'channels_first': if len(bias_shape) == 1: inputs += tf.reshape(self.bias, (1, bias_shape[0], 1, 1, 1)) else: inputs += tf.reshape(self.bias, (1, bias_shape[3]) + bias_shape[:3]) elif data_format == 'channels_last': if len(bias_shape) == 1: inputs += tf.reshape(self.bias, (1, 1, 1, bias_shape[0])) else: inputs += tf.reshape(self.bias, (1,) + bias_shape) elif get_ndim(inputs) == 4: if data_format == 'channels_first': if len(bias_shape) == 1: inputs += tf.reshape(self.bias, (1, bias_shape[0], 1, 1)) else: inputs += tf.reshape(self.bias, (1, bias_shape[2]) + bias_shape[:2]) elif data_format == 'channels_last': if len(bias_shape) == 1: inputs = tf.nn.bias_add(inputs, self.bias, data_format='NHWC') else: inputs += tf.reshape(self.bias, (1,) + bias_shape) elif get_ndim(inputs) == 3: if data_format == 'channels_first': if len(bias_shape) == 1: inputs += tf.reshape(self.bias, (1, bias_shape[0], 1)) else: inputs += tf.reshape(self.bias, (1, bias_shape[1], bias_shape[0])) elif data_format == 'channels_last': if len(bias_shape) == 1: inputs += tf.reshape(self.bias, (1, 1, bias_shape[0])) else: inputs += tf.reshape(self.bias, (1,) + bias_shape) else: inputs = tf.nn.bias_add(inputs, self.bias) return inputs def get_config(self): config = { 'initializer': tf.keras.initializers.serialize(self.initializer), 'regularizer': regularizers.serialize(self.regularizer), } base_config = super(AddBias, self).get_config() return dict(list(base_config.items()) + list(config.items())) ================================================ FILE: monolith/native_training/layers/add_bias_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import numpy as np import tensorflow as tf from monolith.native_training.layers.add_bias import AddBias class AddBiasTest(tf.test.TestCase): def test_ab_instantiate(self): layer_template = AddBias.params() test_params0 = layer_template.copy() test_params0.initializer = tf.initializers.Zeros() ins1 = test_params0.instantiate() print(ins1) ins2 = AddBias(initializer=tf.initializers.Zeros()) print(ins2) def test_ab_serde(self): layer_template = AddBias.params() test_params0 = layer_template.copy() test_params0.initializer = tf.initializers.Zeros() ins1 = test_params0.instantiate() print(ins1) cfg = ins1.get_config() ins2 = AddBias.from_config(cfg) print(ins1, ins2) def test_ab_call(self): layer_template = AddBias.params() test_params0 = layer_template.copy() test_params0.name = 'test_dense0' test_params0.initializer = tf.initializers.Zeros() layer = test_params0.instantiate() data = tf.keras.backend.variable(np.random.uniform(size=(100, 10))) sum_out = tf.reduce_sum(layer(data)) with self.session() as sess: sess.run(tf.compat.v1.global_variables_initializer()) print(sess.run(sum_out)) if __name__ == '__main__': tf.compat.v1.disable_eager_execution() tf.test.main() ================================================ FILE: monolith/native_training/layers/advanced_activations.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import types import tensorflow.keras.initializers as initializers import tensorflow.keras.constraints as constraints from tensorflow.python.keras.activations import exponential from tensorflow.python.keras.activations import gelu from tensorflow.python.keras.activations import hard_sigmoid from tensorflow.python.keras.activations import linear from tensorflow.python.keras.activations import selu from tensorflow.python.keras.activations import sigmoid from tensorflow.python.keras.activations import softplus from tensorflow.python.keras.activations import softsign from tensorflow.python.keras.activations import swish from tensorflow.python.keras.activations import tanh from tensorflow.python.keras.layers import Layer from tensorflow.python.keras.layers.advanced_activations import ReLU from tensorflow.python.keras.layers.advanced_activations import LeakyReLU from tensorflow.python.keras.layers.advanced_activations import ELU from tensorflow.python.keras.layers.advanced_activations import Softmax from tensorflow.python.keras.layers.advanced_activations import ThresholdedReLU from tensorflow.python.keras.layers.advanced_activations import PReLU from monolith.native_training.utils import params as _params from monolith.native_training.monolith_export import monolith_export __all__ = [ 'ReLU', 'LeakyReLU', 'ELU', 'Softmax', 'ThresholdedReLU', 'PReLU', 'Exponential', 'Gelu', 'HardSigmoid', 'Linear', 'Selu', 'Sigmoid', 'Sigmoid2', 'Softplus', 'Softsign', 'Swish', 'Tanh' ] Tanh = type('Tanh', (Layer,), {'call': lambda self, x: tanh(x)}) Sigmoid = type('Sigmoid', (Layer,), {'call': lambda self, x: sigmoid(x)}) Sigmoid2 = type('Sigmoid2', (Layer,), {'call': lambda self, x: sigmoid(x) * 2}) Linear = type('Linear', (Layer,), {'call': lambda self, x: linear(x)}) Gelu = type('Gelu', (Layer,), {'call': lambda self, x: gelu(x)}) Selu = type('Selu', (Layer,), {'call': lambda self, x: selu(x)}) Softsign = type('Softsign', (Layer,), {'call': lambda self, x: softsign(x)}) Softplus = type('Softplus', (Layer,), {'call': lambda self, x: softplus(x)}) Exponential = type('Exponential', (Layer,), {'call': lambda self, x: exponential(x)}) HardSigmoid = type('HardSigmoid', (Layer,), {'call': lambda self, x: hard_sigmoid(x)}) Swish = type('Swish', (Layer,), {'call': lambda self, x: swish(x)}) ReLU.params = types.MethodType(_params, ReLU) PReLU.params = types.MethodType(_params, PReLU) LeakyReLU.params = types.MethodType(_params, LeakyReLU) ELU.params = types.MethodType(_params, ELU) Softmax.params = types.MethodType(_params, Softmax) ThresholdedReLU.params = types.MethodType(_params, ThresholdedReLU) Tanh.params = types.MethodType(_params, Tanh) Sigmoid.params = types.MethodType(_params, Sigmoid) Sigmoid2.params = types.MethodType(_params, Sigmoid2) Linear.params = types.MethodType(_params, Linear) Gelu.params = types.MethodType(_params, Gelu) Selu.params = types.MethodType(_params, Selu) Softsign.params = types.MethodType(_params, Softsign) Softplus.params = types.MethodType(_params, Softplus) Exponential.params = types.MethodType(_params, Exponential) HardSigmoid.params = types.MethodType(_params, HardSigmoid) Swish.params = types.MethodType(_params, Swish) __all_activations = { 'exponential': Exponential, 'gelu': Gelu, 'hard_sigmoid': HardSigmoid, 'hardsigmoid': HardSigmoid, 'linear': Linear, 'selu': Selu, 'sigmoid': Sigmoid, 'sigmoid2': Sigmoid2, 'softplus': Softplus, 'softsign': Softsign, 'swish': Swish, 'tanh': Tanh, 'leakyrelu': LeakyReLU, 'relu': ReLU, 'elu': ELU, 'softmax': Softmax, 'thresholdedrelu': ThresholdedReLU, 'prelu': PReLU } ALL_ACTIVATION_NAMES = set(__all_activations.keys()) @monolith_export def get(identifier): """获取函数, 可以用名字获取, 也可以用序列化的Json获取 Args: identifier (:obj:`Any`): 标识, 可以是name, 获序列化的Json, None等 Returns: 激活函数 """ if identifier is None: return None if isinstance(identifier, str): if identifier.lower() in __all_activations: return __all_activations[identifier.lower()]() else: evaled = eval(identifier) if isinstance(evaled, dict): return deserialize(evaled) raise TypeError( 'Could not interpret activation function identifier: {}'.format( identifier)) elif isinstance(identifier, dict): return deserialize(identifier) elif callable(identifier): if hasattr(identifier, 'params'): try: if issubclass(identifier, Layer): return identifier() else: return identifier except: return identifier elif isinstance(identifier, Layer): name = identifier.__class__.__name__.lower() return __all_activations[name]() else: try: name = identifier.__name__ return __all_activations[name]() except: return identifier else: raise TypeError( 'Could not interpret activation function identifier: {}'.format( identifier)) @monolith_export def serialize(activation): """序列化激活函数 Args: activation (:obj:`tf.activation`): keras激活函数 Returns: Dict/Json 获序列化的激活函数 """ if isinstance(activation, (Linear, Exponential, Selu, HardSigmoid, Gelu, Sigmoid, Softplus, Softsign, Swish, Tanh)): return repr({'name': activation.__class__.__name__}) elif isinstance(activation, (LeakyReLU, ELU)): return repr({ 'name': activation.__class__.__name__, 'alpha': float(activation.alpha) }) elif isinstance(activation, ReLU): return repr({ 'name': 'ReLU', 'max_value': activation.max_value, 'negative_slope': float(activation.negative_slope), 'threshold': float(activation.threshold) }) elif isinstance(activation, PReLU): return repr({ 'name': 'PReLU', 'alpha_initializer': initializers.serialize(activation.alpha_initializer), 'alpha_regularizer': initializers.serialize(activation.alpha_regularizer), 'alpha_constraint': constraints.serialize(activation.alpha_constraint), 'shared_axes': activation.shared_axes }) elif isinstance(activation, Softmax): return repr({'name': 'Softmax', 'axis': activation.axis}) elif isinstance(activation, ThresholdedReLU): return repr({'name': 'ThresholdedReLU', 'theta': float(activation.theta)}) else: return None @monolith_export def deserialize(identifier): """反序列化激活函数 Args: identifier (:obj:`Any`): 标识, 可以是name, 获序列化的Json, None等 Returns: 激活函数 """ if identifier is None: return None else: if not isinstance(identifier, dict): identifier = eval(identifier) assert isinstance(identifier, dict) name = identifier['name'].lower() assert name in __all_activations identifier.pop('name') return __all_activations[name](**identifier) ================================================ FILE: monolith/native_training/layers/advanced_activations_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import tensorflow as tf import tensorflow.keras.activations as acts import tensorflow.keras.layers as lyacts from monolith.native_training.layers.advanced_activations import get, serialize def serde(act): _act = get(act) sered_act = serialize(_act) get(sered_act) all_acts = [ 'relu', 'leakyrelu', 'elu', 'softmax', 'thresholdedrelu', 'prelu', 'exponential', 'gelu', 'hardsigmoid', 'linear', 'selu', 'sigmoid', 'softplus', 'softsign', 'swish', 'tanh' ] raw_acts = [ acts.tanh, acts.sigmoid, acts.softsign, acts.softplus, acts.softmax, acts.exponential, acts.elu, acts.gelu, acts.hard_sigmoid, acts.selu, acts.swish, acts.relu, acts.linear ] lay_acts = [ lyacts.ReLU(), lyacts.PReLU(), lyacts.ThresholdedReLU(), lyacts.ELU(), lyacts.Softmax(), lyacts.LeakyReLU() ] class ActivationsTest(tf.test.TestCase): def test_get_from_str(self): for act in all_acts: serde(act) def test_get_from_layers(self): for act in lay_acts: serde(act) def test_get_from_func(self): for act in lay_acts: serde(act) def test_params(self): for act in all_acts: cls = get(act).__class__ p = cls.params() # print(p.new_instance()) def test_call(self): inp = tf.random.uniform(shape=(100, 200)) out = [] for act in all_acts: out.append(get(act)(inp)) sum_out = tf.reduce_sum(tf.add_n(out)) with self.session() as sess: sess.run(tf.compat.v1.global_variables_initializer()) print(sess.run(sum_out)) if __name__ == '__main__': tf.compat.v1.disable_eager_execution() tf.test.main() ================================================ FILE: monolith/native_training/layers/agru.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import tensorflow as tf from tensorflow.python.eager import context from tensorflow.python.keras import backend from tensorflow.python.keras import activations from tensorflow.python.keras import initializers from tensorflow.python.keras import regularizers from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.ops import tensor_array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.keras.engine.base_layer import Layer from tensorflow.python.keras.engine.input_spec import InputSpec import tensorflow.python.keras.layers.legacy_rnn.rnn_cell_impl as rnn_impl from monolith.native_training.utils import with_params from monolith.native_training.monolith_export import monolith_export from monolith.native_training.layers.utils import check_dim, dim_size _hasattr = rnn_impl._hasattr _concat = rnn_cell_impl._concat _zero_state_tensors = rnn_cell_impl._zero_state_tensors _BIAS_VARIABLE_NAME = "bias" _WEIGHTS_VARIABLE_NAME = "kernel" # This can be used with self.assertRaisesRegexp for assert_like_rnncell. ASSERT_LIKE_RNNCELL_ERROR_REGEXP = "is not an RNNCell" __all__ = ['AGRUCell', 'dynamic_rnn_with_attention'] @monolith_export @with_params class AGRUCell(Layer): """带attention的GRU单元, 用于DIEN中. Args: units (:obj:`int`): GRU隐含层大小 att_type (:obj:`str`): attention方式, 支持两种AGRU/AUGRU activation (:obj:`tf.activation`): 激活函数 initializer (:obj:`tf.initializer`): kernel初始化器 regularizer (:obj:`tf.regularizer`): kernel正则化 """ def __init__(self, units, att_type='AGRU', activation=None, initializer=None, regularizer=None, **kwargs): super(AGRUCell, self).__init__(**kwargs) # Inputs must be 2-dimensional. assert att_type.upper() in {'AGRU', 'AUGRU'} self.input_spec = [ InputSpec(ndim=2), InputSpec(ndim=2), InputSpec(max_ndim=2) ] self.units = units self.att_type = att_type self.activation = activations.get(activation or math_ops.tanh) self.initializer = tf.initializers.get( initializer) or tf.initializers.HeNormal() self.regularizer = regularizers.get(regularizer) @property def state_size(self): return self.units @property def output_size(self): return self.units def build(self, inputs_shape): input_shape, state_shape, att_shape = inputs_shape assert check_dim(state_shape[-1]) == self.units input_depth = check_dim(input_shape[-1]) if input_shape[-1] == -1: raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s" % str(inputs_shape)) self._gate_kernel = self.add_weight( name="gates/{}".format(_WEIGHTS_VARIABLE_NAME), dtype=self.dtype, shape=[input_depth + self.units, 2 * self.units], initializer=self.initializer, regularizer=self.regularizer) self._gate_bias = self.add_weight( name="gates/{}".format(_BIAS_VARIABLE_NAME), dtype=self.dtype, shape=[2 * self.units], initializer=initializers.Ones(), regularizer=self.regularizer) self._candidate_kernel = self.add_weight( name="candidate/{}".format(_WEIGHTS_VARIABLE_NAME), dtype=self.dtype, shape=[input_depth + self.units, self.units], initializer=self.initializer, regularizer=self.regularizer) self._candidate_bias = self.add_weight( name="candidate/{}".format(_BIAS_VARIABLE_NAME), dtype=self.dtype, shape=[self.units], initializer=initializers.Ones(), regularizer=self.regularizer) super(AGRUCell, self).build(inputs_shape) def call(self, inputs, **kwargs): x, state, att_score = inputs gate_inputs = math_ops.matmul(array_ops.concat([x, state], 1), self._gate_kernel) gate_inputs = nn_ops.bias_add(gate_inputs, self._gate_bias) value = math_ops.sigmoid(gate_inputs) r, u = array_ops.split(value=value, num_or_size_splits=2, axis=1) r_state = r * state candidate = math_ops.matmul(array_ops.concat([x, r_state], 1), self._candidate_kernel) candidate = nn_ops.bias_add(candidate, self._candidate_bias) c = self.activation(candidate) if att_score is None: new_h = (1.0 - u) * state + u * c else: if self.att_type.upper() == 'AUGRU': # GRU with attentional update gate(AUGRU) u = (1.0 - att_score) * u new_h = u * state + (1 - u) * c else: # self.att_type.upper() == 'AGRU': # Attention based GRU(AGRU) new_h = (1. - att_score) * state + att_score * c return new_h, new_h def zero_state(self, batch_size, dtype): # Try to use the last cached zero_state. This is done to avoid recreating # zeros, especially when eager execution is enabled. state_size = self.state_size is_eager = context.executing_eagerly() if is_eager and _hasattr(self, "_last_zero_state"): (last_state_size, last_batch_size, last_dtype, last_output) = getattr(self, "_last_zero_state") if (last_batch_size == batch_size and last_dtype == dtype and last_state_size == state_size): return last_output with backend.name_scope(type(self).__name__ + "ZeroState"): output = _zero_state_tensors(state_size, batch_size, dtype) if is_eager: self._last_zero_state = (state_size, batch_size, dtype, output) return output def get_config(self): config = { "units": self.units, "att_type": self.att_type, "initializer": initializers.serialize(self.initializer), "activation": activations.serialize(self.activation), 'regularizer': regularizers.serialize(self.regularizer) } base_config = super(AGRUCell, self).get_config() return dict(list(base_config.items()) + list(config.items())) @monolith_export def create_ta(name, size, dtype): """创建Tensor Array, 一般用于while循环中存放中间结果. Args: name (:obj:`str`): Array名称 size (:obj:`int`): Array大小 dtype (:obj:`tf.DType`): 数据类型 """ return tensor_array_ops.TensorArray(dtype=dtype, size=size, tensor_array_name=name) @monolith_export def static_rnn_with_attention(cell, inputs, att_scores, init_state=None): """带Attention的静态RNN, 利用python for循环直接将时间维度静态展开, 模型大小会增大 Args: cell (:obj:`RNNCell`): RNN单元 inputs (:obj:`tf.Tensor`): 输入数据, shape为(batch_size, seq_len, emb_size) att_scores (:obj:`tf.Tensor`): attention权重, shape为(batch_size, seq_len) init_state (:obj:`tf.Tensor`): 初始化状态 """ assert isinstance(cell, AGRUCell) if init_state is None: batch_size = dim_size(inputs, 0) if getattr(cell, "get_initial_state", None) is not None: state = cell.get_initial_state(inputs=None, batch_size=batch_size, dtype=dtype) else: state = cell.zero_state(batch_size, inputs.dtype) else: state = init_state inputs, outputs = tf.unstack(tf.transpose(inputs, [1, 0, 2])), [] for time, inp in enumerate(inputs): attr = tf.reshape(att_scores[:, time], shape=(-1, 1)) cell_out, new_state = cell((inp, state, attr)) state = new_state outputs.append(state) outputs = tf.transpose(tf.stack(outputs), [1, 0, 2]) return outputs, state @monolith_export def dynamic_rnn_with_attention(cell, inputs, att_scores, parallel_iterations=1, swap_memory=True, init_state=None): """带Attention的动态RNN, 得用tf.while实现, 模型大小不会增大 Args: cell (:obj:`RNNCell`): RNN单元 inputs (:obj:`tf.Tensor`): 输入数据, shape为(batch_size, seq_len, emb_size) att_scores (:obj:`tf.Tensor`): attention权重, shape为(batch_size, seq_len) parallel_iterations (:obj:`int`): 并行迭代次数, 具体请参考`control_flow_ops.while_loop` swap_memory (:obj:`bool`): 是否swap内存, 具体请参考`control_flow_ops.while_loop` init_state (:obj:`tf.Tensor`): 初始化状态 """ assert isinstance(cell, AGRUCell) batch_size, time_steps = dim_size(inputs, 0), dim_size(inputs, 1) time = array_ops.constant(0, dtype=tf.dtypes.int32, name="time") if init_state is None: if getattr(cell, "get_initial_state", None) is not None: state = cell.get_initial_state(inputs=None, batch_size=batch_size, dtype=dtype) else: state = cell.zero_state(batch_size, inputs.dtype) else: state = init_state with ops.name_scope("dynamic_rnn"): output_ta = create_ta("output_ta", time_steps, inputs.dtype) input_ta = create_ta("input_ta", time_steps, inputs.dtype) # [batch_size, time, emb_dim] -> [time, batch_size, emb_dim] input_ta = input_ta.unstack(tf.transpose(inputs, [1, 0, 2])) def _body(time, output_ta, state, att_scores): att_score = tf.reshape(att_scores[:, time], shape=(-1, 1)) # [bz, 1] cell_out, new_state = cell((input_ta.read(time), state, att_score)) output_ta = output_ta.write(time, cell_out) return (time + 1, output_ta, new_state, att_scores) _, output_final, final_state, _ = control_flow_ops.while_loop( cond=lambda time, *_: time < time_steps, body=_body, loop_vars=(time, output_ta, state, att_scores), parallel_iterations=parallel_iterations, swap_memory=swap_memory) outputs = output_final.stack() outputs = tf.transpose(outputs, [1, 0, 2]) outputs.set_shape([None, time_steps, dim_size(outputs, -1)]) return outputs, final_state ================================================ FILE: monolith/native_training/layers/agru_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import numpy as np import tensorflow as tf from monolith.native_training.layers.agru import AGRUCell, \ dynamic_rnn_with_attention, static_rnn_with_attention class AGRUTest(tf.test.TestCase): def test_agru_instantiate(self): dense_layer_template = AGRUCell.params() test_params0 = dense_layer_template.copy() test_params0.name = 'test_dense0' test_params0.units = 10 test_params0.activation = tf.keras.activations.sigmoid test_params0.initializer = tf.keras.initializers.GlorotNormal() mlp1 = test_params0.instantiate() print(mlp1) mlp2 = AGRUCell(units=10, activation=tf.keras.activations.sigmoid, initializer=tf.keras.initializers.HeUniform()) print(mlp2) def test_agru_serde(self): mlp1 = AGRUCell(units=10, activation=tf.keras.activations.sigmoid, initializer=tf.keras.initializers.HeUniform()) cfg = mlp1.get_config() mlp2 = AGRUCell.from_config(cfg) print(mlp1, mlp2) def test_agru_call(self): dense_layer_template = AGRUCell.params() test_params0 = dense_layer_template.copy() test_params0.name = 'test_dense0' test_params0.units = 10 test_params0.activation = tf.keras.activations.sigmoid test_params0.initializer = tf.keras.initializers.GlorotNormal() layer = test_params0.instantiate() print(layer) data = tf.keras.backend.variable(np.ones((100, 100))) state = tf.keras.backend.variable(np.ones((100, 10))) attr = tf.keras.backend.variable(np.ones((100, 1))) _, out = layer((data, state, attr)) sum_out = tf.reduce_sum(out) with self.session() as sess: sess.run(tf.compat.v1.global_variables_initializer()) print(sess.run(sum_out)) def test_agru_static_rnn_call(self): dense_layer_template = AGRUCell.params() test_params0 = dense_layer_template.copy() test_params0.name = 'test_dense0' test_params0.units = 10 test_params0.activation = tf.keras.activations.sigmoid test_params0.initializer = tf.keras.initializers.GlorotNormal() cell = test_params0.instantiate() print(cell) data = tf.keras.backend.variable(np.ones((100, 20, 10))) attr = tf.keras.backend.variable(np.ones((100, 20))) _, out = static_rnn_with_attention(cell, inputs=data, att_scores=attr) sum_out = tf.reduce_sum(out) with self.session() as sess: sess.run(tf.compat.v1.global_variables_initializer()) print(sess.run(sum_out)) def test_agru_dynamic_rnn_call(self): dense_layer_template = AGRUCell.params() test_params0 = dense_layer_template.copy() test_params0.name = 'test_dense0' test_params0.units = 10 test_params0.activation = tf.keras.activations.sigmoid test_params0.initializer = tf.keras.initializers.GlorotNormal() cell = test_params0.instantiate() print(cell) data = tf.random.uniform(shape=(100, 20, 10)) attr = tf.random.uniform(shape=(100, 20)) _, out = dynamic_rnn_with_attention(cell, inputs=data, att_scores=attr) sum_out = tf.reduce_sum(out) with self.session() as sess: sess.run(tf.compat.v1.global_variables_initializer()) print(sess.run(sum_out)) if __name__ == '__main__': tf.compat.v1.disable_eager_execution() tf.compat.v1.disable_v2_behavior() tf.test.main() ================================================ FILE: monolith/native_training/layers/dense.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import tensorflow as tf from tensorflow.python.ops import variables as variable_ops from tensorflow.python.keras.layers import Layer from tensorflow.python.framework import tensor_shape from tensorflow.python.keras import activations from tensorflow.python.keras import regularizers from tensorflow.python.keras import backend as K from tensorflow.python.keras import initializers from tensorflow.python.keras.engine import base_layer_utils from tensorflow.python.keras.engine.input_spec import InputSpec from tensorflow.python.keras.layers.ops import core as core_ops from monolith.native_training.utils import with_params, get_uname from monolith.native_training.monolith_export import monolith_export @monolith_export @with_params class Dense(Layer): """Dense Layer实现 :math:`y = active(wx + b)`. 之所以要重新实现一个Dense Layer, 是因为增加的了些额外的操作, 如kernel_norm, 论文可参考 https://arxiv.org/pdf/1602.07868.pdf kernel_norm的计算方式为: .. math:: y = active( norm_{kernel} * l2_{normalize}(W) x + b) 先对W求 :math:`l2_{normalize}`, 将其取值限制在[-1, 1]之间, 然后乘以 :math:`norm_{kernel}`, 这样 :math:`norm_{kernel} * l2_{normalize}(W)` 的取值在 [-kernel_norm, kernel_norm]之间, 可以有效地防止梯度爆炸. :math:`norm_{kernel}` 一般由W的初值决定, 有 :math:`norm_{kernel} = morm(W_{init})`. 也可让 :math:`norm_{kernel}` 成为trainable, 让算法自已调节. Args: units (:obj:`tf.Tensor`): 输入, 也就是x activation (:obj:`tf.activation`, `str`): 激活函数, 可以用str表示, 也可以用TF中的activation use_bias (:obj:`bool`): 是否使用bias kernel_initializer (:obj:`tf.initializer`): kernel, 也就是W的初始化器 bias_initializer (:obj:`tf.initializer`): bias, 也就是b的初始化器 bias_regularizer (:obj:`tf.regularizer`): bias正侧化 allow_kernel_norm (:obj:`bool`): 是否开启kernel_norm kernel_norm_trainable (:obj:`bool`): 是否让kernel_norm可训练 partitioner (:obj:`tf.partitioner`, optional): 分区器, 可以将一个大变量分到不同的PS机器上 inactive_relu_monitor (:obj:`bool`): 是否开启relu_monitor inactive_relu_monitor_decay (:obj:`float`): 因为relu的非0率是用指数平均来计算的, decay就是衰减因子 optimizer (:obj:`tf.optimizer`): 优化器, 请参考TF >>> dense = Dense(units=100, >>> activation=tf.keras.activations.sigmoid, >>> kernel_initializer=tf.keras.initializers.GlorotNormal()) >>> y = dense(x) """ def __init__(self, units, activation=None, use_bias=True, kernel_initializer='glorot_uniform', bias_initializer='zeros', kernel_regularizer=None, bias_regularizer=None, allow_kernel_norm=False, kernel_norm_trainable=False, partitioner=None, inactive_relu_monitor=False, inactive_relu_monitor_decay=0.1, optimizer=None, **kwargs): if 'input_shape' not in kwargs and 'input_dim' in kwargs: kwargs['input_shape'] = (kwargs.pop('input_dim'),) # Call the _init__() function for tf.keras.layers.Dense super(Dense, self).__init__(**kwargs) # Change/Add some class properties to the tf.keras.layers.Dense # properties. Note that this Dense layer does not support regularizers # and constraints. self.units = units self.activation = activations.get(activation) self.use_bias = use_bias self.kernel_initializer = initializers.get(kernel_initializer or 'glorot_uniform') self.bias_initializer = initializers.get(bias_initializer) self.kernel_var = None self.supports_masking = True self.input_spec = InputSpec(min_ndim=2) self.kernel_regularizer = regularizers.get(kernel_regularizer) self.bias_regularizer = regularizers.get(bias_regularizer) self.allow_kernel_norm = allow_kernel_norm self.kernel_norm_trainable = kernel_norm_trainable self.partitioner = partitioner self.inactive_relu_monitor = inactive_relu_monitor self.inactive_relu_monitor_decay = inactive_relu_monitor_decay self.optimizer = optimizer def add_weight(self, name=None, shape=None, dtype=None, initializer=None, regularizer=None, trainable=None, constraint=None, use_resource=None, synchronization=tf.VariableSynchronization.AUTO, aggregation=tf.VariableAggregation.NONE, **kwargs): var = super().add_weight(name=name, shape=shape, dtype=dtype, initializer=initializer, regularizer=regularizer, trainable=trainable, constraint=constraint, use_resource=use_resource, synchronization=synchronization, aggregation=aggregation, **kwargs) if isinstance(var, tf.Variable): var.optimizer = self.optimizer elif isinstance(var, variable_ops.PartitionedVariable): for var_p in var: var_p.optimizer = self.optimizer return var def get_variable(self, name, shape=None, dtype=None, initializer=None, regularizer=None, trainable=None, collections=None, caching_device=None, partitioner=None, validate_shape=True, use_resource=None, custom_getter=None, constraint=None, synchronization=tf.VariableSynchronization.AUTO, aggregation=tf.VariableAggregation.NONE): cur_name_scope = tf.compat.v1.get_default_graph().get_name_scope() with tf.compat.v1.variable_scope(cur_name_scope, reuse=tf.compat.v1.AUTO_REUSE): var = tf.compat.v1.get_variable(name=name, shape=shape, dtype=dtype, initializer=initializer, regularizer=regularizer, trainable=trainable, collections=collections, caching_device=caching_device, partitioner=partitioner, validate_shape=validate_shape, use_resource=use_resource, custom_getter=custom_getter, constraint=constraint, synchronization=synchronization, aggregation=aggregation) if isinstance(var, tf.Variable): var.optimizer = self.optimizer elif isinstance(var, variable_ops.PartitionedVariable): for var_p in var: var_p.optimizer = self.optimizer if base_layer_utils.is_split_variable(var) or isinstance( var, variable_ops.PartitionedVariable): for v in var: K.track_variable(v) if trainable: self._trainable_weights.append(v) else: self._non_trainable_weights.append(v) else: K.track_variable(var) if trainable: self._trainable_weights.append(var) else: self._non_trainable_weights.append(var) return var def build(self, input_shape): dtype = tf.dtypes.as_dtype(self.dtype or K.floatx()) if not (dtype.is_floating or dtype.is_complex): raise TypeError('Unable to build `Dense` layer with non-floating point ' 'dtype %s' % (dtype,)) input_shape = tensor_shape.TensorShape(input_shape) if tensor_shape.dimension_value(input_shape[-1]) is None: raise ValueError('The last dimension of the inputs to `Dense` ' 'should be defined. Found `None`.') last_dim = tensor_shape.dimension_value(input_shape[-1]) self.input_spec = InputSpec(min_ndim=2, axes={-1: last_dim}) kernel_shape = [last_dim, self.units] init_kernel = self.kernel_initializer(shape=kernel_shape, dtype=self.dtype) self.kernel_var = self.get_variable(initializer=init_kernel, trainable=True, name='kernel', shape=None, dtype=dtype, regularizer=self.kernel_regularizer, partitioner=self.partitioner) self.kernel = self.kernel_var # Add the option for allow_kernel_norm if self.allow_kernel_norm: self.kernel = tf.nn.l2_normalize(self.kernel, axis=0, epsilon=1e-6, name='normalized_kernel') if self.kernel_norm_trainable: init_trainable_kernel_norm = tf.linalg.norm(init_kernel, axis=0) self.trainable_kernel_norm = self.get_variable( initializer=init_trainable_kernel_norm, shape=None, trainable=True, name='trainable_kernel_norm', dtype=dtype, partitioner=self.partitioner) self.kernel = tf.multiply(self.kernel, self.trainable_kernel_norm, name='mul_of_kernel_and_trainable_norm') if self.use_bias: self.bias = self.add_weight(name='bias', shape=[self.units], initializer=self.bias_initializer, regularizer=self.bias_regularizer, dtype=dtype, trainable=True) else: self.bias = None if self.inactive_relu_monitor and self.activation.__name__ == 'relu': self.inactive_relu_count_moving_avg = self.get_variable( initializer=tf.keras.initializers.zeros, trainable=False, name='inactive_relu_count_moving_avg', shape=[self.units], dtype=tf.float32, collections=[ tf.compat.v1.GraphKeys.METRIC_VARIABLES, tf.compat.v1.GraphKeys.GLOBAL_VARIABLES ]) super(Dense, self).build(input_shape) def call(self, inputs, **kwargs): output = core_ops.dense(inputs, self.kernel, self.bias, self.activation, dtype=self._compute_dtype_object) if self.inactive_relu_monitor: inactive_relu_count = self.units - tf.math.count_nonzero(output, axis=0) tf.compat.v1.summary.histogram('inactive_relu_count_moving_avg', self.inactive_relu_count_moving_avg) update_op = tf.compat.v1.assign( self.inactive_relu_count_moving_avg, (1. - self.inactive_relu_monitor_decay) * self.inactive_relu_count_moving_avg + self.inactive_relu_monitor_decay * tf.cast(inactive_relu_count, dtype=tf.float32), ) with tf.control_dependencies([update_op]): output = tf.identity(output) return output def compute_output_shape(self, input_shape): input_shape = tensor_shape.TensorShape(input_shape) input_shape = input_shape.with_rank_at_least(2) if tensor_shape.dimension_value(input_shape[-1]) is None: raise ValueError( 'The innermost dimension of input_shape must be defined, but saw: %s' % input_shape) return input_shape[:-1].concatenate(self.units) def get_config(self): config = { 'units': self.units, 'activation': activations.serialize(self.activation), 'use_bias': self.use_bias, 'kernel_initializer': initializers.serialize(self.kernel_initializer), 'bias_initializer': initializers.serialize(self.bias_initializer), 'kernel_regularizer': regularizers.serialize(self.kernel_regularizer), 'bias_regularizer': regularizers.serialize(self.bias_regularizer), 'allow_kernel_norm': self.allow_kernel_norm, 'kernel_norm_trainable': self.kernel_norm_trainable, 'partitioner': self.partitioner, } base_config = super(Dense, self).get_config() return dict(list(base_config.items()) + list(config.items())) ================================================ FILE: monolith/native_training/layers/dense_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import numpy as np import os import tensorflow as tf from monolith.native_training.layers.dense import Dense class DenseTest(tf.test.TestCase): def test_dense_instantiate(self): dense_layer_template = Dense.params() test_params0 = dense_layer_template.copy() test_params0.name = 'test_dense0' test_params0.units = 100 test_params0.activation = tf.keras.activations.sigmoid test_params0.kernel_initializer = tf.keras.initializers.GlorotNormal() ins1 = test_params0.instantiate() print(ins1) ins2 = Dense(units=100, activation=tf.keras.activations.sigmoid, kernel_initializer=tf.keras.initializers.GlorotNormal()) print(ins2) def test_dense_serde(self): dense_layer_template = Dense.params() test_params0 = dense_layer_template.copy() test_params0.name = 'test_dense0' test_params0.units = 100 test_params0.activation = tf.keras.activations.sigmoid test_params0.kernel_initializer = tf.keras.initializers.GlorotNormal() ins1 = test_params0.instantiate() print(ins1) cfg = ins1.get_config() ins2 = Dense.from_config(cfg) print(ins1, ins2) def test_dense_call(self): layer = Dense(units=100, activation=tf.keras.activations.sigmoid, kernel_initializer=tf.keras.initializers.GlorotNormal()) data = tf.keras.backend.variable(np.ones((100, 100))) sum_out = tf.reduce_sum(layer(data)) with self.session() as sess: sess.run(tf.compat.v1.global_variables_initializer()) print(sess.run(sum_out)) def test_dense_kernel_norm_call(self): layer = Dense(units=100, allow_kernel_norm=True, kernel_norm_trainable=True, activation=tf.keras.activations.sigmoid, kernel_initializer=tf.keras.initializers.GlorotNormal()) data = tf.keras.backend.variable(np.ones((100, 100))) sum_out = tf.reduce_sum(layer(data)) with self.session() as sess: sess.run(tf.compat.v1.global_variables_initializer()) print(sess.run(sum_out)) def test_inactive_relu_monitor(self): dense_layer_template = Dense.params() test_params = dense_layer_template.copy() test_params.units = 10 test_params.activation = tf.keras.activations.relu test_params.inactive_relu_monitor = True layer = test_params.instantiate() with tf.Graph().as_default(): x = tf.constant([[1., 1., 1., 1., 1.]]) _ = layer(x) graph = tf.compat.v1.get_default_graph() self.assertIn('Dense/inactive_relu_count_moving_avg_1', [node.name for node in graph.as_graph_def().node]) def test_dense_with_explicit_partition(self): layer = Dense(units=1024, allow_kernel_norm=True, kernel_norm_trainable=True, activation=tf.keras.activations.sigmoid, kernel_initializer=tf.keras.initializers.GlorotNormal(), partitioner=tf.compat.v1.variable_axis_size_partitioner( max_shard_bytes=1 << 17, max_shards=5)) data = tf.keras.backend.variable(np.ones((100, 294))) sum_out = layer(data) partition_dims = [] expected_dims = [59, 59, 59, 59, 58] for var in layer.kernel_var: partition_dims.append(var.shape[0]) with self.session() as sess: sess.run(tf.compat.v1.global_variables_initializer()) sum_out = sess.run(sum_out) self.assertEqual(sum_out.shape, (100, 1024)) def test_dense_with_implicit_partition(self): with tf.compat.v1.variable_scope( "", partitioner=tf.compat.v1.variable_axis_size_partitioner( max_shard_bytes=1 << 17, max_shards=5)): # The dense kernel's shape is [294, 1024] and will be # partitioned into five shards(unevenly) layer = Dense(units=1024, allow_kernel_norm=True, kernel_norm_trainable=True, activation=tf.keras.activations.sigmoid, kernel_initializer=tf.keras.initializers.GlorotNormal(), partitioner=None) data = tf.keras.backend.variable(np.ones((100, 294))) sum_out = layer(data) partition_dims = [] expected_dims = [59, 59, 59, 59, 58] for var in layer.kernel_var: partition_dims.append(var.shape[0]) self.assertEqual(partition_dims, expected_dims) with self.session() as sess: sess.run(tf.compat.v1.global_variables_initializer()) sum_out = sess.run(sum_out) self.assertEqual(sum_out.shape, (100, 1024)) if __name__ == '__main__': tf.compat.v1.disable_eager_execution() tf.test.main() ================================================ FILE: monolith/native_training/layers/feature_cross.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import List from absl import logging import tensorflow as tf from tensorflow.keras.layers import Layer, Conv1D from tensorflow.python.keras import activations import tensorflow.keras.initializers as initializers from tensorflow.python.keras import regularizers from monolith.native_training.layers.mlp import MLP from monolith.native_training.utils import with_params, get_uname from monolith.native_training.layers.utils import merge_tensor_list, DCNType from monolith.native_training.monolith_export import monolith_export from monolith.native_training.layers.layer_ops import ffm from tensorflow.python.ops import variables as variable_ops from tensorflow.python.keras.engine import base_layer_utils from tensorflow.python.keras import backend as K from monolith.native_training.layers.utils import check_dim, dim_size @monolith_export @with_params class GroupInt(Layer): """Group Interaction的缩写, 一种简单的特征交叉方式, 同时支持attention. 论文可参考 https://www.csie.ntu.edu.tw/~b97053/paper/Rendle2010FM.pdf 特征交叉可以在多个层面做, 一种方法是在特征工程中做, 即在特征工程阶段直接生成一个新特征, 这个特征是由多个原始征特拼接起来的, 然后再做Embedding. 这样做的好处是记忆性较好, 但由于稀疏性, 有时训练不够充分, 也存在过拟合的风险. 另一种是在模型层面做, 代表算法为FM, DeepFM等 在模型中做二阶特征交叉存在如下问题: - 输出维度高: FM用点积表示特征交叉, 如果输入有n个特征, 输出有 n(n-1)/2 维, 当特征较多时, 给训练/推理带来很大的负担 - 重复交叉: 特征交叉可以在两个地方做, 现实中往往同时做. FM等算法并不区分参与交叉的是原始特征还是交叉特征. 所以存在重复交叉. 不过, 也有人认为 重复交叉会生成更高阶的特征, 不是重复 为了克服FM等算法的不足, 可以使用GroupInt. 它先将特征分组(Group), 哪些特征属于一个组由算法开发人员确定. 然后用sumpooling来将特征聚合 得到group embedding. 最后用group embedding做两两交叉输出 GroupInt输出有如下几种形式: - 交叉用dot, 直接输出. 此时输出的大小远小于原始FM, 而且, 人工确定group, 减少了重复交叉 - 交叉用multiply, 输出有两种选择: - 直接concat输出 - 用attention, 将所以结果线性组合后输出(与AFM一样, 论文可参考 https://www.ijcai.org/proceedings/2017/0435.pdf) Args: interaction_type (:obj:`str`): Interaction的方式有两种, dot和multiply use_attention (:obj:`bool`): 是否使用attention, 当interaction_type为'multiply'时才可用 attention_units (:obj:`List[int]`): 使用一个MLP生成attention, attention_units表示MLP每一层的dim, 最后一维必须是1 activation (:obj:`tf.activation`): MLP的激活函数 initializer (:obj:`tf.initializer`): MLP的初始化器 regularizer (:obj:`tf.regularizer`): MLP的正则化器 out_type (:obj:`str`): 输出类型, 可以为stack, concat, None keep_list (:obj:`bool`): 输出是否保持list """ def __init__(self, interaction_type='multiply', use_attention: bool = False, attention_units: List[int] = None, activation='relu', initializer=None, regularizer=None, out_type='concat', keep_list: bool = False, **kwargs): super(GroupInt, self).__init__(**kwargs) assert interaction_type in ['multiply', 'dot'] self.interaction_type = interaction_type self.use_attention = use_attention if use_attention: assert interaction_type == 'multiply' self.attention_units = attention_units self.activation = activations.get(activation) self.initializer = initializers.get( initializer) or initializers.GlorotNormal() self.regularizer = regularizers.get(regularizer) self.out_type = out_type self.keep_list = keep_list def build(self, input_shape): if self.use_attention: assert self.attention_units[-1] == 1 self.mlp = MLP(name='groupint_attention_mlp', output_dims=self.attention_units, activations=self.activation, initializers=self.initializer, kernel_regularizer=self.regularizer) else: self.mlp = None return super().build(input_shape) def call(self, inputs, **kwargs): left_fields, right_fields = inputs left, right = tf.concat(left_fields, axis=1), tf.concat(right_fields, axis=1) last_dim_size = dim_size(left_fields[0], -1) ffm_embeddings = ffm(left=left, right=right, dim_size=last_dim_size, int_type=self.interaction_type) if self.interaction_type == 'multiply': if self.use_attention: num_feature = len(left_fields) * len( right_fields ) #int(dim_size(left, 1) * dim_size(right, 1) / last_dim_size) stacked = tf.reshape(ffm_embeddings, shape=(-1, num_feature, last_dim_size)) attention = self.mlp(stacked) # (bs, num_feature, 1) ffm_embeddings = tf.reshape(stacked * attention, shape=(-1, num_feature * last_dim_size)) return [ffm_embeddings] if self.keep_list else ffm_embeddings def get_config(self): config = { 'interaction_type': self.interaction_type, 'use_attention': self.use_attention, 'attention_units': self.attention_units, 'activation': activations.serialize(self.activation), 'initializer': initializers.serialize(self.initializer), 'regularizer': regularizers.serialize(self.regularizer), 'out_type': self.out_type, 'keep_list': self.keep_list } base_config = super(GroupInt, self).get_config() return dict(list(base_config.items()) + list(config.items())) FFM = GroupInt @monolith_export @with_params class AllInt(Layer): r"""AllInt是All Interaction的缩写, 是一种简单的特征交叉方式, 通过引入压缩矩阵, 减少输出大小. 论文可参考 https://www.csie.ntu.edu.tw/~b97053/paper/Rendle2010FM.pdf GroupInt虽然能克服FM带来的输出膨胀的问题, 但也有其它问题, 如Group要人工决定, 给算法开发人员带来较大的负担. AllInt将所有特征都做交叉, 不用人工选择, 同时引入压缩矩阵来减少输出大小 All Interaction中引入压缩矩阵. 如下: .. math:: O_{n, c} = X_{n, k} * X_{n, k}^T * C_{n, c} 为了避免生成(n, n)的大中间矩阵, 在计算上进行了一些优化, 即先算 :math:`X_{n, k}^T * C_{n, c}`, 这样得到的(k, c)矩阵小很多, 计算效率高 Args: cmp_dim (:obj:`int`): 压缩维的维度 initializer (:obj:`tf.initializer`): 初始化器 regularizer (:obj:`tf.regularizer`): kernel正则化器 use_bias (:obj:`bool`) 是否启用bias out_type (:obj:`str`): 输出类型, 可以为stack, concat, None keep_list (:obj:`bool`): 输出是否保持list """ def __init__(self, cmp_dim, initializer=None, regularizer=None, use_bias=True, out_type='concat', keep_list=False, **kwargs): super(AllInt, self).__init__(**kwargs) self.cmp_dim = cmp_dim self.initializer = initializers.get( initializer) or initializers.GlorotNormal() self.regularizer = regularizers.get(regularizer) self.use_bias = use_bias self.out_type = out_type self.keep_list = keep_list def build(self, input_shape): num_feat = check_dim(input_shape[1]) self.kernel = self.add_weight(name='allint_kernel', shape=(num_feat, self.cmp_dim), dtype=tf.float32, initializer=self.initializer, regularizer=self.regularizer, trainable=True) if self.use_bias: self.bias = self.add_weight(name='allint_bias', shape=(self.cmp_dim,), dtype=tf.float32, initializer=initializers.Zeros(), trainable=True) return super(AllInt, self).build(input_shape) def call(self, embeddings, **kwargs): # embeddings: [batch_size, num_feat, emb_size] transposed = tf.transpose(embeddings, perm=[0, 2, 1]) # [batch_size, emb_size, num_feat] feature_comp = tf.matmul(transposed, self.kernel) # [batch_size, emb_size, cmp_dim] if self.use_bias: feature_comp += self.bias # [batch_size, num_feat, emb_size] * [batch_size, emb_size, cmp_dim] -> [batch_size, num_feat, cmp_dim] interaction = tf.matmul(embeddings, feature_comp) # [batch_size, num_feat, cmp_dim] return merge_tensor_list(interaction, merge_type=self.out_type, keep_list=self.keep_list) def get_config(self): config = { 'cmp_dim': self.cmp_dim, 'initializer': initializers.serialize(self.initializer), 'regularizer': regularizers.serialize(self.regularizer), 'use_bias': self.use_bias, 'out_type': self.out_type, 'keep_list': self.keep_list } base_config = super(AllInt, self).get_config() return dict(list(base_config.items()) + list(config.items())) @monolith_export @with_params class CDot(Layer): """Compression and Dot Interaction, CDot. 可以看成是Allint的升级版, 也是一种自动做特征交叉的方法. 论文可参考 https://arxiv.org/pdf/1803.05170.pdf Allint通过引入压缩矩阵, 减少相对FM的输出大小, 同时移除了GroupInt中人工定义Group的不足, CDot与Allint十分相似 CDot相对Allint的改进在于: - AllInt引入的压缩矩阵与输入无关, 在CDot中, 压缩矩阵是与输入数据相关, 可以根据输入, 自适应地调节压缩矩阵. - CDot输出时, 会将压缩后的中间特征也输出, 作为上层MLP的输入, Allint不会做这一步 一般提取高阶特征交叉时使用MLP, MLP的输入是直接接拼起来的Embedding. 一些实验表明, 可以先用CDot提取二阶特征, 再在二阶特征基础上提取高阶 特征效果更好. 所以CDot也可以与MLP联用, 用于高阶特征提取 Args: project_dim (:obj:`int`): 投影dim compress_units (:obj:`List[int]`): 用一个MLP来压缩, 压缩MLP的各层dims activation (:obj:`tf.activation`): MLP的激活函数 initializer (:obj:`tf.initializer`): 初始化器 regularizer (:obj:`tf.regularizer`): kernel正则化器 """ def __init__(self, project_dim, compress_units, activation='relu', initializer=None, regularizer=None, **kwargs): super(CDot, self).__init__(**kwargs) self.activation = activations.get(activation) self.initializer = initializers.get( initializer) or initializers.GlorotNormal() self.regularizer = regularizers.get(regularizer) self.project_dim = project_dim self.compress_units = compress_units def build(self, input_shape): (_, num_feature, emd_size) = input_shape self._num_feature = check_dim(num_feature) self._emd_size = check_dim(emd_size) self.project_weight = self.add_weight(name="project_weight", shape=(num_feature, self.project_dim), dtype=tf.float32, initializer=self.initializer, regularizer=self.regularizer) self.compress_tower = MLP(output_dims=self.compress_units + [emd_size * self.project_dim], activations=self.activation, initializers=self.initializer, kernel_regularizer=self.regularizer, name="compress_tower") self._trainable_weights.extend(self.compress_tower.trainable_weights) self._non_trainable_weights.extend( self.compress_tower.non_trainable_weights) return super(CDot, self).build(input_shape) def call(self, inputs, **kwargs): # 1) project the origin feature into raw compressed space transed_input = tf.transpose(inputs, perm=[0, 2, 1 ]) # (batch_size, emd_size, num_feature) # (batch_size, emd_size, num_feature) * (num_feature, project_dim) -> (batch_size, emd_size, project_dim) projected = tf.matmul(transed_input, self.project_weight) # 2) concat the raw compressed features, and go through mlp to cast to compressed space concated = tf.reshape( projected, shape=(-1, self._emd_size * self.project_dim)) # (batch_size, emd_size * project_dim) compressed = self.compress_tower( concated) # (batch_size, emd_size * project_dim) # 3) feature cross # (batch_size, num_feature, emd_size) * (batch_size, emd_size, project_dim) -> (batch_size, num_feature, project_dim) crossed = tf.matmul( inputs, tf.reshape(compressed, shape=(-1, self._emd_size, self.project_dim))) crossed = tf.reshape( crossed, shape=(-1, self._num_feature * self.project_dim)) # (batch_size, num_feature * project_dim) # 4) concat the compressed features and crossed features return tf.concat([crossed, compressed], axis=1) def get_config(self): config = { 'project_dim': self.project_dim, 'compress_units': self.compress_units, 'activation': activations.serialize(self.activation), 'initializer': initializers.serialize(self.initializer), 'regularizer': regularizers.serialize(self.regularizer), } base_config = super(CDot, self).get_config() return dict(list(base_config.items()) + list(config.items())) @monolith_export @with_params class CAN(Layer): """Co-action Network, CAN, 协同作用网络 论文可参考 https://arxiv.org/pdf/2011.05625.pdf 在模型中做特征交叉, 同一份Embedding, 同时要拟合原始特征/交叉特征, 容易两个都拟合不好. CAN是为了改善这种情况提出的, 通过拓展参数, 使得交叉特征与原始特征的学习相对独立 CAN Unit将要建模的”特征对”分为weight side(item)和input side(user): - weight side可以reshape成MLP的参数 - input side作为MLP的输入,通过多层MLP来建模co-action Args: layer_num (:obj:`int`): Layer的层数 activation (:obj:`tf.activation`): 激活函数 is_seq (:obj:`bool`): 是否为序列特征 is_stacked (:obj:`bool`): User侧是否是多个特征stack起来的 """ def __init__(self, layer_num: int = 2, activation='relu', is_seq: bool = False, is_stacked: bool = True, **kwargs): super(CAN, self).__init__(**kwargs) self.layer_num = layer_num self.activation = activations.get(activation) self.is_seq = is_seq self.is_stacked = is_stacked def build(self, input_shape): user_emb_sh, item_emb_sh = input_shape self._batch_size = check_dim(user_emb_sh[0]) assert user_emb_sh[0] == item_emb_sh[0] u_emb_size = check_dim(user_emb_sh[-1]) iemb_size = check_dim(item_emb_sh[-1]) assert iemb_size == (u_emb_size * (u_emb_size + 1)) * self.layer_num self._splits = [u_emb_size * u_emb_size, u_emb_size] * self.layer_num return super(CAN, self).build(input_shape) def call(self, inputs, **kwargs): user_emb, item_emb = inputs if self._batch_size == -1: self._batch_size = dim_size(user_emb, 0) dims = self._splits[1] if self.is_seq and self.is_stacked: # user_emb shape: (bs, num_feat, seq_len, u_emb_size) weight_shape = (self._batch_size, 1, dims, dims) bias_shape = (self._batch_size, 1, 1, dims) elif not self.is_seq and self.is_stacked: # user_emb shape: (bs, num_feat, u_emb_size) weight_shape = (self._batch_size, dims, dims) bias_shape = (self._batch_size, 1, dims) elif self.is_seq and not self.is_stacked: # user_emb shape: (bs, seq_len, u_emb_size) weight_shape = (self._batch_size, dims, dims) bias_shape = (self._batch_size, 1, dims) else: # user_emb shape: (bs, u_emb_size) user_emb = tf.expand_dims(user_emb, axis=1) # (bs, 1, u_emb_size) weight_shape = (self._batch_size, dims, dims) bias_shape = (self._batch_size, 1, dims) params = tf.split(item_emb, num_or_size_splits=self._splits, axis=1) for i in range(self.layer_num): weight = tf.reshape(params[2 * i], shape=weight_shape) bias = tf.reshape(params[2 * i + 1], shape=bias_shape) if self.activation is not None: user_emb = self.activation(tf.matmul(user_emb, weight) + bias) else: user_emb = tf.matmul(user_emb, weight) + bias if self.is_seq and self.is_stacked: # user_emb shape: (bs, num_feat, seq_len, u_emb_size) return tf.reduce_sum(user_emb, axis=2) # (bs, num_feat, u_emb_size) elif not self.is_seq and self.is_stacked: # user_emb shape: (bs, num_feat, u_emb_size) return user_emb # (bs, num_feat, u_emb_size) elif self.is_seq and not self.is_stacked: # user_emb shape: (bs, seq_len, u_emb_size) return tf.reduce_sum(user_emb, axis=1) # (bs, u_emb_size) else: # user_emb shape: (bs, 1, u_emb_size) return tf.squeeze(user_emb) # (bs, u_emb_size) def get_config(self): config = { 'layer_num': self.layer_num, 'activation': activations.serialize(self.activation), "is_seq": self.is_seq, "is_stacked": self.is_stacked } base_config = super(CAN, self).get_config() return dict(list(base_config.items()) + list(config.items())) @monolith_export @with_params class DCN(Layer): r"""二阶特征交叉可用FM等方法显式提取, 更高阶的交叉用MLP隐式提取. Deep & Cross Network (DCN)可替代MLP做高阶特征交叉, 通过加入残差联接, 达到比MLP更好的效果 DCN现在有三个版本(论文可参考 https://arxiv.org/pdf/1708.05123.pdf): - vector, :math:`x_{l+1} = x_0 * x_l w + b + x_l`, 其中w的shape为(dim, 1) - matrix, :math:`x_{l+1} = x_0 * (x_l w + b) + x_l`, 其中w的shape为(dim, dim) - mixed, :math:`x_{l+1} = \sum_i x_0 * (x_l V C U^T + b) * softmax(x_l g) + x_l` Args: layer_num (:obj:`int`): DCN的层数 dcn_type (:obj:`str`): DCN类型, 目前支持三种vector/matrix/mixed initializer (:obj:`tf.initializer`): 初始化器 regularizer (:obj:`tf.regularizer`): 正则化器 num_experts (:obj:`int`): 只在mixed模式下有用, 用于指定expert个数 low_rank (:obj:`int`): 只在mixed模式下有用, 用于指定低秩 use_dropout (:obj:`bool`): 只否使用dropout keep_prob (:obj:`float`): dropout的保留概率 mode (:obj:`str`): 运行模式, 可以是train/eval/predict """ def __init__(self, layer_num: int = 1, dcn_type: str = DCNType.Matrix, initializer=None, regularizer=None, num_experts: int = 1, low_rank: int = 0, allow_kernel_norm: bool = False, use_dropout=False, keep_prob=0.95, mode: str = tf.estimator.ModeKeys.TRAIN, **kwargs): super(DCN, self).__init__(**kwargs) self.layer_num = layer_num self.dcn_type = dcn_type self.num_experts = num_experts self.low_rank = low_rank self.initializer = initializers.get( initializer) or initializers.GlorotNormal() self.regularizer = regularizers.get(regularizer) self.allow_kernel_norm = allow_kernel_norm self.use_dropout = use_dropout self.keep_prob = keep_prob self.mode = mode def build(self, input_shape): dims = check_dim(input_shape[-1]) if self.dcn_type == DCNType.Vector: self.kernel = [ self.get_variable(name='kernel_{}'.format(i), shape=[dims, 1], dtype=tf.float32, initializer=self.initializer, regularizer=self.regularizer, trainable=True) for i in range(self.layer_num) ] elif self.dcn_type == DCNType.Matrix: self.kernel = [ self.get_variable(name='kernel_{}'.format(i), shape=[dims, dims], dtype=tf.float32, initializer=self.initializer, regularizer=self.regularizer, trainable=True) for i in range(self.layer_num) ] else: self.U = [[ self.get_variable(name='U_{}_{}'.format(i, j), shape=[dims, self.low_rank], dtype=tf.float32, initializer=self.initializer, regularizer=self.regularizer, trainable=True) for j in range(self.num_experts) ] for i in range(self.layer_num)] self.V = [[ self.get_variable(name='V_{}_{}'.format(i, j), shape=[dims, self.low_rank], dtype=tf.float32, initializer=self.initializer, regularizer=self.regularizer, trainable=True) for j in range(self.num_experts) ] for i in range(self.layer_num)] self.C = [[ self.get_variable(name='C_{}_{}'.format(i, j), shape=[self.low_rank, self.low_rank], dtype=tf.float32, initializer=self.initializer, regularizer=self.regularizer, trainable=True) for j in range(self.num_experts) ] for i in range(self.layer_num)] self.G = [[ self.get_variable(name='G_{}_{}'.format(i, j), shape=[dims, 1], dtype=tf.float32, initializer=self.initializer, regularizer=self.regularizer, trainable=True) for j in range(self.num_experts) ] for i in range(self.layer_num)] self.bias = [ self.get_variable(name='bias_{}'.format(i), shape=[1, dims], dtype=tf.float32, initializer=initializers.Zeros(), regularizer=None, trainable=True) for i in range(self.layer_num) ] return super(DCN, self).build(input_shape) def call(self, inputs, **kwargs): x0 = inputs xl = x0 for i in range(self.layer_num): if self.dcn_type == DCNType.Vector: xl = x0 * tf.matmul(xl, self.kernel[i]) + self.bias[i] + xl elif self.dcn_type == DCNType.Matrix: xl = x0 * (tf.matmul(xl, self.kernel[i]) + self.bias[i]) + xl else: output_of_experts = [] gating_score_of_experts = [] for expert_id in range(self.num_experts): # (1) G(x_l) # compute the gating score by x_l: (batch_size, 1) gating_score_of_experts.append(tf.matmul(xl, self.G[i][expert_id])) # (2) E(x_l) # project the input x_l to $\mathbb{R}^{r}$ v_x = tf.matmul(xl, self.V[i][expert_id]) # (batch_size, low_rank) v_x = tf.tanh(v_x) # nonlinear activation in low rank space cv_x = tf.matmul(v_x, self.C[i][expert_id]) # (batch_size, low_rank) cv_x = tf.tanh(cv_x) # project back to $\mathbb{R}^{d}$ ucv_x = tf.matmul(cv_x, self.U[i][expert_id], transpose_b=True) # (batch_size, num_feat) out = x0 * (ucv_x + self.bias[i]) output_of_experts.append(out) # (3) mixture of low-rank experts output_of_experts = tf.stack(output_of_experts, -1) # (batch_size, num_feat, num_experts) gating_score_of_experts = tf.stack(gating_score_of_experts, -2) # (bs, num_experts, 1) gating_score_of_experts = tf.nn.softmax(gating_score_of_experts, axis=-1) moe_out = tf.matmul(output_of_experts, gating_score_of_experts) xl = tf.squeeze(moe_out, -1) + xl if self.use_dropout and self.mode == tf.estimator.ModeKeys.TRAIN: xl = tf.nn.dropout(xl, rate=1 - self.keep_prob) return xl def get_variable(self, name, shape, dtype, initializer, regularizer, trainable): # ref https://arxiv.org/pdf/1602.07868.pdf if self.allow_kernel_norm: upper_ns = tf.compat.v1.get_default_graph().get_name_scope() var_init = initializer(shape, dtype) with tf.compat.v1.name_scope(f'{upper_ns}/{name}/') as name_scope: var_name = name_scope.strip('/') with tf.compat.v1.variable_scope('', reuse=tf.compat.v1.AUTO_REUSE): var = tf.compat.v1.get_variable(initializer=var_init, name=var_name, dtype=dtype, regularizer=regularizer, trainable=trainable) normalized = tf.nn.l2_normalize(var, axis=0, epsilon=1e-6, name='normalized_var') var_norm_init = tf.norm(var_init, axis=0, name='init_trainable_norm') if base_layer_utils.is_split_variable(var) or isinstance( var, variable_ops.PartitionedVariable): for v in var: K.track_variable(v) if trainable: self._trainable_weights.append(v) else: self._non_trainable_weights.append(v) else: K.track_variable(var) if trainable: self._trainable_weights.append(var) else: self._non_trainable_weights.append(var) with tf.compat.v1.variable_scope('', reuse=tf.compat.v1.AUTO_REUSE): trainable_var_norm = tf.compat.v1.get_variable( initializer=var_norm_init, name=f'{var_name}/trainable_norm', dtype=dtype) if base_layer_utils.is_split_variable(trainable_var_norm) or isinstance( trainable_var_norm, variable_ops.PartitionedVariable): for v in trainable_var_norm: K.track_variable(v) if trainable: self._trainable_weights.append(v) else: self._non_trainable_weights.append(v) else: K.track_variable(trainable_var_norm) if trainable: self._trainable_weights.append(trainable_var_norm) else: self._non_trainable_weights.append(trainable_var_norm) var = tf.multiply(normalized, trainable_var_norm, name='mul_var_norm') else: var = self.add_weight(initializer=initializer, shape=shape, name=name, dtype=dtype, regularizer=regularizer, trainable=trainable) return var def get_config(self): config = { 'layer_num': self.layer_num, 'dcn_type': self.dcn_type, 'initializer': initializers.serialize(self.initializer), 'regularizer': regularizers.serialize(self.regularizer), 'num_experts': self.num_experts, 'low_rank': self.low_rank, 'allow_kernel_norm': self.allow_kernel_norm, 'use_dropout': self.use_dropout, 'keep_prob': self.keep_prob, 'mode': self.mode } base_config = super(DCN, self).get_config() return dict(list(base_config.items()) + list(config.items())) @monolith_export @with_params class CIN(Layer): r"""Compressed Interaction Network, CIN, 压缩相互作用网络. 它是高阶(二阶以上)特征提取方法, 形式上是DCN与FM的结合体, 也是xDeepFM的核心. 论文可参考 https://arxiv.org/pdf/1703.04247.pdf DCN的计算: - :math:`x_{l+1} = f_{\theta}(x_0, x_l) + x_l`, 即它是一个残差网络, 并且每一层的计算都与 :math:`x_0` 有关 FM的计算: - 相对于LR, 增加了二阶交叉项, 并且用embedding的形式压缩表达, 计算特征交叉的方式是点积 CIN的计算: - 与DCN一样, 并且每一层的计算都与 :math:`x_0` 有关, 但是并不使用残差, :math:`f_{\theta}(x,y)` 不是线性的, 而是与FM类似, 用embedding计算得到, 但使用的不是点积(bit-wise), 而是对应元素相乘, 然后线性组合(vector-wise). :math:`f_{\theta}(x,y)` 是类似于FM的方法显式交叉, 所以它是一种显式高阶特征交叉方法 - 计算上, CIN还有一个特点是它可以转化成CNN高效计算 .. math:: X_{h,*}^k = \sum_{i=1}^{H_{k-1}} \sum_{j=1}^m W_{ij}^{k,k} (x_{i,*}^{k-1} \circ x_{j,*}^0) CIN的主要特点是: - 相互作用在vector-wise level, 而不是在bit-wise level - 高阶特征交叉是显性的, 而非隐性的 - 模型大小并不会随因交叉度增加而指数增加 Args: hidden_uints (:obj:`List[int]`): CIN隐含层uints个数 activation (:obj:`tf.activation`): 激活函数 initializer (:obj:`tf.initializer`): 初始化器 regularizer (:obj:`tf.regularizer`): 正则化器 """ def __init__(self, hidden_uints, activation=None, initializer='glorot_uniform', regularizer=None, **kwargs): super(CIN, self).__init__(**kwargs) self.hidden_uints = hidden_uints self.activation = activations.get(activation) self.initializer = initializers.get(initializer) self.regularizer = regularizers.get(regularizer) self._layer_num = len(self.hidden_uints) self._batch_size = None self._num_feat = None self._emb_size = None def build(self, input_shape): assert len(input_shape) == 3 (batch_size, num_feat, emb_size) = input_shape self._batch_size = check_dim(batch_size) self._num_feat = check_dim(num_feat) self._emb_size = check_dim(emb_size) self._conv1d = [] for i, uints in enumerate(self.hidden_uints): if i == 0: last_hidden_dim = num_feat else: last_hidden_dim = self.hidden_uints[i - 1] if i != self._layer_num - 1: self._conv1d.append( Conv1D(filters=uints, kernel_size=1, strides=1, activation=self.activation, kernel_initializer=self.initializer, kernel_regularizer=self.regularizer, input_shape=(emb_size, last_hidden_dim * num_feat))) else: self._conv1d.append( Conv1D(filters=uints, kernel_size=1, strides=1, kernel_initializer=self.initializer, kernel_regularizer=self.regularizer, input_shape=(emb_size, last_hidden_dim * num_feat))) self._trainable_weights.extend(self._conv1d[-1].trainable_weights) self._non_trainable_weights.extend(self._conv1d[-1].non_trainable_weights) return super(CIN, self).build(input_shape) def call(self, inputs, **kwargs): x0 = tf.transpose(inputs, perm=[0, 2, 1]) # (batch_size, emb_size, num_feat) xl = x0 final_result = [] for i in range(self._layer_num): # (batch_size, emb_size, -1) xl_last_dim = dim_size(xl, -1) zl = tf.reshape(tf.einsum('bdh,bdm->bdhm', xl, x0), shape=(self._batch_size, self._emb_size, xl_last_dim * self._num_feat)) xl = self._conv1d[i](zl) # (batch_size, emb_size, num_hidden) final_result.append(xl) return tf.concat([tf.reduce_sum(hi, axis=1) for hi in final_result], axis=1) def get_config(self): config = { 'hidden_uints': self.hidden_uints, 'activation': activations.serialize(self.activation), "initializer": initializers.serialize(self.initializer), "regularizer": regularizers.serialize(self.regularizer) } base_config = super(CIN, self).get_config() return dict(list(base_config.items()) + list(config.items())) ================================================ FILE: monolith/native_training/layers/feature_cross_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import numpy as np import tensorflow as tf from monolith.native_training.layers.feature_cross import GroupInt, AllInt, CDot, CAN, CIN, DCN class FeatureCrossTest(tf.test.TestCase): def test_groupint_instantiate(self): layer_template = GroupInt.params() test_params0 = layer_template.copy() test_params0.interaction_type = 'dot' test_params0.use_attention = False test_params0.attention_units = [128, 256, 1] test_params0.activation = 'relu' ins1 = test_params0.instantiate() print(ins1) ins2 = GroupInt(interaction_type='multiply', use_attention=True, attention_units=[128, 256, 1], activation='relu') print(ins2) def test_groupint_serde(self): ins1 = GroupInt(interaction_type='multiply', use_attention=True, attention_units=[128, 256, 1], activation='relu') cfg = ins1.get_config() ins2 = GroupInt.from_config(cfg) print(ins1, ins2) def test_groupint_call(self): layer_template = GroupInt.params() test_params0 = layer_template.copy() test_params0.name = 'test_dense0' test_params0.out_type = 'concat' layer = test_params0.instantiate() left = [tf.keras.backend.variable(np.ones((100, 10))) for _ in range(5)] right = [tf.keras.backend.variable(np.ones((100, 10))) for _ in range(3)] sum_out = tf.reduce_sum(layer((left, right))) with self.session() as sess: sess.run(tf.compat.v1.global_variables_initializer()) print(sess.run(sum_out)) def test_groupint_attention_call(self): layer = GroupInt(interaction_type='multiply', use_attention=True, attention_units=[15, 10, 1], activation='relu') left = [tf.keras.backend.variable(np.ones((100, 10))) for _ in range(5)] right = [tf.keras.backend.variable(np.ones((100, 10))) for _ in range(3)] sum_out = tf.reduce_sum(layer((left, right))) with self.session() as sess: sess.run(tf.compat.v1.global_variables_initializer()) print(sess.run(sum_out)) def test_allint_instantiate(self): layer_template = AllInt.params() test_params0 = layer_template.copy() test_params0.cmp_dim = 4 ins1 = test_params0.instantiate() print(ins1) ins2 = AllInt(cmp_dim=4) print(ins2) def test_allint_serde(self): layer_template = AllInt.params() test_params0 = layer_template.copy() test_params0.cmp_dim = 4 ins1 = test_params0.instantiate() print(ins1) cfg = ins1.get_config() ins2 = AllInt.from_config(cfg) print(ins1, ins2) def test_allint_call(self): layer_template = AllInt.params() test_params0 = layer_template.copy() test_params0.name = 'test_dense0' test_params0.cmp_dim = 4 layer = test_params0.instantiate() data = tf.keras.backend.variable(np.ones((100, 10, 10))) sum_out = tf.reduce_sum(layer(data)) with self.session() as sess: sess.run(tf.compat.v1.global_variables_initializer()) print(sess.run(sum_out)) def test_cdot_instantiate(self): layer_template = CDot.params() test_params0 = layer_template.copy() test_params0.project_dim = 8 test_params0.compress_units = [128, 256] test_params0.activation = 'tanh' ins1 = test_params0.instantiate() print(ins1) ins2 = CDot(project_dim=8, compress_units=[128, 256], activation='tanh') print(ins2) def test_cdot_serde(self): ins1 = CDot(project_dim=8, compress_units=[128, 256], activation='tanh') cfg = ins1.get_config() ins2 = CDot.from_config(cfg) print(ins1, ins2) def test_cdot_call(self): layer = CDot(project_dim=8, compress_units=[128, 256], activation='tanh') data = tf.keras.backend.variable(np.ones((100, 10, 10))) sum_out = tf.reduce_sum(layer(data)) with self.session() as sess: sess.run(tf.compat.v1.global_variables_initializer()) print(sess.run(sum_out)) def test_can_instantiate(self): layer_template = CAN.params() test_params0 = layer_template.copy() test_params0.layer_num = 8 test_params0.activation = 'sigmoid' test_params0.is_seq = False test_params0.is_stacked = True ins1 = test_params0.instantiate() print(ins1) ins2 = CAN(layer_num=8, activation='tanh', is_seq=False, is_stacked=True) print(ins2) def test_can_serde(self): ins1 = CAN(layer_num=8, activation='tanh', is_seq=False, is_stacked=True) cfg = ins1.get_config() ins2 = CAN.from_config(cfg) print(ins1, ins2) def test_can_seq_call(self): layer = CAN(layer_num=2, activation='relu', is_seq=True, is_stacked=True) user = tf.keras.backend.variable(np.ones((128, 10, 12, 10))) item = tf.keras.backend.variable(np.ones((128, 220))) sum_out = tf.reduce_sum(layer((user, item))) with self.session() as sess: sess.run(tf.compat.v1.global_variables_initializer()) print(sess.run(sum_out)) def test_can_call(self): layer = CAN(layer_num=2, activation='relu', is_seq=False, is_stacked=True) user = tf.keras.backend.variable(np.ones((128, 10, 10))) item = tf.keras.backend.variable(np.ones((128, 220))) sum_out = tf.reduce_sum(layer((user, item))) with self.session() as sess: sess.run(tf.compat.v1.global_variables_initializer()) print(sess.run(sum_out)) def test_dcn_instantiate(self): layer_template = DCN.params() test_params0 = layer_template.copy() test_params0.layer_num = 8 test_params0.dcn_type = 'matrix' test_params0.use_dropout = True test_params0.keep_prob = 0.5 ins1 = test_params0.instantiate() print(ins1) ins2 = DCN(layer_num=8, dcn_type='matrix', use_dropout=True, keep_prob=0.5) print(ins2) def test_dcn_serde(self): ins1 = DCN(layer_num=8, dcn_type='matrix', use_dropout=True, keep_prob=0.5) cfg = ins1.get_config() ins2 = DCN.from_config(cfg) print(ins1, ins2) def test_dcn_vector_call(self): layer = DCN(layer_num=2, dcn_type='vector', allow_kernel_norm=True, use_dropout=True, keep_prob=0.5) data = tf.keras.backend.variable(np.ones((128, 10, 10))) sum_out = tf.reduce_sum(layer(data)) with self.session() as sess: sess.run(tf.compat.v1.global_variables_initializer()) print(sess.run(sum_out)) def test_dcn_matrix_call(self): layer = DCN(layer_num=2, dcn_type='matrix', allow_kernel_norm=True, use_dropout=True, keep_prob=0.5) data = tf.keras.backend.variable(np.ones((128, 10, 10))) sum_out = tf.reduce_sum(layer(data)) with self.session() as sess: sess.run(tf.compat.v1.global_variables_initializer()) print(sess.run(sum_out)) def test_dcn_mixed_call(self): layer = DCN(layer_num=2, dcn_type='mixed', num_experts=2, low_rank=5, allow_kernel_norm=True, use_dropout=True, keep_prob=0.5) data = tf.keras.backend.variable(np.ones((128, 10, 10))) sum_out = tf.reduce_sum(layer(data)) with self.session() as sess: sess.run(tf.compat.v1.global_variables_initializer()) print(sess.run(sum_out)) def test_cin_instantiate(self): layer_template = CIN.params() test_params0 = layer_template.copy() test_params0.hidden_uints = [10, 5] test_params0.activation = 'sigmoid' ins1 = test_params0.instantiate() print(ins1) ins2 = CIN(hidden_uints=[10, 5], activation='tanh') print(ins2) def test_cin_serde(self): ins1 = CIN(hidden_uints=[10, 5], activation='tanh') cfg = ins1.get_config() ins2 = CIN.from_config(cfg) print(ins1, ins2) def test_cin_call(self): layer = CIN(hidden_uints=[10, 5], activation='relu') data = tf.keras.backend.variable(np.ones((128, 10, 10))) sum_out = tf.reduce_sum(layer(data)) with self.session() as sess: sess.run(tf.compat.v1.global_variables_initializer()) print(sess.run(sum_out)) if __name__ == '__main__': tf.compat.v1.disable_eager_execution() tf.compat.v1.disable_v2_behavior() tf.test.main() ================================================ FILE: monolith/native_training/layers/feature_seq.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from monolith.native_training.layers.advanced_activations import serialize import numpy as np import tensorflow as tf from tensorflow.keras.layers import Dense, Layer, InputSpec from tensorflow.python.keras import activations import tensorflow.keras.initializers as initializers from tensorflow.python.keras import regularizers from monolith.core.base_layer import add_layer_loss from monolith.native_training.layers.mlp import MLP from monolith.native_training.layers.agru import AGRUCell, dynamic_rnn_with_attention from monolith.native_training.utils import with_params from monolith.native_training.monolith_export import monolith_export from monolith.native_training.layers.utils import check_dim, dim_size @monolith_export @with_params class DIN(Layer): """Deep Interest Network, 是阿里的原创, 基于兴趣序列特征聚合, 论文可参考 https://arxiv.org/pdf/1706.06978.pdf 为了更好地描述用户, 仅用静态特征是不够的, 需要加入行为特征. 行为特征往往是一个序列, 如点击过的app, 购买过的商品等等. 一方面, 用户的行为是由内在兴趣(Interest)与外部条件(Target)一起促成的. 用户行为是用户兴趣的体现, 简单起见, 用户行为表示兴趣 DIN的三个假设: - Behavior/Interest: 将用户行为序列表示为embedding序列, 这个序列同时也表示用户兴趣 - Target Representation: 将用户物品(Target)表示为embedding, 它与行为/兴趣处于同一空间, 因为它能满足用户的兴趣, 促进行为的发生 - Interest Match: 用户对物品发生行为, 是因为物品满足了用户的`某些`兴趣, 用Attention来表达 为了简单, 以单个特征为例: - queries: 表示召回的物品(Target), emb_size为k, shape为(k, ) - keys : 表示用户序列特征(Interest), emb_size为k, 序列长长度为t, shape为(t, k) 先将queries tile成shape为(t, k), 即将数据copy t次, 使queries, key同shape. 然后作如下操作 din_all = concat([queries, keys, queries - keys, queries * keys]) 也就是将queries, keys及其差值, 乘值等concat起来, 然后输入MLP, 得到attention weight(即物品对兴趣的满足程度) attention_weight = mlp(din_all) 最后, 线性组合, 实现attention (兴趣汇总), 如下: output = matmul(attention_weight * keys) 结果的shape为(k, ), 与原始queries同shape. Args: hidden_units (:obj:`list`): DIN中MLP layers 的hidden_units, 最后一维为1 activation (:obj:`tf.activation`): 激活函数 initializer (:obj:`tf.initializer`): kernel/bias初始化器 regularizer (:obj:`tf.regularizer`): kernel正则化 mode (:obj:`str`): 输出模式, 如果为 `sum`, 则会进行线性组合, 反回的shape与queries一样, 否则只相乘不组合, 返架的shape与keys一样 decay (:obj:`bool`): 是否在attention weight上做decay, 默认为False """ def __init__(self, hidden_units, activation=None, initializer=None, regularizer=None, mode: str = 'sum', decay: bool = False, **kwargs): super(DIN, self).__init__(**kwargs) assert hidden_units[-1] == 1 self.input_spec = [InputSpec(ndim=2), InputSpec(ndim=3)] self.hidden_units = hidden_units self.activation = activations.get(activation) self.initializer = initializers.get( initializer) or initializers.GlorotNormal() self.regularizer = regularizers.get(regularizer) self.dense_tower = None self.mode = mode self.decay = decay def build(self, input_shape): self.dense_tower = MLP(name='compress_tower', activations=self.activation, output_dims=self.hidden_units, initializers=self.initializer, kernel_regularizer=self.regularizer) self._trainable_weights.extend(self.dense_tower.trainable_weights) self._non_trainable_weights.extend(self.dense_tower.non_trainable_weights) self.add_loss(self.dense_tower.losses) super(DIN, self).build(input_shape) def call(self, inputs, **kwargs): queries, keys = inputs mask = kwargs.get('mask', None) T, H = dim_size(keys, 1), dim_size(keys, 2) if self.hidden_units is None: self.hidden_units = [T, 1] # tf.tile(input, multiples, name=None), creates a new tensor by replicating `input` `multiples` times # The output tensor's i'th dimension has input.dims(i) * multiples[i] elements, # and the values of input are replicated multiples[i] times along the 'i'th dimension queries = tf.reshape(tf.tile(queries, [1, T]), [-1, T, H]) # [B, H] -> [B, T * H] --> [B, T, H] # DIN din_all = tf.concat([queries, keys, queries - keys, queries * keys], axis=-1) # [B, T, 4 * H] # dense_tower on the last dim, [B, T, 4 * H] -> [B, T, 1] attention_weight = self.dense_tower(din_all) if self.decay: attention_weight /= (H**0.5) # Mask if mask is not None: mask = tf.greater_equal(mask, tf.ones_like(mask)) key_masks = tf.expand_dims(mask, 2) # [B, T, 1] attention_weight = tf.where(key_masks, attention_weight, tf.zeros_like(attention_weight)) # [B, 1, T] tf.compat.v1.summary.histogram( '{name}_attention_outputs'.format(name=self.name), attention_weight) if self.mode == 'sum': # Weighted sum # [B, T, 1]^T * [B, T, H] -> [B, 1, H] attention_out = tf.matmul(attention_weight, keys, transpose_a=True) outputs = tf.squeeze(attention_out, [1]) # [B, 1, H] -> [B, H] else: # [B, T, H] * [B, T, 1] -> [B, T, H] outputs = keys * attention_weight return outputs def get_config(self): config = { 'hidden_units': self.hidden_units, 'activation': activations.serialize(self.activation), 'initializer': initializers.serialize(self.initializer), 'regularizer': regularizers.serialize(self.regularizer), } base_config = super(DIN, self).get_config() return dict(list(base_config.items()) + list(config.items())) @monolith_export @with_params class DIEN(Layer): """DIN的升级版, Deep Interest Evolution Network, 阿里原创, 基于兴趣演进的序列特征聚合, 论文可参考 https://arxiv.org/pdf/1809.03672.pdf 在推荐场景,用户无需输入搜索关键词来表达意图,这种情况下捕捉用户兴趣并考虑兴趣的动态变化将是提升模型效果的关键. 大多该类模型将用户的行为直接看做兴趣,而用户的潜在兴趣往往很难通过行为来完全表示, 需要挖掘行为背后的用户真实兴趣,并考虑用户兴趣的动态变化 DIEN的假设: - Behavior Layer: 也就是将用户行为序列表示为embedding序列, embedding表达的意义是行为本身, 不再直接代表兴趣, 这与DIN不同 - Interest Extractor Layer: 用GRU从用户行为中提取兴趣(Interest), 兴趣是随时间演变的, DIN没有考虑这一点 - Interest Evolving Layer: 随着外部环境(Target attention)和内部认知(Interest)的变化,用户兴趣也不断变化, 最终兴起促使行为发生 - 物品表示与DIN一样, 它与兴趣处于同一空间, 因为它能满足用户的兴趣, 促进行为的发生 - 物品与兴趣的关系建模与DIN不一样, DIN是静态地看物品能否满足用户兴趣; DIEN中, 用户兴趣是演进的(Evolving), 物品会诱导/挖掘用户兴趣 在网络结构上表示为AGRU, 即attention + GRU Args: num_units (:obj:`int`): GRU隐含层的大小 att_type (:obj:`str`): attention的类型, 目前支持AGRU/AUGRU两种 activation (:obj:`tf.activation`): 激活函数 initializer (:obj:`tf.initializer`): kernel/bias初始化器 regularizer (:obj:`tf.regularizer`): kernel正则化 """ def __init__(self, num_units, att_type='AGRU', activation=tf.keras.activations.relu, initializer=tf.initializers.HeUniform, regularizer=None, **kwargs): super(DIEN, self).__init__(**kwargs) self.num_units, self.att_type = num_units, att_type self.activation = tf.keras.activations.get(activation) self.initializer = initializers.get( initializer) or initializers.GlorotNormal() self.regularizer = regularizers.get(regularizer) def build(self, input_shape): self.gru_cell = tf.keras.layers.GRUCell( name='gru_cell', units=self.num_units, activation=self.activation, kernel_initializer=self.initializer, bias_initializer=tf.initializers.Zeros(), kernel_regularizer=self.regularizer) self._trainable_weights.extend(self.gru_cell.trainable_weights) self._non_trainable_weights.extend(self.gru_cell.non_trainable_weights) self.add_loss(self.gru_cell.losses) self.augru_cell = AGRUCell(name='augru_cell', units=self.num_units, activation=self.activation, att_type='AGRU', initializer=self.initializer, regularizer=self.regularizer) self._trainable_weights.extend(self.augru_cell.trainable_weights) self._non_trainable_weights.extend(self.augru_cell.non_trainable_weights) self.add_loss(self.augru_cell.losses) self.weight = self.add_weight(name='attention_weight', dtype=tf.float32, shape=(self.num_units, self.num_units), initializer=self.initializer, regularizer=self.regularizer) super(DIEN, self).build(input_shape) def _attention(self, queries, keys): emb_size = dim_size(keys, 2) query_weight = tf.reshape(tf.matmul(queries, self.weight, transpose_b=True), [-1, emb_size, 1]) logit = tf.squeeze(tf.matmul(keys, query_weight), [2]) return tf.nn.softmax(logit) def call(self, inputs, **kwargs): if isinstance(inputs, (list, tuple)): if len(inputs) == 3: queries, keys, mask = inputs[:] elif len(inputs) == 2: queries, keys = inputs[:] else: queries = inputs[0] keys = kwargs['keys'] else: queries = inputs keys = kwargs['keys'] # interest extractor layer to capture temporal interests outputs, _ = tf.compat.v1.nn.dynamic_rnn(self.gru_cell, keys, dtype=tf.float32) # [B, T, H] # interest evolving layer to capture interest evolving process that is relative to the target item attn_scores = self._attention(queries, outputs) # [B, T] _, final_state = dynamic_rnn_with_attention(self.augru_cell, outputs, attn_scores) # [B, T, H] return final_state def get_config(self): config = { 'num_units': self.num_units, 'att_type': self.att_type, 'activation': activations.serialize(self.activation), 'initializer': initializers.serialize(self.initializer), 'regularizer': regularizers.serialize(self.regularizer), } base_config = super(DIEN, self).get_config() return dict(list(base_config.items()) + list(config.items())) @monolith_export @with_params class DMR_U2I(Layer): """Deep Match to Rank, DMR, 深度配匹排序, 与RNN不同, 主要考虑序列顺序 与DIN一样, 还DMR还是用attention的方式来聚合序列特征. 不同的是MR考虑了序列顺序, 即增加了位置embedding来处理用户序列的选后顺序 由于原始论文中最后的输出是点积, 梯度回传时只有一个值, 会导致训练不充分, 所以引入辅助loss, 但是辅助loss要用到负采样, 系统实现上比较 麻烦, 这里用element wise乘积代替点积, 去除辅助loss. 论文可参考 https://ojs.aaai.org/index.php/AAAI/article/view/5346/5202 Args: cmp_dim (:obj:`int`): 压缩维度 activation (:obj:`tf.activation`): 激活函数 initializer (:obj:`tf.initializer`): kernel/bias初始化器 regularizer (:obj:`tf.regularizer`): kernel正则化 """ def __init__(self, cmp_dim: int, activation="PReLU", initializer="glorot_uniform", regularizer=None, **kwargs): super(DMR_U2I, self).__init__(**kwargs) self.cmp_dim = cmp_dim self.activation = activations.get(activation) self.initializer = initializers.get( initializer) or initializers.GlorotNormal() self.regularizer = regularizers.get(regularizer) def build(self, input_shape): item_sh, user_seq_sh = input_shape (bs1, seq_length, ue_size) = tuple(map(check_dim, user_seq_sh)) (bs2, ie_size) = tuple(map(check_dim, item_sh)) assert bs1 == bs2 # position embedding self.pos_emb = self.add_weight(name="pos_emb", shape=(seq_length, self.cmp_dim), initializer=self.initializer, regularizer=self.regularizer) self.emb_weight = self.add_weight(name="emb_weight", shape=(ue_size, self.cmp_dim), initializer=self.initializer, regularizer=self.regularizer) self.z_weight = self.add_weight(name="z_weight", shape=(self.cmp_dim, 1), initializer=initializers.Ones()) self.bias = self.add_weight(name="bias", shape=(self.cmp_dim,), initializer=initializers.Zeros()) self.linear = Dense(name="dense", units=ie_size, activation=self.activation, kernel_initializer=self.initializer, kernel_regularizer=self.regularizer, use_bias=True) self._trainable_weights.extend(self.linear.trainable_weights) self._non_trainable_weights.extend(self.linear.non_trainable_weights) def call(self, inputs, **kwargs): items, user_seq = inputs # 1) calculate compressed represention emb_cmp = tf.matmul(user_seq, self.emb_weight) # (bs, seq_length, cmp_dim) comped = self.pos_emb + emb_cmp + self.bias # (bs, seq_length, cmp_dim) # 2) prepare attention weight # (bs, seq_length, cmp_dim) * (cmp_dim, 1) -> (bs, seq_length, 1) alpha = tf.matmul(comped, self.z_weight) # (bs, seq_length, 1) alpha = tf.nn.softmax(alpha, axis=1) # (bs, seq_length, 1) # 3) execute attention user_seq_trans = tf.transpose(user_seq, perm=(0, 2, 1)) # (bs, ue_size, seq_length) # (bs, ue_size, seq_length) * (bs, seq_length, 1) -> (bs, ue_size, 1) -> (bs, ue_size) user_seq_merged = tf.squeeze(tf.matmul(user_seq_trans, alpha), axis=-1) # (bs, ue_size) # 4) linear transform user_seq_merged = self.linear(user_seq_merged) # (bs, ie_size) return user_seq_merged * items def get_config(self): config = { 'cmp_dim': self.cmp_dim, 'activation': activations.serialize(self.activation), 'initializer': initializers.serialize(self.initializer), 'regularizer': regularizers.serialize(self.regularizer), } base_config = super(DMR_U2I, self).get_config() return dict(list(base_config.items()) + list(config.items())) ================================================ FILE: monolith/native_training/layers/feature_seq_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import numpy as np import tensorflow as tf from monolith.native_training.layers.feature_seq import DIN, DIEN, DMR_U2I class FeatureSeqTest(tf.test.TestCase): def test_din_instantiate(self): layer_template = DIN.params() test_params0 = layer_template.copy() test_params0.hidden_units = [10, 1] test_params0.initializer = tf.initializers.GlorotNormal() ins1 = test_params0.instantiate() print(ins1) ins2 = DIN(hidden_units=[10, 1], initializer=tf.initializers.HeUniform()) print(ins2) def test_din_serde(self): ins1 = DIN(hidden_units=[10, 1], initializer=tf.initializers.HeUniform()) cfg = ins1.get_config() ins2 = DIN.from_config(cfg) print(ins1, ins2) def test_din_call(self): layer = DIN(hidden_units=[10, 1], initializer=tf.initializers.HeUniform()) query = tf.keras.backend.variable(np.ones((100, 10))) keys = tf.keras.backend.variable(np.ones((100, 15, 10))) out = layer((query, keys)) sum_out = tf.reduce_sum(out) with self.session() as sess: sess.run(tf.compat.v1.global_variables_initializer()) print(sess.run(sum_out)) def test_dien_instantiate(self): layer_template = DIEN.params() test_params0 = layer_template.copy() test_params0.num_units = 10 test_params0.initializer = tf.initializers.GlorotNormal() ins1 = test_params0.instantiate() print(ins1) ins2 = DIEN(num_units=10, initializer=tf.initializers.HeUniform()) print(ins2) def test_dien_serde(self): ins1 = DIEN(num_units=10, initializer=tf.initializers.HeUniform()) cfg = ins1.get_config() ins2 = DIEN.from_config(cfg) print(ins1, ins2) def test_dien_call(self): layer = DIEN(num_units=10, initializer=tf.initializers.HeUniform()) query = tf.keras.backend.variable(np.ones((100, 10))) keys = tf.keras.backend.variable(np.ones((100, 15, 10))) out = layer((query, keys)) sum_out = tf.reduce_sum(out) with self.session() as sess: sess.run(tf.compat.v1.global_variables_initializer()) print(sess.run(sum_out)) def test_dmr_instantiate(self): layer_template = DMR_U2I.params() test_params0 = layer_template.copy() test_params0.cmp_dim = 10 test_params0.activation = 'relu' test_params0.initializer = tf.initializers.GlorotNormal() ins1 = test_params0.instantiate() print(ins1) ins2 = DMR_U2I(cmp_dim=10, activation='relu', initializer=tf.initializers.HeUniform()) print(ins2) def test_dmr_serde(self): ins1 = DMR_U2I(cmp_dim=10, activation='relu', initializer=tf.initializers.HeUniform()) cfg = ins1.get_config() ins2 = DMR_U2I.from_config(cfg) print(ins1, ins2) def test_dmr_call(self): layer = DMR_U2I(cmp_dim=5, activation='relu', initializer=tf.initializers.HeUniform()) query = tf.keras.backend.variable(np.ones((100, 10))) keys = tf.keras.backend.variable(np.ones((100, 15, 10))) out = layer((query, keys)) sum_out = tf.reduce_sum(out) with self.session() as sess: sess.run(tf.compat.v1.global_variables_initializer()) print(sess.run(sum_out)) if __name__ == '__main__': tf.compat.v1.disable_eager_execution() tf.test.main() ================================================ FILE: monolith/native_training/layers/feature_trans.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import numpy as np import tensorflow as tf from tensorflow.keras.layers import Layer, InputSpec import tensorflow.keras.initializers as initializers from tensorflow.python.keras import regularizers from monolith.core.base_layer import add_layer_loss from monolith.native_training.layers.mlp import MLP from monolith.native_training.layers.utils import merge_tensor_list from monolith.native_training.utils import with_params from monolith.native_training.monolith_export import monolith_export from monolith.native_training.layers.utils import check_dim, dim_size @monolith_export @with_params class AutoInt(Layer): r"""Auto-Interaction的缩写, 基于Self-attention的特征变换. 论文可参考 https://arxiv.org/pdf/1810.11921.pdf 一个样本有n个特征, 每个特征用一个k维的embedding表示, 则样本可以表示为(n, k)的矩阵. 所谓attention, 本质上是一种线性组合, 关键是确定组合系数 AutoInt中确定组合系数的方式为: .. math:: coeff_{n, n} = softmax( X_{n, k} * X_{n, k}^T ) 即先计算自相关, 确定特征与其它特征的`相似性`, 然后用softmax的方式归一化, 得到组合系数. 最后是组性组合, 计算attention: .. math:: O_{n, k} = coeff_{n, n} * X_{n, k} 在AutoInt中, 上述过程可以迭代进行多次, 一次为一个layer Args: layer_num (:obj:`int`): auto int layer的层数, 一层为一个完整的auto int out_type (:obj:`str`): 输出类型, 可以为stack, concat, None keep_list (:obj:`bool`): 输出是否保持list """ def __init__(self, layer_num=1, out_type='concat', keep_list: bool = False, **kwargs): super(AutoInt, self).__init__(**kwargs) self.layer_num = layer_num self.out_type = out_type self.keep_list = keep_list def call(self, embeds, **kwargs): assert len(embeds.shape) == 3 autoint_input = embeds for i in range(self.layer_num): layer_name = '{name}_{idx}'.format(name=self.name, idx=i) with tf.name_scope(layer_name): # [batch_size, num_feat, emb_dim] -> [batch_size, num_feat, num_feat] attn = tf.nn.softmax(tf.matmul(autoint_input, autoint_input, transpose_b=True), axis=-1) autoint_input = tf.matmul(attn, autoint_input) # [batch, num_feats, emb_dim] return merge_tensor_list(autoint_input, merge_type=self.out_type, keep_list=self.keep_list) def get_config(self): config = { 'layer_num': self.layer_num, 'out_type': self.out_type, 'keep_list': self.keep_list } base_config = super(AutoInt, self).get_config() return dict(list(base_config.items()) + list(config.items())) @with_params class iRazor(Layer): """特征选择和Embedding维度搜索 一个样本有n个特征, 每个特征用一个k维的embedding表示. 可以为每一个Embedding分配一个先择概率(一个0~1之间的数), 训练出完成后, 如果概率较大, 则保留, 否则去除, 从而实现Embedding维度搜索. 也可以给特征分配一个移除概率, 这个概率与Embedding分配的概率可以归一化, 如果概率越大, 则将特征移除. 当训练完成 后, 可以用后剪枝算法CPT(cumulative probability threshold, 累积概率阈值)来对网络裁剪, 从而达到特征选择和Embedding维度搜索的目的 .. note:: 从另一个角度看, 不同的特征, 重要程度不一样, 同一个特征Embedding, 不同的维度重要程度也不一样. 常用内积, 在Euclid空间中计算内积就是点乘, 因此每个维度的重要性一样. 所以可以引入一个`度量空间`, 在这个空间算内积. 为了简单, 度量矩阵用对角阵(半正定), 此时, 直观理解就是每个embedding维度权重不 一样, 而且权重匀为正数, 是可以学习的. 前面的分析是假设`度量空间`存在, 可以为`度量空间`情况也分配权重, 而且这个权重与Embedding权重是归一化的, 从而实现不同特征重要性不一样. 此时, iRazor的目的是做特征变换 给定一个 nas_space, 假设emb_size=8, 则nas_space=[0, 1, 3, 5, 8], 是对embedding的一个划分:{}, {0}, {1, 2}, {3, 4}, {5, 6, 7}, 共5段, 每段出现的概率为p_i .. code-block:: text rigid_masks = [ [0, 0, 0, 0, 0, 0, 0, 0], -> p_0, 表示`度量空间`不存在的概率 [1, 0, 0, 0, 0, 0, 0, 0], -> p_1, 表示0号位置的概率/重要性 [0, 1, 1, 0, 0, 0, 0, 0], -> p_2, 表示1-2号位置的概率/重要性 [0, 0, 0, 1, 1, 0, 0, 0], -> p_3, 表示3-4号位置的概率/重要性 [0, 0, 0, 0, 0, 1, 1, 1] -> p_4, 表示5-8号位置的概率/重要性 ] P = (p_0, p_1, p_2, p_3, p_4), 且有 p_0 + ... + p_4 = 1 soft_masks = P * rigid_masks = (p_1, p_2, p_2, p_3, p_3, p_4, p_4, p_4) 从上面可以看出, nas_space是对测度空间的限制, 强制某几个维度(分组)有相同的权重. 如果 nas_space = [0,1,2,3,4,5,6,7,8], 可以去除这种强制. rigid_masks中第一行全为0, 表示表示`度量空间`不存在. 可以加一个辅助loss, 强制`度量空间`不存在, 因为可以减少参数/省内存/评估特征重要性 .. code-block:: text loss = feature_weight * sum(soft_masks) Args: nas_space (:obj:`list`): 用于定义embedding特征分组, 第一个元素是0, 最后一个元素是emb_size, 元素是有序的. 0表示`度量空间`不存在, nas_space[i-1]:nas_space[i] 表示一个分组, 位于同一组内的元素有相同的权重 t (:obj:`float`): softmax平滑因子 initializer (:obj:`tf.initializer`): kernel/bias初始化器 regularizer (:obj:`tf.regularizer`): kernel正则化 feature_weight (:obj:`tf.Tensor`): 特征权重, 用于计算辅助loss out_type (:obj:`str`): 输出类型, 可以为stack, concat, None keep_list (:obj:`bool`): 输出是否保持list """ def __init__(self, nas_space, t=0.05, initializer=None, regularizer=None, feature_weight=None, out_type='concat', keep_list=False, **kwargs): super(iRazor, self).__init__(**kwargs) self.out_type = out_type self.keep_list = keep_list self.nas_space = nas_space self.t = t self.nas_logits = None self.emb_size = max(self.nas_space) self.nas_len = len(self.nas_space) self.initializer = initializers.get(initializer) self.regularizer = regularizers.get(regularizer) if feature_weight is not None: if isinstance(feature_weight, (tf.Tensor, tf.Variable)): self.feature_weight = tf.reshape(feature_weight, shape=(1, -1)) else: self.feature_weight = tf.constant(feature_weight, shape=(1, -1), dtype=tf.float32) else: self.feature_weight = feature_weight @property def rigid_masks(self): masks = np.zeros(shape=(self.nas_len, self.emb_size), dtype=np.float32) for i, j in enumerate(self.nas_space): if i > 0: masks[i, self.nas_space[i - 1]:j] = 1.0 return tf.constant(masks, name="masks", dtype=tf.float32) def build(self, input_shape): # input_shape: [bath_size, num_feat, emb_dim] shape = (check_dim(input_shape[1]), self.nas_len) self.nas_logits = self.add_weight(name="nas_weight", shape=shape, dtype=tf.float32, initializer=self.initializer, regularizer=self.regularizer) super(iRazor, self).build(input_shape) def call(self, embeds, **kwargs): assert check_dim(embeds.shape[-1]) == max(self.nas_space) nas_weight = tf.nn.softmax(self.nas_logits / self.t, axis=1, name="nas_concat_choice_probs") tf.compat.v1.summary.histogram(name='nas_weight', values=nas_weight) # create soft mask for each embedding dim with nas soft_masks = tf.matmul(nas_weight, self.rigid_masks, name="choice_matrix") if self.feature_weight is not None: nas_loss = tf.matmul(self.feature_weight, tf.reduce_sum(soft_masks, axis=1, keepdims=True)) add_layer_loss(self.name, tf.reduce_sum(nas_loss)) # re-weight embeds out_embeds = embeds * soft_masks return merge_tensor_list(out_embeds, merge_type=self.out_type, keep_list=self.keep_list) def get_config(self): config = { 'nas_space': self.nas_space, 't': self.t, 'initializer': initializers.serialize(self.initializer), 'regularizer': regularizers.serialize(self.regularizer), 'feature_weight': self.feature_weight, 'out_type': self.out_type, 'keep_list': self.keep_list, } base_config = super(iRazor, self).get_config() return dict(list(base_config.items()) + list(config.items())) @monolith_export @with_params class SeNet(Layer): """SeNet最早用于图像中, 这里是借用其概念, 不同特征具有不同重要性. 论文可参考 https://arxiv.org/pdf/1709.01507.pdf 一个样本有n个特征, 每个特征用一个k维的embedding表示. 但是并不是每个特征都一样重要, 所以想给每个特征一个权重, 以调整其重要性. 权重计算是用一个MLP完成的, 一般有三层input - cmp_layer - output. 其中input/output是同shape的, input是通过 reduce_mean输入矩阵(n, k)的最后一维得到. 最后用 weight(n) * (n, k) 为特征加权 Args: num_feature (:obj:`int`): 输入特征数 cmp_dim (:obj:`int`): 压缩维的维度 initializer (:obj:`tf.initializer`): kernel/bias初始化器 kernel_regularizer (:obj:`tf.regularizer`): kernel正则化 bias_regularizer (:obj:`tf.regularizer`): bias正则化 on_gpu: 计算是否发生在GPU上, 如果是, 则用GPU优化版本 out_type (:obj:`str`): 输出类型, 可以为stack, concat, None keep_list (:obj:`bool`): 输出是否保持list """ def __init__(self, num_feature, cmp_dim, initializer=None, regularizer=None, on_gpu=False, out_type='concat', keep_list=False, **kwargs): super(SeNet, self).__init__(**kwargs) self.num_feat = num_feature self.cmp_dim = cmp_dim self.initializer = initializers.get(initializer) self.regularizer = regularizers.get(regularizer) self.on_gpu = on_gpu self.out_type = out_type self.keep_list = keep_list def build(self, input_shape): if self.cmp_dim is None: self.cmp_tower = lambda x: x else: self.cmp_tower = MLP(name='cmp_tower', output_dims=[self.cmp_dim, self.num_feat], activations=['relu', 'sigmoid'], initializers=self.initializer, kernel_regularizer=self.regularizer) self._trainable_weights.extend(self.cmp_tower.trainable_weights) self._non_trainable_weights.extend(self.cmp_tower.non_trainable_weights) self.add_loss(self.cmp_tower.losses) super(SeNet, self).build(input_shape) def call(self, inputs, **kwargs): senet_input_concat, emb_dim = None, None if isinstance(inputs, (tf.Tensor, tf.Variable)): # [batch_size, slots_num, emb_dim] num_feat, emb_dim = dim_size(inputs, 1), dim_size(inputs, 2) senet_input_concat = tf.reshape(inputs, [-1, num_feat, emb_dim]) sequeeze_embedding = tf.reduce_mean(senet_input_concat, axis=2) # [batch, slots_num] else: # isinstance(inputs, (list, tuple)) num_feat = len(inputs) if self.on_gpu: slots_lens = [dim_size(embed, 1) for embed in inputs] ids = tf.constant( np.concatenate([[i] * length for i, length in enumerate(slots_lens) ])) lens = tf.constant([1.0 / slot_len for slot_len in slots_lens]) concat_trans = tf.transpose(tf.concat(inputs, axis=1)) sequeeze_embedding = tf.compat.v1.segment_sum(concat_trans, ids) sequeeze_embedding = tf.transpose(sequeeze_embedding) sequeeze_embedding = tf.reshape(sequeeze_embedding, shape=(-1, num_feat)) sequeeze_embedding = tf.multiply(sequeeze_embedding, lens) else: sequeeze_embedding = tf.concat( [tf.reduce_mean(embed, axis=1, keepdims=True) for embed in inputs], axis=1, name='concat') weight_out = self.cmp_tower(sequeeze_embedding) if isinstance(inputs, (tf.Tensor, tf.Variable)): # [batch, num_feat] -> # [batch, num_feat, 1] weight_out = tf.reshape(weight_out, [-1, num_feat, 1]) senet_weighted = tf.multiply(weight_out, senet_input_concat) else: weight_out = tf.split(weight_out, num_feat, axis=1) senet_weighted = [ tf.multiply(embed, weight) for embed, weight in zip(inputs, weight_out) ] return merge_tensor_list(senet_weighted, merge_type=self.out_type, keep_list=self.keep_list, num_feature=num_feat) def get_config(self): config = { 'num_feature': self.num_feat, 'cmp_dim': self.cmp_dim, 'initializer': initializers.serialize(self.initializer), 'regularizer': regularizers.serialize(self.regularizer), 'on_gpu': self.on_gpu, 'out_type': self.out_type, 'keep_list': self.keep_list } base_config = super(SeNet, self).get_config() return dict(list(base_config.items()) + list(config.items())) ================================================ FILE: monolith/native_training/layers/feature_trans_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import numpy as np import tensorflow as tf from monolith.native_training.layers.feature_trans import AutoInt, iRazor, SeNet class FeatureTransTest(tf.test.TestCase): def test_autoint_instantiate(self): layer_template = AutoInt.params() test_params0 = layer_template.copy() test_params0.layer_num = 1 ins1 = test_params0.instantiate() print(ins1) ins2 = AutoInt(layer_num=1) print(ins2) def test_autoint_serde(self): layer_template = AutoInt.params() test_params0 = layer_template.copy() test_params0.layer_num = 1 ins1 = test_params0.instantiate() print(ins1) cfg = ins1.get_config() ins2 = AutoInt.from_config(cfg) print(ins1, ins2) def test_autoint_call(self): layer_template = AutoInt.params() test_params0 = layer_template.copy() test_params0.name = 'test_dense0' test_params0.layer_num = 2 layer = test_params0.instantiate() data = tf.keras.backend.variable(np.ones((100, 10, 10))) sum_out = tf.reduce_sum(layer(data)) with self.session() as sess: sess.run(tf.compat.v1.global_variables_initializer()) print(sess.run(sum_out)) def test_senet_instantiate(self): layer_template = SeNet.params() test_params0 = layer_template.copy() test_params0.num_feature = 10 test_params0.cmp_dim = 4 test_params0.initializer = tf.initializers.GlorotNormal() ins1 = test_params0.instantiate() print(ins1) ins2 = SeNet(num_feature=10, cmp_dim=4, initializer=tf.initializers.HeUniform()) print(ins2) def test_senet_serde(self): ins1 = SeNet(num_feature=10, cmp_dim=4, initializer=tf.initializers.HeUniform()) cfg = ins1.get_config() ins2 = SeNet.from_config(cfg) print(ins1, ins2) def test_senet_call(self): layer_template = SeNet.params() test_params0 = layer_template.copy() test_params0.num_feature = 10 test_params0.cmp_dim = 4 test_params0.initializer = tf.initializers.GlorotNormal() layer = test_params0.instantiate() data = tf.keras.backend.variable(np.ones((100, 10, 10))) sum_out = tf.reduce_sum(layer(data)) with self.session() as sess: sess.run(tf.compat.v1.global_variables_initializer()) print(sess.run(sum_out)) def test_irazor_instantiate(self): layer_template = iRazor.params() test_params0 = layer_template.copy() test_params0.nas_space = [0, 2, 5, 7, 10] test_params0.initializer = tf.initializers.GlorotNormal() ins1 = test_params0.instantiate() print(ins1) ins2 = iRazor(nas_space=[0, 2, 5, 7, 10], t=0.08, initializer=tf.initializers.HeUniform()) print(ins2) def test_irazor_serde(self): ins1 = iRazor(nas_space=[0, 2, 5, 7, 10], t=0.08, initializer=tf.initializers.HeUniform()) cfg = ins1.get_config() ins2 = iRazor.from_config(cfg) print(ins1, ins2) def test_irazor_call(self): layer = iRazor(nas_space=[0, 2, 5, 7, 10], t=0.08, initializer=tf.initializers.HeUniform()) data = tf.keras.backend.variable(np.ones((100, 10, 10))) out = layer(data) sum_out = tf.reduce_sum(out) with self.session() as sess: sess.run(tf.compat.v1.global_variables_initializer()) print(sess.run(sum_out)) if __name__ == '__main__': tf.compat.v1.disable_eager_execution() tf.test.main() ================================================ FILE: monolith/native_training/layers/kernels/feature_insight_kernels.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include #include #include #include #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/framework/tensor.h" namespace tensorflow { namespace monolith_tf { class FeatureInsightOp : public OpKernel { public: explicit FeatureInsightOp(OpKernelConstruction *ctx) : OpKernel(ctx) { std::vector segment_sizes; OP_REQUIRES_OK(ctx, ctx->GetAttr("segment_sizes", &segment_sizes)); int32 idx = 0; num_feature_ = segment_sizes.size(); for (int32 size : segment_sizes) { for (int i = 0; i < size; ++i) { segment_id_map_.push_back(idx); } idx++; } } void Compute(OpKernelContext *ctx) override { const Tensor *input_tensor; OP_REQUIRES_OK(ctx, ctx->input("input", &input_tensor)); auto input_mat = input_tensor->matrix(); const Tensor *weight_tensor; OP_REQUIRES_OK(ctx, ctx->input("weight", &weight_tensor)); auto weight_mat = weight_tensor->matrix(); int64 batch_size = input_tensor->dim_size(0); int64 out_dim = weight_tensor->dim_size(1); Tensor *out_tensor; OP_REQUIRES_OK( ctx, ctx->allocate_output( "output", {batch_size, num_feature_ * out_dim}, &out_tensor)); auto out_mat = out_tensor->matrix(); out_mat.setZero(); Tensor tmp_tensor; OP_REQUIRES_OK(ctx, ctx->allocate_temp(out_tensor->dtype(), {num_feature_}, &tmp_tensor)); auto tmp_mat = tmp_tensor.flat(); for (size_t i = 0; i < batch_size; ++i) { // batch_size for (size_t k = 0; k < out_dim; ++k) { // out_size tmp_mat.setZero(); for (size_t j = 0; j < input_tensor->dim_size(1); ++j) { // total_embedding_size int32 idx = segment_id_map_[j]; tmp_mat(idx) += input_mat(i, j) * weight_mat(j, k); } for (size_t idx = 0; idx < num_feature_; ++idx) { out_mat(i, idx * out_dim + k) += tmp_mat(idx); } } } } private: int64 num_feature_; std::vector segment_id_map_; }; class FeatureInsightGradOp : public OpKernel { public: explicit FeatureInsightGradOp(OpKernelConstruction *ctx) : OpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("segment_sizes", &segment_sizes_)); int K; OP_REQUIRES_OK(ctx, ctx->GetAttr("K", &K)); num_feature_ = segment_sizes_.size(); int grad_dim = num_feature_ * K; grad_dim_to_k_.reserve(grad_dim); grad_dim_to_feature_idx_.reserve(grad_dim); feature_idx_to_embedding_start_.reserve(num_feature_); for (int i = 0; i < num_feature_; ++i) { for (int j = 0; j < K; ++j) { grad_dim_to_feature_idx_.push_back(i); grad_dim_to_k_.push_back(j); } if (i == 0) { feature_idx_to_embedding_start_.push_back(0); } else { feature_idx_to_embedding_start_.push_back( feature_idx_to_embedding_start_[i - 1] + segment_sizes_[i - 1]); } } } void Compute(OpKernelContext *ctx) override { const Tensor *grad_tensor; OP_REQUIRES_OK(ctx, ctx->input("grad", &grad_tensor)); auto grad_mat = grad_tensor->matrix(); const Tensor *input_tensor; OP_REQUIRES_OK(ctx, ctx->input("input", &input_tensor)); auto input_mat = input_tensor->matrix(); const Tensor *weight_tensor; OP_REQUIRES_OK(ctx, ctx->input("weight", &weight_tensor)); auto weight_mat = weight_tensor->matrix(); Tensor *input_grad_tensor; OP_REQUIRES_OK(ctx, ctx->allocate_output("input_grad", input_tensor->shape(), &input_grad_tensor)); auto input_grad_mat = input_grad_tensor->matrix(); input_grad_mat.setZero(); Tensor *weight_grad_tensor; OP_REQUIRES_OK(ctx, ctx->allocate_output("weight_grad", weight_tensor->shape(), &weight_grad_tensor)); auto weight_grad_mat = weight_grad_tensor->matrix(); weight_grad_mat.setZero(); int64 batch_size = input_tensor->dim_size(0); int64 grad_dim = grad_tensor->dim_size(1); LOG(INFO) << "get data done! batch_size=" << batch_size << ", grad_dim=" << grad_dim; for (size_t i = 0; i < batch_size; ++i) { // batch_size for (size_t g = 0; g < grad_dim; ++g) { int k = grad_dim_to_k_[g]; int feature_idx = grad_dim_to_feature_idx_[g]; int start = feature_idx_to_embedding_start_[feature_idx]; int end = start + segment_sizes_[feature_idx]; float grad_val = grad_mat(i, g); for (int j = start; j < end; ++j) { weight_grad_mat(j, k) += grad_val * input_mat(i, j); input_grad_mat(i, j) += grad_val * weight_mat(j, k); } } } } private: int64 num_feature_; std::vector segment_sizes_; std::vector grad_dim_to_k_; std::vector grad_dim_to_feature_idx_; std::vector feature_idx_to_embedding_start_; }; namespace { REGISTER_KERNEL_BUILDER(Name("FeatureInsight").Device(DEVICE_CPU), FeatureInsightOp) REGISTER_KERNEL_BUILDER(Name("FeatureInsightGrad").Device(DEVICE_CPU), FeatureInsightGradOp) } // namespace } // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/layers/kernels/ffm_kernels.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/layers/kernels/ffm_kernels.h" #include #include #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" namespace tensorflow { namespace monolith_tf { using CPUDevice = Eigen::ThreadPoolDevice; template <> struct FFMImpl { static void Compute(OpKernelContext *ctx, const std::string &int_type, TTypes::ConstMatrix left_matrix, int left_feat_num, TTypes::ConstMatrix right_matrix, int right_feat_num, int batch_size, int dim_size, TTypes::Matrix output) { output.setZero(); for (int l = 0; l < left_feat_num; ++l) { int l_idx = l * dim_size; for (int r = 0; r < right_feat_num; ++r) { int r_idx = r * dim_size; if (int_type == "dot") { int o_idx = l * right_feat_num + r; for (int b = 0; b < batch_size; ++b) { for (int k = 0; k < dim_size; ++k) { output(b, o_idx) += left_matrix(b, l_idx + k) * right_matrix(b, r_idx + k); } } } else { int o_idx = (l * right_feat_num + r) * dim_size; for (int b = 0; b < batch_size; ++b) { for (int k = 0; k < dim_size; ++k) { output(b, o_idx + k) = left_matrix(b, l_idx + k) * right_matrix(b, r_idx + k); } } } } } } }; template <> struct FFMGradImpl { static void Compute(OpKernelContext *ctx, const std::string &int_type, TTypes::ConstMatrix grad_matrix, int grad_feat_num, TTypes::ConstMatrix left_matrix, int left_feat_num, TTypes::ConstMatrix right_matrix, int right_feat_num, int batch_size, int dim_size, TTypes::Matrix left_grad_matrix, TTypes::Matrix right_grad_matrix) { left_grad_matrix.setZero(); right_grad_matrix.setZero(); for (int g = 0; g < grad_feat_num; ++g) { int l_idx = (g / right_feat_num) * dim_size; int r_idx = (g % right_feat_num) * dim_size; if (int_type == "dot") { for (int b = 0; b < batch_size; ++b) { for (int k = 0; k < dim_size; ++k) { left_grad_matrix(b, l_idx + k) += grad_matrix(b, g) * right_matrix(b, r_idx + k); right_grad_matrix(b, r_idx + k) += grad_matrix(b, g) * left_matrix(b, l_idx + k); } } } else { int g_idx = g * dim_size; for (int b = 0; b < batch_size; ++b) { for (int k = 0; k < dim_size; ++k) { left_grad_matrix(b, l_idx + k) += grad_matrix(b, g_idx + k) * right_matrix(b, r_idx + k); right_grad_matrix(b, r_idx + k) += grad_matrix(b, g_idx + k) * left_matrix(b, l_idx + k); } } } } } }; namespace { REGISTER_KERNEL_BUILDER(Name("FFM").Device(DEVICE_CPU), FFMOp) REGISTER_KERNEL_BUILDER(Name("FFMGrad").Device(DEVICE_CPU), FFMGradOp) } // namespace } // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/layers/kernels/ffm_kernels.cu.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #if GOOGLE_CUDA #define EIGEN_USE_GPU #include "monolith/native_training/layers/kernels/ffm_kernels.h" #include #include #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/kernels/gpu_device_array.h" #include "tensorflow/core/kernels/gpu_device_array_gpu.h" #include "tensorflow/core/util/gpu_kernel_helper.h" #include "tensorflow/core/util/gpu_launch_config.h" namespace tensorflow { namespace monolith_tf { using GPUDevice = Eigen::GpuDevice; __global__ void FFMKernelMultiply(TTypes::ConstMatrix left_matrix, int left_feat_num, TTypes::ConstMatrix right_matrix, int right_feat_num, int batch_size, int dim_size, TTypes::Matrix output) { GPU_1D_KERNEL_LOOP(b, batch_size) { for (int l = 0; l < left_feat_num; ++l) { int l_idx = l * dim_size; for (int r = 0; r < right_feat_num; ++r) { int r_idx = r * dim_size; int o_idx = (l * right_feat_num + r) * dim_size; for (int k = 0; k < dim_size; ++k) { output(b, o_idx + k) = left_matrix(b, l_idx + k) * right_matrix(b, r_idx + k); } } } } } __global__ void FFMKernelDot(TTypes::ConstMatrix left_matrix, int left_feat_num, TTypes::ConstMatrix right_matrix, int right_feat_num, int batch_size, int dim_size, TTypes::Matrix output) { GPU_1D_KERNEL_LOOP(b, batch_size) { for (int j = 0; j < output.dimension(1); ++j) { output(b, j) = 0; } } __syncthreads(); GPU_1D_KERNEL_LOOP(b, batch_size) { for (int l = 0; l < left_feat_num; ++l) { int l_idx = l * dim_size; for (int r = 0; r < right_feat_num; ++r) { int r_idx = r * dim_size; int o_idx = l * right_feat_num + r; for (int k = 0; k < dim_size; ++k) { output(b, o_idx) += left_matrix(b, l_idx + k) * right_matrix(b, r_idx + k); } } } } } template <> struct FFMImpl { static void Compute(OpKernelContext *ctx, const std::string &int_type, TTypes::ConstMatrix left_matrix, int left_feat_num, TTypes::ConstMatrix right_matrix, int right_feat_num, int batch_size, int dim_size, TTypes::Matrix output) { Eigen::GpuDevice gpu_device = ctx->eigen_device(); auto config = GetGpuLaunchConfig(batch_size, gpu_device); if (int_type == "dot") { TF_CHECK_OK(GpuLaunchKernel( FFMKernelDot, config.block_count, config.thread_per_block, 0, gpu_device.stream(), left_matrix, left_feat_num, right_matrix, right_feat_num, batch_size, dim_size, output)); } else { TF_CHECK_OK(GpuLaunchKernel( FFMKernelMultiply, config.block_count, config.thread_per_block, 0, gpu_device.stream(), left_matrix, left_feat_num, right_matrix, right_feat_num, batch_size, dim_size, output)); } } }; __global__ void FFMGradKernelMultiply( TTypes::ConstMatrix grad_matrix, int grad_feat_num, TTypes::ConstMatrix left_matrix, int left_feat_num, TTypes::ConstMatrix right_matrix, int right_feat_num, int batch_size, int dim_size, TTypes::Matrix left_grad_matrix, TTypes::Matrix right_grad_matrix) { GPU_1D_KERNEL_LOOP(b, batch_size) { for (int g = 0; g < left_feat_num * dim_size; ++g) { left_grad_matrix(b, g) = 0; } for (int g = 0; g < right_feat_num * dim_size; ++g) { right_grad_matrix(b, g) = 0; } } __syncthreads(); GPU_1D_KERNEL_LOOP(b, batch_size) { for (int g = 0; g < grad_feat_num; ++g) { int l_idx = (g / right_feat_num) * dim_size; int r_idx = (g % right_feat_num) * dim_size; int g_idx = g * dim_size; for (int k = 0; k < dim_size; ++k) { left_grad_matrix(b, l_idx + k) += grad_matrix(b, g_idx + k) * right_matrix(b, r_idx + k); right_grad_matrix(b, r_idx + k) += grad_matrix(b, g_idx + k) * left_matrix(b, l_idx + k); } } } } __global__ void FFMGradKernelDot( TTypes::ConstMatrix grad_matrix, int grad_feat_num, TTypes::ConstMatrix left_matrix, int left_feat_num, TTypes::ConstMatrix right_matrix, int right_feat_num, int batch_size, int dim_size, TTypes::Matrix left_grad_matrix, TTypes::Matrix right_grad_matrix) { GPU_1D_KERNEL_LOOP(b, batch_size) { for (int g = 0; g < left_feat_num * dim_size; ++g) { left_grad_matrix(b, g) = 0; } for (int g = 0; g < right_feat_num * dim_size; ++g) { right_grad_matrix(b, g) = 0; } } __syncthreads(); GPU_1D_KERNEL_LOOP(b, batch_size) { for (int g = 0; g < grad_feat_num; ++g) { int l_idx = (g / right_feat_num) * dim_size; int r_idx = (g % right_feat_num) * dim_size; for (int k = 0; k < dim_size; ++k) { left_grad_matrix(b, l_idx + k) += grad_matrix(b, g) * right_matrix(b, r_idx + k); right_grad_matrix(b, r_idx + k) += grad_matrix(b, g) * left_matrix(b, l_idx + k); } } } } template <> struct FFMGradImpl { static void Compute(OpKernelContext *ctx, const std::string &int_type, TTypes::ConstMatrix grad_matrix, int grad_feat_num, TTypes::ConstMatrix left_matrix, int left_feat_num, TTypes::ConstMatrix right_matrix, int right_feat_num, int batch_size, int dim_size, TTypes::Matrix left_grad_matrix, TTypes::Matrix right_grad_matrix) { Eigen::GpuDevice gpu_device = ctx->eigen_device(); auto config = GetGpuLaunchConfig(batch_size, gpu_device); if (int_type == "dot") { TF_CHECK_OK(GpuLaunchKernel( FFMGradKernelDot, config.block_count, config.thread_per_block, 0, gpu_device.stream(), grad_matrix, grad_feat_num, left_matrix, left_feat_num, right_matrix, right_feat_num, batch_size, dim_size, left_grad_matrix, right_grad_matrix)); } else { TF_CHECK_OK(GpuLaunchKernel( FFMGradKernelMultiply, config.block_count, config.thread_per_block, 0, gpu_device.stream(), grad_matrix, grad_feat_num, left_matrix, left_feat_num, right_matrix, right_feat_num, batch_size, dim_size, left_grad_matrix, right_grad_matrix)); } } }; namespace { REGISTER_KERNEL_BUILDER(Name("FFM").Device(DEVICE_GPU), FFMOp) REGISTER_KERNEL_BUILDER(Name("FFMGrad").Device(DEVICE_GPU), FFMGradOp) } // namespace } // namespace monolith_tf } // namespace tensorflow #endif // GOOGLE_CUDA ================================================ FILE: monolith/native_training/layers/kernels/ffm_kernels.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_MONOLITH_NATIVE_TRAINING_LAYERS_KERNELS_FFM_KERNELS_H_ #define MONOLITH_MONOLITH_NATIVE_TRAINING_LAYERS_KERNELS_FFM_KERNELS_H_ #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/framework/tensor.h" namespace tensorflow { namespace monolith_tf { template struct FFMImpl { static void Compute(OpKernelContext *ctx, const std::string &int_type, TTypes::ConstMatrix left_matrix, int left_feat_num, TTypes::ConstMatrix right_matrix, int right_feat_num, int batch_size, int dim_size, TTypes::Matrix output); }; template struct FFMGradImpl { static void Compute(OpKernelContext *ctx, const std::string &int_type, TTypes::ConstMatrix grad_matrix, int grad_feat_num, TTypes::ConstMatrix left_matrix, int left_feat_num, TTypes::ConstMatrix right_matrix, int right_feat_num, int batch_size, int dim_size, TTypes::Matrix left_grad_matrix, TTypes::Matrix right_grad_matrix); }; template class FFMOp : public OpKernel { public: explicit FFMOp(OpKernelConstruction *ctx) : OpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("dim_size", &dim_size_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("int_type", &int_type_)); } void Compute(OpKernelContext *ctx) override { const Tensor *left_tensor; OP_REQUIRES_OK(ctx, ctx->input("left", &left_tensor)); OP_REQUIRES( ctx, left_tensor->dims() == 2, errors::InvalidArgument("the left input tensor of ffm is not 2D")); int64 batch_size = left_tensor->dim_size(0); int64 left_feat_num = left_tensor->dim_size(1) / dim_size_; auto left_matrix = left_tensor->matrix(); const Tensor *right_tensor; OP_REQUIRES_OK(ctx, ctx->input("right", &right_tensor)); OP_REQUIRES( ctx, left_tensor->dims() == 2, errors::InvalidArgument("the right input tensor of ffm is not 2D")); OP_REQUIRES(ctx, batch_size == right_tensor->dim_size(0), errors::InvalidArgument( "the batch size of left and right tensor are not match")); int64 right_feat_num = right_tensor->dim_size(1) / dim_size_; auto right_matrix = right_tensor->matrix(); Tensor *output_tensor = nullptr; int out_last_dim = 0; if (int_type_ == "dot") { out_last_dim = left_feat_num * right_feat_num; } else { out_last_dim = left_feat_num * right_feat_num * dim_size_; } OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {batch_size, out_last_dim}, &output_tensor)); auto output_matrix = output_tensor->matrix(); FFMImpl::Compute(ctx, int_type_, left_matrix, left_feat_num, right_matrix, right_feat_num, batch_size, dim_size_, output_matrix); } private: int dim_size_; std::string int_type_; }; template class FFMGradOp : public OpKernel { public: explicit FFMGradOp(OpKernelConstruction *ctx) : OpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("dim_size", &dim_size_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("int_type", &int_type_)); } void Compute(OpKernelContext *ctx) override { const Tensor *grad_tensor; OP_REQUIRES_OK(ctx, ctx->input("grad", &grad_tensor)); OP_REQUIRES(ctx, grad_tensor->dims() == 2, errors::InvalidArgument("the grad tensor of ffm is not 2D")); int batch_size = grad_tensor->dim_size(0); int grad_feat_num = 0; if (int_type_ == "dot") { grad_feat_num = grad_tensor->dim_size(1); } else { grad_feat_num = grad_tensor->dim_size(1) / dim_size_; } auto grad_matrix = grad_tensor->matrix(); const Tensor *left_tensor; OP_REQUIRES_OK(ctx, ctx->input("left", &left_tensor)); OP_REQUIRES( ctx, left_tensor->dims() == 2, errors::InvalidArgument("the left input tensor of ffm is not 2D")); int64 left_feat_num = left_tensor->dim_size(1) / dim_size_; auto left_matrix = left_tensor->matrix(); const Tensor *right_tensor; OP_REQUIRES_OK(ctx, ctx->input("right", &right_tensor)); OP_REQUIRES( ctx, left_tensor->dims() == 2, errors::InvalidArgument("the right input tensor of ffm is not 2D")); int64 right_feat_num = right_tensor->dim_size(1) / dim_size_; auto right_matrix = right_tensor->matrix(); OP_REQUIRES(ctx, grad_feat_num == left_feat_num * right_feat_num, errors::InvalidArgument("the in/out shape not match")); Tensor *left_grad_tensor = nullptr; OP_REQUIRES_OK( ctx, ctx->allocate_output(0, left_tensor->shape(), &left_grad_tensor)); auto left_grad_matrix = left_grad_tensor->matrix(); Tensor *right_grad_tensor = nullptr; OP_REQUIRES_OK(ctx, ctx->allocate_output(1, right_tensor->shape(), &right_grad_tensor)); auto right_grad_matrix = right_grad_tensor->matrix(); FFMGradImpl::Compute(ctx, int_type_, grad_matrix, grad_feat_num, left_matrix, left_feat_num, right_matrix, right_feat_num, batch_size, dim_size_, left_grad_matrix, right_grad_matrix); } private: int dim_size_; std::string int_type_; }; } // namespace monolith_tf } // namespace tensorflow #endif // MONOLITH_MONOLITH_NATIVE_TRAINING_LAYERS_KERNELS_FFM_KERNELS_H_ ================================================ FILE: monolith/native_training/layers/kernels/fid_counter_kernel.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/framework/tensor.h" namespace tensorflow { namespace monolith_tf { class MonolithFidCounterOp : public OpKernel { public: explicit MonolithFidCounterOp(OpKernelConstruction *ctx) : OpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("step", &step_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("counter_threshold", &counter_threshold_)); } void Compute(OpKernelContext *ctx) override { ctx->set_output(0, ctx->input(0)); } private: float step_; int counter_threshold_; }; namespace { REGISTER_KERNEL_BUILDER(Name("MonolithFidCounter").Device(DEVICE_CPU), MonolithFidCounterOp) } // namespace } // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/layers/layer_ops.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import tensorflow as tf from absl import logging from typing import Tuple from monolith.native_training.runtime.ops import gen_monolith_ops layer_ops_lib = gen_monolith_ops def ffm(left: tf.Tensor, right: tf.Tensor, dim_size: int, int_type: str = 'multiply') -> tf.Tensor: output = layer_ops_lib.FFM(left=left, right=right, dim_size=dim_size, int_type=int_type) return output @tf.RegisterGradient('FFM') def _ffm_grad(op, grad: tf.Tensor) -> tf.Tensor: left, right = op.inputs[0], op.inputs[1] dim_size = op.get_attr('dim_size') int_type = op.get_attr('int_type') (left_grad, right_grad) = layer_ops_lib.FFMGrad(grad=grad, left=left, right=right, dim_size=dim_size, int_type=int_type) return left_grad, right_grad def feature_insight(input_embedding, weight, segment_sizes, aggregate: bool = False) -> tf.Tensor: assert segment_sizes assert input_embedding.shape.as_list()[-1] == weight.shape.as_list()[0] out = layer_ops_lib.FeatureInsight(input=input_embedding, weight=weight, segment_sizes=segment_sizes) if aggregate: k, num_feature = weight.shape.as_list()[-1], len(segment_sizes) segment_ids = [] for i in range(num_feature): segment_ids.extend([i] * k) segment_ids_tensor = tf.constant(value=segment_ids, shape=(k * num_feature,), dtype=tf.int32) return tf.transpose( tf.math.segment_sum(tf.transpose(out * out), segment_ids=segment_ids_tensor)) pass else: return out @tf.RegisterGradient('FeatureInsight') def _feature_insight(op, grad: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]: input_embedding, weight = op.inputs[0], op.inputs[1] segment_sizes = op.get_attr('segment_sizes') k = weight.shape.as_list()[-1] input_embedding_grad, weight_grad = layer_ops_lib.FeatureInsightGrad( grad=grad, input=input_embedding, weight=weight, segment_sizes=segment_sizes, K=k) return input_embedding_grad, weight_grad def fid_counter(counter: tf.Tensor, counter_threshold: int, step: float = 1.0): """Count element value(e.g. embedding/label), will consume 1-size vector as counter Args: counter(Tensor): feature slice to store counter counter_threshold(int): threshold set step to 0 step(Tensor): value add to counter Returns: counter: counter value Attention: 1. fid_counter's input embedding MUST use SgdOptimizer(1.0). 2. We recommend using Fp32Compressor() for counter slice. 3. If you use Fp16Compressor(), for precision reason, we recommend setting counter_threshold to 60000. Example:: >>> item_count = self.embedding_lookup(slice_name='item_count', slots=[534], dim=1, initializer= ConstantsInitializer(1.0), optimizer= SgdOptimizer(1.0), compressor= Fp32Compressor()) >>> item_count = layer_ops.fid_counter(item_count, step=1) >>> item_count = tf.reshape(item_count, shape=(-1, )) >>> item_weights = 1 / (1 + tf.math.exp(4 - 0.03 * item_count)) """ counter = layer_ops_lib.MonolithFidCounter( counter=counter, step=step, counter_threshold=counter_threshold) counter = counter + tf.cast(step, counter.dtype) counter = tf.where( counter > counter_threshold, tf.ones_like(counter) * tf.cast(counter_threshold, counter.dtype), counter) return counter @tf.RegisterGradient('MonolithFidCounter') def _fid_counter_grad(op, grad: tf.Tensor) -> tf.Tensor: counter = op.inputs[0] step = op.get_attr('step') grad = tf.ones_like(counter) * tf.cast(-step, counter.dtype) counter_threshold = op.get_attr('counter_threshold') grad = tf.where(counter >= counter_threshold, tf.zeros_like(grad), grad) return grad ================================================ FILE: monolith/native_training/layers/layer_ops_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import numpy as np import tensorflow as tf from tensorflow.python.framework import test_util from monolith.native_training.layers.layer_ops import ffm from monolith.native_training.layers import layer_ops tf.random.set_seed(0) class LayerOpsTest(tf.test.TestCase): def test_ffm_mul(self): with test_util.use_gpu(): left = tf.random.uniform(shape=(8, 10 * 4), minval=0, maxval=10) right = tf.random.uniform(shape=(8, 12 * 4), minval=0, maxval=10) output_maybe_on_gpu = ffm(left=left, right=right, dim_size=4) if tf.test.is_gpu_available(): self.assertEqual(output_maybe_on_gpu.device, '/job:localhost/replica:0/task:0/device:GPU:0') with tf.device("/device:CPU:0"): output_on_cpu = ffm(left=left, right=right, dim_size=4) self.assertEqual(output_on_cpu.device, '/job:localhost/replica:0/task:0/device:CPU:0') self.assertTrue(output_maybe_on_gpu.shape == (8, 480)) self.assertAllEqual(output_maybe_on_gpu, output_on_cpu) def test_ffm_mul_grad(self): with test_util.use_gpu(): left = tf.random.uniform(shape=(8, 10 * 4), minval=0, maxval=10) right = tf.random.uniform(shape=(8, 12 * 4), minval=0, maxval=10) with tf.GradientTape() as g: g.watch(left) g.watch(right) out = ffm(left=left, right=right, dim_size=4) loss = tf.reduce_sum(out) left_grad_maybe_on_gpu, right_grad_maybe_on_gpu = g.gradient( loss, [left, right]) self.assertTrue(left_grad_maybe_on_gpu.shape == (8, 40)) self.assertTrue(right_grad_maybe_on_gpu.shape == (8, 48)) with tf.device("/device:CPU:0"), tf.GradientTape() as g: g.watch(left) g.watch(right) out = ffm(left=left, right=right, dim_size=4) loss = tf.reduce_sum(out) left_grad_on_cpu, right_grad_on_cpu = g.gradient(loss, [left, right]) self.assertEqual(left_grad_on_cpu.device, '/job:localhost/replica:0/task:0/device:CPU:0') self.assertEqual(right_grad_on_cpu.device, '/job:localhost/replica:0/task:0/device:CPU:0') self.assertAllEqual(left_grad_maybe_on_gpu, left_grad_on_cpu) self.assertAllEqual(right_grad_maybe_on_gpu, right_grad_on_cpu) def test_ffm_dot(self): with test_util.use_gpu(): left = tf.random.uniform(shape=(8, 10 * 4), minval=0, maxval=10) right = tf.random.uniform(shape=(8, 12 * 4), minval=0, maxval=10) output_maybe_on_gpu = ffm(left=left, right=right, dim_size=4, int_type='dot') if tf.test.is_gpu_available(): self.assertEqual(output_maybe_on_gpu.device, '/job:localhost/replica:0/task:0/device:GPU:0') with tf.device("/device:CPU:0"): output_on_cpu = ffm(left=left, right=right, dim_size=4, int_type='dot') self.assertEqual(output_on_cpu.device, '/job:localhost/replica:0/task:0/device:CPU:0') self.assertTrue(output_maybe_on_gpu.shape == (8, 120)) self.assertAllEqual(output_maybe_on_gpu, output_on_cpu) def test_ffm_dot_grad(self): with test_util.use_gpu(): left = tf.random.uniform(shape=(8, 10 * 4), minval=0, maxval=10) right = tf.random.uniform(shape=(8, 12 * 4), minval=0, maxval=10) with tf.GradientTape() as g: g.watch(left) g.watch(right) out = ffm(left=left, right=right, dim_size=4, int_type='dot') loss = tf.reduce_sum(out) left_grad_maybe_on_gpu, right_grad_maybe_on_gpu = g.gradient( loss, [left, right]) self.assertTrue(left_grad_maybe_on_gpu.shape == (8, 40)) self.assertTrue(right_grad_maybe_on_gpu.shape == (8, 48)) with tf.device("/device:CPU:0"), tf.GradientTape() as g: g.watch(left) g.watch(right) out = ffm(left=left, right=right, dim_size=4, int_type='dot') loss = tf.reduce_sum(out) left_grad_on_cpu, right_grad_on_cpu = g.gradient(loss, [left, right]) self.assertEqual(left_grad_on_cpu.device, '/job:localhost/replica:0/task:0/device:CPU:0') self.assertEqual(right_grad_on_cpu.device, '/job:localhost/replica:0/task:0/device:CPU:0') self.assertAllEqual(left_grad_maybe_on_gpu, left_grad_on_cpu) self.assertAllEqual(right_grad_maybe_on_gpu, right_grad_on_cpu) def test_feature_insight(self): segment_sizes = [3, 2, 4] input_embedding = [ 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9 ] input_embedding_tensor = tf.constant(value=input_embedding, shape=(3, 9), dtype=tf.float32) weight = [ 0.1, 0.2, 0.3, 0.4, 0.5, 0.4, 0.3, 0.2, 0.1, 0.9, 0.8, 0.7, 0.6, 0.5, 0.6, 0.7, 0.8, 0.9 ] weight_tensor = tf.constant(value=weight, shape=(9, 2), dtype=tf.float32) input_embedding_splits = tf.split(input_embedding_tensor, num_or_size_splits=segment_sizes, axis=1) weight_splits = tf.split(weight_tensor, num_or_size_splits=segment_sizes, axis=0) concatenated = tf.concat([ tf.matmul(ip, w) for ip, w in zip(input_embedding_splits, weight_splits) ], axis=1) k, num_feature = 2, 3 segment_ids = [] for i in range(num_feature): segment_ids.extend([i] * k) segment_ids_tensor = tf.constant(value=segment_ids, shape=(k * num_feature,), dtype=tf.int32) res_exp = tf.transpose( tf.math.segment_sum(tf.transpose(concatenated * concatenated), segment_ids=segment_ids_tensor)) out = layer_ops.feature_insight(input_embedding_tensor, weight_tensor, segment_sizes, aggregate=True) self.assertAllClose(out, res_exp) def test_feature_insight_grad(self): segment_sizes = [3, 2, 4] input_embedding = [ 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9 ] input_embedding_tensor = tf.constant(value=input_embedding, shape=(3, 9), dtype=tf.float32) weight = [ 0.1, 0.2, 0.3, 0.4, 0.5, 0.4, 0.3, 0.2, 0.1, 0.9, 0.8, 0.7, 0.6, 0.5, 0.6, 0.7, 0.8, 0.9 ] weight_tensor = tf.constant(value=weight, shape=(9, 2), dtype=tf.float32) with tf.GradientTape(persistent=True) as g: g.watch(input_embedding_tensor) g.watch(weight_tensor) input_embedding_splits = tf.split(input_embedding_tensor, num_or_size_splits=segment_sizes, axis=1) weight_splits = tf.split(weight_tensor, num_or_size_splits=segment_sizes, axis=0) res_exp = tf.concat([ tf.matmul(ip, w) for ip, w in zip(input_embedding_splits, weight_splits) ], axis=1) out = layer_ops.feature_insight(input_embedding_tensor, weight_tensor, segment_sizes) input_embedding_grad_exp = g.gradient(res_exp, input_embedding_tensor) weight_grad_exp = g.gradient(res_exp, weight_tensor) input_embedding_grad = g.gradient(out, input_embedding_tensor) weight_grad = g.gradient(out, weight_tensor) self.assertAllClose(out, res_exp) self.assertAllClose(input_embedding_grad, input_embedding_grad_exp) self.assertAllClose(weight_grad, weight_grad_exp) def test_fid_counter_grad(self): alpha = tf.constant([1.0]) with tf.GradientTape() as g: g.watch(alpha) counter = layer_ops.fid_counter(alpha, counter_threshold=1000, step=1) counter_loss = tf.reduce_sum(counter) var_grad = g.gradient(counter_loss, alpha) self.assertAllClose(counter, [2.0]) self.assertAllClose(var_grad, [-1.0]) print(f"The grad {list(var_grad.numpy())}", flush=True) with tf.GradientTape() as g: g.watch(alpha) counter = layer_ops.fid_counter(alpha, counter_threshold=1000, step=0.01) counter_loss = tf.reduce_sum(counter) var_grad = g.gradient(counter_loss, alpha) self.assertAllClose(counter, [1.01]) self.assertAllClose(var_grad, [-0.01]) print(f"The grad {list(var_grad.numpy())}", flush=True) alpha = tf.constant([1000.0]) with tf.GradientTape() as g: g.watch(alpha) counter = layer_ops.fid_counter(alpha, counter_threshold=1000, step=1) counter_loss = tf.reduce_sum(counter) var_grad = g.gradient(counter_loss, alpha) self.assertAllClose(counter, [1000]) self.assertAllClose(var_grad, [0]) print(f"The grad {list(var_grad.numpy())}", flush=True) if __name__ == '__main__': tf.test.main() ================================================ FILE: monolith/native_training/layers/lhuc.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # -*- encoding=utf-8 -*- from __future__ import absolute_import from __future__ import division from __future__ import print_function import numpy as np import tensorflow as tf from tensorflow.keras.layers import Layer from tensorflow.keras import regularizers from tensorflow.keras.models import Sequential from tensorflow.keras.layers import BatchNormalization as BatchNorm from monolith.native_training.layers.mlp import MLP from monolith.native_training.layers.dense import Dense from monolith.native_training.utils import extend_as_list, with_params import monolith.native_training.layers.advanced_activations as ad_acts from monolith.native_training.monolith_export import monolith_export @monolith_export @with_params class LHUCTower(Layer): """LHUCTower, 对MLP的改进, 在MLP的基础上增加了一系列的 LHUC MLP 当作Gate. 论文可参考 https://arxiv.org/abs/1601.02828 Args: output_dims (:obj:`List[int]`): 主Tower的每一层的输出神经元个数 lhuc_output_dims (:obj:`List[int]`, `List[List[int]]`): 每个LHUC MLP的output_dims, 其长度与output_dims相同, 可以有两种方式指定, 1) 用`List[int]`指定, 此时, 除最上层外, 所有LHUC MLP结构相同, 最上层的Dense会在内部自动加上 并处理shape; 2) 用`List[List[int]]`, 此时, 每个LHUC MLP结构都可以不同, 内部不会处理最上层Dense 层, 所以用户必须确保shape是正确的. lhuc_output_dims默认为None, 等价于[]. activations (:obj:`List[tf.activation]`, `List[str]`, `tf.activation`, `str`): 激活函数, 可以用str表示, 也可以用TF中的activation initializers (:obj:`List[tf.initializer]`): kernel, 也就是W的初始化器, 是一个列表 kernel_regularizer (:obj:`tf.regularizer`): kernel正侧化器 use_weight_norm (:obj:`bool`): 是否开启kernel_norm use_learnable_weight_norm (:obj:`bool`): 是否让kernel_norm可训练 use_bias (:obj:`bool`): 是否使用bias, 默认为True bias_regularizer (:obj:`tf.regularizer`): bias正侧化 enable_batch_normalization (:obj:`bool`): 是否开启batch normalization, 如果开启, 会对输入数据, 及每个Dense Layer的输出匀做 BatchNorm (最后一个Dense Layer除外). batch_normalization_momentum (:obj:`float`): BatchNorm中的动量因子 batch_normalization_renorm (:obj:`bool`): 是否使用renorm, (论文可参考 https://arxiv.org/abs/1702.03275) batch_normalization_renorm_clipping (:obj:`bool`): renorm中的clipping, 具体请参考TF中的 `BatchNormalization`_ batch_normalization_renorm_momentum (:obj:`float`): renorm中的momentum, 具体请参考TF中的 `BatchNormalization`_ 此外, 对于 weight_norm, batch_normalization 相关参数, 主MLP与LHUC MLP共用, 如果要为LHUC MLP指定不同的参数, 可用 "lhuc_{params_name}" 来指定 .. _BatchNormalization: https://www.tensorflow.org/api_docs/python/tf/keras/layers/BatchNormalization """ def __init__(self, output_dims, lhuc_output_dims=None, activations='relu', initializers=None, use_bias=True, use_weight_norm=True, use_learnable_weight_norm=True, kernel_regularizer=None, bias_regularizer=None, enable_batch_normalization=False, batch_normalization_momentum=0.99, batch_normalization_renorm=False, batch_normalization_renorm_clipping=None, batch_normalization_renorm_momentum=0.99, **kwargs): self._lhuc_kwargs = { k: v for k, v in kwargs.items() if k.startswith('lhuc_') } for lhuc_key in self._lhuc_kwargs: del kwargs[lhuc_key] super(LHUCTower, self).__init__(**kwargs) self.output_dims = output_dims self.n_layers = len(output_dims) if activations is None: self.activations = [ad_acts.get('relu')] * (self.n_layers - 1) + [None] elif isinstance(activations, (list, tuple)): assert len(activations) == self.n_layers self.activations = [ad_acts.get(act) for act in activations] else: self.activations = [ ad_acts.get(activations) if i != self.n_layers - 1 else None for i in range(self.n_layers) ] self.initializers = extend_as_list(initializers, self.n_layers) self.use_bias = use_bias self.use_weight_norm = use_weight_norm self.use_learnable_weight_norm = use_learnable_weight_norm self.kernel_regularizer = kernel_regularizer self.bias_regularizer = bias_regularizer self.enable_batch_normalization = enable_batch_normalization self.batch_normalization_momentum = batch_normalization_momentum self.batch_normalization_renorm = batch_normalization_renorm self.batch_normalization_renorm_clipping = batch_normalization_renorm_clipping self.batch_normalization_renorm_momentum = batch_normalization_renorm_momentum if lhuc_output_dims: assert isinstance(lhuc_output_dims, (list, tuple)) if all(isinstance(dims, (list, tuple)) for dims in lhuc_output_dims): for i, dims in enumerate(lhuc_output_dims): assert dims[-1] == output_dims[ i], "the last dim of lhuc must be identity with dense output" self.lhuc_output_dims = lhuc_output_dims elif all(isinstance(dims, int) for dims in lhuc_output_dims): self.lhuc_output_dims = [] for dim in self.output_dims: self.lhuc_output_dims.append(lhuc_output_dims + [dim]) else: raise Exception("lhuc_output_dims is error") else: self.lhuc_output_dims = [[i] for i in self.output_dims] self.lhuc_activations = [[ ad_acts.get('relu') if i != len(dims) - 1 else ad_acts.get('sigmoid2') for i in range(len(dims)) ] for dims in self.lhuc_output_dims] self.layers = [] self.lhuc_layers = [] self.extra_layers = [] def lhuc_params(self, name): params = self._lhuc_kwargs.get(f"lhuc_{name}") if params is None and hasattr(self, name): params = getattr(self, name) return params def build(self, input_shape): if self.enable_batch_normalization: bn_layer = BatchNorm( name='batch_norm', momentum=self.batch_normalization_momentum, renorm=self.batch_normalization_renorm, renorm_clipping=self.batch_normalization_renorm_clipping, renorm_momentum=self.batch_normalization_renorm_momentum) self._trainable_weights.extend(bn_layer.trainable_weights) self._non_trainable_weights.extend(bn_layer.non_trainable_weights) self.extra_layers.append(bn_layer) for i, dim in enumerate(self.output_dims): layer_name = f'layer_{i + 1}' sequential = Sequential(name=layer_name) # one block in dense tower dense = Dense(name=f'{layer_name}/dense', units=dim, activation=None, use_bias=self.use_bias, kernel_initializer=self.initializers[i], bias_initializer=tf.initializers.zeros(), allow_kernel_norm=self.use_weight_norm, kernel_norm_trainable=self.use_learnable_weight_norm, kernel_regularizer=regularizers.get( self.kernel_regularizer), bias_regularizer=regularizers.get(self.bias_regularizer)) self._trainable_weights.extend(dense.trainable_weights) self._non_trainable_weights.extend(dense.non_trainable_weights) sequential.add(dense) if i != (self.n_layers - 1) and self.enable_batch_normalization: bn_layer = BatchNorm( name=f'{layer_name}/batch_norm', momentum=self.batch_normalization_momentum, renorm=self.batch_normalization_renorm, renorm_clipping=self.batch_normalization_renorm_clipping, renorm_momentum=self.batch_normalization_renorm_momentum) self._trainable_weights.extend(bn_layer.trainable_weights) self._non_trainable_weights.extend(bn_layer.non_trainable_weights) sequential.add(bn_layer) if self.activations[i] is not None: sequential.add(self.activations[i]) self.layers.append(sequential) # for lhuc tower mlp = MLP(name=f'{layer_name}/lhuc', output_dims=self.lhuc_output_dims[i], activations=self.lhuc_activations[i], initializers=self.initializers[i], kernel_regularizer=self.lhuc_params('kernel_regularizer'), use_weight_norm=self.lhuc_params('use_weight_norm'), use_learnable_weight_norm=self.lhuc_params( 'use_learnable_weight_norm'), use_bias=self.lhuc_params('use_bias'), bias_regularizer=self.lhuc_params('bias_regularizer'), enable_batch_normalization=self.lhuc_params( 'enable_batch_normalization'), batch_normalization_momentum=self.lhuc_params( 'batch_normalization_momentum'), batch_normalization_renorm=self.lhuc_params( 'batch_normalization_renorm'), batch_normalization_renorm_clipping=self.lhuc_params( 'batch_normalization_renorm_clipping'), batch_normalization_renorm_momentum=self.lhuc_params( 'batch_normalization_renorm_momentum')) self._trainable_weights.extend(mlp.trainable_weights) self._non_trainable_weights.extend(mlp.non_trainable_weights) self.lhuc_layers.append(mlp) super(LHUCTower, self).build(input_shape) def call(self, inputs, **kwargs): if isinstance(inputs, (list, tuple)): assert len(inputs) == 2 dense_input, lhuc_input = inputs else: inputs = tf.convert_to_tensor(inputs) dense_input = inputs lhuc_input = inputs input_t = dense_input for layer in self.extra_layers: input_t = layer(input_t) for layer, lhuc_layer in zip(self.layers, self.lhuc_layers): output_t = layer(input_t) * lhuc_layer(lhuc_input) input_t = output_t return output_t def get_config(self): config = { "output_dims": self.output_dims, "lhuc_output_dims": self.lhuc_output_dims, "activations": [ad_acts.serialize(act) for act in self.activations], "initializers": [ tf.initializers.serialize(init) for init in self.initializers ], "use_bias": self.use_bias, "use_weight_norm": self.use_weight_norm, "use_learnable_weight_norm": self.use_learnable_weight_norm, 'kernel_regularizer': regularizers.serialize(self.kernel_regularizer), 'bias_regularizer': regularizers.serialize(self.bias_regularizer), "enable_batch_normalization": self.enable_batch_normalization, "batch_normalization_momentum": self.batch_normalization_momentum, 'batch_normalization_renorm': self.batch_normalization_renorm, 'batch_normalization_renorm_clipping': self.batch_normalization_renorm_clipping, 'batch_normalization_renorm_momentum': self.batch_normalization_renorm_momentum } config.update(self._lhuc_kwargs) base_config = super(LHUCTower, self).get_config() return dict(list(base_config.items()) + list(config.items())) @classmethod def from_config(cls, config): p = cls.params().copy() need_pop = [] for key, value in config.items(): if key in p: if key == 'initializers': p[key] = [tf.initializers.deserialize(init) for init in config[key]] elif key == 'activations': p[key] = [ad_acts.deserialize(act) for act in config[key]] elif key == 'kernel_regularizer': regularizers.deserialize(value), elif key == 'bias_regularizer': regularizers.deserialize(value), else: p[key] = value need_pop.append(key) for key in need_pop: config.pop(key) return p.instantiate() ================================================ FILE: monolith/native_training/layers/lhuc_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import numpy as np import tensorflow as tf from monolith.native_training.layers.lhuc import LHUCTower class LHUCTowerTest(tf.test.TestCase): def test_lhuc_instantiate(self): lhuc_layer_template = LHUCTower.params() test_params0 = lhuc_layer_template.copy() test_params0.name = 'test_dense0' test_params0.output_dims = [1, 3, 4, 5] test_params0.activations = None test_params0.initializers = tf.keras.initializers.GlorotNormal() lhuc1 = test_params0.instantiate() print(lhuc1) lhuc2 = LHUCTower(output_dims=[1, 3, 4, 5], activations=None, initializers=tf.keras.initializers.HeUniform()) print(lhuc2) def test_lhuc_serde(self): lhuc_layer_template = LHUCTower.params() test_params0 = lhuc_layer_template.copy() test_params0.name = 'test_dense0' test_params0.output_dims = [1, 3, 4, 5] test_params0.activations = None test_params0.initializers = tf.keras.initializers.GlorotNormal() lhuc1 = test_params0.instantiate() cfg = lhuc1.get_config() lhuc2 = LHUCTower.from_config(cfg) print(lhuc1, lhuc2) def test_lhuc_call(self): layer = LHUCTower(output_dims=[50, 20, 1], activations=None, lhuc_output_dims=[[50, 50], [50, 50, 20], [100, 1]], use_bias=True, lhuc_use_bias=False, initializers=tf.keras.initializers.HeUniform()) dense_data = tf.keras.backend.variable(np.ones((100, 100))) lhuc_data = tf.keras.backend.variable(np.ones((100, 50))) sum_out = tf.reduce_sum(layer([dense_data, lhuc_data])) with self.session() as sess: sess.run(tf.compat.v1.global_variables_initializer()) print(sess.run(sum_out)) if __name__ == '__main__': tf.compat.v1.disable_eager_execution() tf.test.main() ================================================ FILE: monolith/native_training/layers/logit_correction.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import tensorflow as tf from tensorflow.keras.layers import Layer, InputSpec from tensorflow.python.ops import math_ops from tensorflow.keras import activations from tensorflow.keras import initializers from tensorflow.python.keras import regularizers from monolith.native_training.utils import with_params from monolith.native_training.layers.mlp import MLP from monolith.native_training.monolith_export import monolith_export @monolith_export @with_params class LogitCorrection(Layer): """Logit校正, 由于采样等原因, 会使得CTR/CVR的预测与后验均值有偏差, 需要对这种偏差进行校正 Logit校正可以在训练时进行, 也可以在推理时进行, 为了减轻推理时负担, 一般选择训练时进行, LogitCorrection就是用于训练时校正的 Args: activation (:obj:`tf.activation`): 激活函数, 默认为None sample_bias (:obj:`bool`): 是否校正样本采样偏差 """ def __init__(self, activation=None, sample_bias: bool = False, **kwargs): super(LogitCorrection, self).__init__(**kwargs) # compatible with older version forced sumpooling # self.input_spec = InputSpec(shape=[None, None, 1]) self.input_spec = [InputSpec(max_ndim=2), InputSpec(max_ndim=2)] self.activation = activations.get(activation) self.sample_bias = sample_bias def call(self, inputs, **kwargs): # tensor with tf.shape([None,]) logits, sample_rate = inputs corrected = self.get_sample_logits(logits, sample_rate, self.sample_bias) if self.activation is not None: corrected = self.activation(corrected) return corrected @staticmethod def safe_log_sigmoid(logits): zeros = tf.zeros_like(logits, dtype=logits.dtype) cond = (logits >= zeros) relu_logits = tf.where(cond, logits, zeros) neg_abs_logits = tf.where(cond, -logits, logits) return tf.negative(relu_logits - logits + tf.compat.v1.log1p(tf.exp(neg_abs_logits))) @staticmethod def get_sample_logits(logits, sample_rate, sample_bias): if sample_rate is None and sample_bias: return LogitCorrection.safe_log_sigmoid(logits) elif sample_rate is not None and not sample_bias: return tf.add(logits, tf.negative(tf.compat.v1.log(sample_rate))) elif sample_rate is not None and sample_bias: return tf.add(LogitCorrection.safe_log_sigmoid(logits), tf.negative(tf.compat.v1.log(sample_rate))) else: return logits def compute_output_shape(self, input_shape): return tuple(tf.shape([ None, ])) def get_config(self): config = { 'activation': activations.serialize(self.activation), 'sample_bias': self.sample_bias } base_config = super(LogitCorrection, self).get_config() return dict(list(base_config.items()) + list(config.items())) ================================================ FILE: monolith/native_training/layers/logit_correction_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import numpy as np import tensorflow as tf from monolith.native_training.layers.logit_correction import * class SailSpecialTest(tf.test.TestCase): def test_sr_instantiate(self): layer_template = LogitCorrection.params() test_params0 = layer_template.copy() test_params0.activation = tf.keras.activations.relu ins1 = test_params0.instantiate() print(ins1) ins2 = LogitCorrection(activation=tf.keras.activations.relu) print(ins2) def test_sr_serde(self): layer_template = LogitCorrection.params() test_params0 = layer_template.copy() test_params0.activation = tf.keras.activations.sigmoid ins1 = test_params0.instantiate() print(ins1) cfg = ins1.get_config() ins2 = LogitCorrection.from_config(cfg) print(ins1, ins2) def test_sr_call(self): layer_template = LogitCorrection.params() test_params0 = layer_template.copy() test_params0.name = 'test_dense0' test_params0.activation = tf.keras.activations.tanh layer = test_params0.instantiate() x = tf.keras.backend.variable(np.random.uniform(size=(100, 10))) sr = tf.keras.backend.variable(np.random.uniform(low=1e-10, size=(100, 1))) sum_out = tf.reduce_sum(layer((x, sr))) with self.session() as sess: sess.run(tf.compat.v1.global_variables_initializer()) print(sess.run(sum_out)) if __name__ == '__main__': tf.compat.v1.disable_eager_execution() tf.test.main() ================================================ FILE: monolith/native_training/layers/mlp.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import tensorflow as tf from tensorflow.keras.layers import Layer from tensorflow.keras.layers import BatchNormalization as BatchNorm from tensorflow.python.keras import regularizers from monolith.native_training.layers.dense import Dense from monolith.native_training.utils import extend_as_list, with_params import monolith.native_training.layers.advanced_activations as ad_acts from monolith.native_training.monolith_export import monolith_export from monolith.native_training.summary.summary_ops import feature_insight_data @monolith_export @with_params class MLP(Layer): """多层感知器(Multilayer Perceptron), 最经典的人工神经网络, 由一系列层叠起来的Dense层组成 Args: output_dims (:obj:`List[int]`): 每一层的输出神经元个数 activations (:obj:`List[tf.activation]`, `List[str]`, `tf.activation`, `str`): 激活函数, 可以用str表示, 也可以用TF中的activation initializers (:obj:`List[tf.initializer]`): kernel, 也就是W的初始化器, 是一个列表 kernel_regularizer (:obj:`tf.regularizer`): kernel正侧化器 use_weight_norm (:obj:`bool`): 是否开启kernel_norm use_learnable_weight_norm (:obj:`bool`): 是否让kernel_norm可训练 use_bias (:obj:`bool`): 是否使用bias, 默认为True bias_regularizer (:obj:`tf.regularizer`): bias正侧化 enable_batch_normalization (:obj:`bool`): 是否开启batch normalization, 如果开启, 会对输入数据, 及每个Dense Layer的输出匀做 BatchNorm (最后一个Dense Layer除外). batch_normalization_momentum (:obj:`float`): BatchNorm中的动量因子 batch_normalization_renorm (:obj:`bool`): 是否使用renorm, (论文可参考 https://arxiv.org/abs/1702.03275) batch_normalization_renorm_clipping (:obj:`bool`): renorm中的clipping, 具体请参考TF中的 `BatchNormalization`_ batch_normalization_renorm_momentum (:obj:`float`): renorm中的momentum, 具体请参考TF中的 `BatchNormalization`_ .. _BatchNormalization: https://www.tensorflow.org/api_docs/python/tf/keras/layers/BatchNormalization """ def __init__(self, output_dims, activations=None, initializers=None, kernel_regularizer=None, use_weight_norm=True, use_learnable_weight_norm=True, use_bias=True, bias_regularizer=None, enable_batch_normalization=False, batch_normalization_momentum=0.99, batch_normalization_renorm=False, batch_normalization_renorm_clipping=None, batch_normalization_renorm_momentum=0.99, **kwargs): super(MLP, self).__init__(**kwargs) self.output_dims = output_dims self.use_weight_norm = use_weight_norm self.use_learnable_weight_norm = use_learnable_weight_norm self.use_bias = use_bias self.kernel_regularizer = regularizers.get(kernel_regularizer) self.bias_regularizer = regularizers.get(bias_regularizer) self.enable_batch_normalization = enable_batch_normalization self.batch_normalization_momentum = batch_normalization_momentum self.batch_normalization_renorm = batch_normalization_renorm self.batch_normalization_renorm_clipping = batch_normalization_renorm_clipping self.batch_normalization_renorm_momentum = batch_normalization_renorm_momentum self._stacked_layers = [] self._n_layers = len(self.output_dims) self._activations = None self._initializers = [ tf.initializers.get(init) for init in extend_as_list(initializers, self._n_layers) ] if activations is None: self._activations = [ad_acts.get('relu')] * (self._n_layers - 1) + [None] elif isinstance(activations, (list, tuple)): assert len(activations) == self._n_layers self._activations = [ad_acts.get(act) for act in activations] else: self._activations = [ ad_acts.get(activations) if i != self._n_layers - 1 else None for i in range(self._n_layers) ] def build(self, input_shape): if self.enable_batch_normalization: bn = BatchNorm(momentum=self.batch_normalization_momentum, renorm=self.batch_normalization_renorm, renorm_clipping=self.batch_normalization_renorm_clipping, renorm_momentum=self.batch_normalization_renorm_momentum, name=f"BatchNorm/in") self._trainable_weights.extend(bn.trainable_weights) self._non_trainable_weights.extend(bn.non_trainable_weights) self.add_loss(bn.losses) self._stacked_layers.append(bn) for i, dim in enumerate(self.output_dims): is_final_layer = (i == (self._n_layers - 1)) dense = Dense(name=f"dense_{i}", units=dim, activation=None, use_bias=self.use_bias, kernel_initializer=self._initializers[i], bias_initializer=tf.initializers.zeros(), allow_kernel_norm=self.use_weight_norm, kernel_norm_trainable=self.use_learnable_weight_norm, kernel_regularizer=self.kernel_regularizer, bias_regularizer=self.bias_regularizer) self._trainable_weights.extend(dense.trainable_weights) self._non_trainable_weights.extend(dense.non_trainable_weights) self.add_loss(dense.losses) self._stacked_layers.append(dense) if not is_final_layer and self.enable_batch_normalization: bn = BatchNorm(momentum=self.batch_normalization_momentum, renorm=self.batch_normalization_renorm, renorm_clipping=self.batch_normalization_renorm_clipping, renorm_momentum=self.batch_normalization_renorm_momentum, name=f"BatchNorm/out") self._trainable_weights.extend(bn.trainable_weights) self._non_trainable_weights.extend(bn.non_trainable_weights) self.add_loss(bn.losses) self._stacked_layers.append(bn) if self._activations[i] is not None: self._stacked_layers.append(self._activations[i]) super(MLP, self).build(input_shape) def call(self, input, **kwargs): input_t, output_t = input, None for layer in self._stacked_layers: output_t = layer(input_t) if layer.name.endswith('dense_0') and len(kwargs) > 0: segment_names = kwargs.get('segment_names') segment_sizes = kwargs.get('segment_sizes') group_info = kwargs.get('group_info') label = kwargs.get('label') if segment_names is not None and segment_sizes is not None: feature_insight_data(input_t, segment_names, segment_sizes, weight=layer.kernel, group_info=group_info, label=label) input_t = output_t return output_t def get_config(self): config = { 'output_dims': self.output_dims, "activations": [ad_acts.serialize(act) for act in self._activations], "initializers": [ tf.initializers.serialize(init) for init in self._initializers ], "use_weight_norm": self.use_weight_norm, "use_learnable_weight_norm": self.use_learnable_weight_norm, "enable_batch_normalization": self.enable_batch_normalization, "batch_normalization_momentum": self.batch_normalization_momentum, "use_bias": self.use_bias, 'kernel_regularizer': regularizers.serialize(self.kernel_regularizer), 'bias_regularizer': regularizers.serialize(self.bias_regularizer), 'batch_normalization_renorm': self.batch_normalization_renorm, 'batch_normalization_renorm_clipping': self.batch_normalization_renorm_clipping, 'batch_normalization_renorm_momentum': self.batch_normalization_renorm_momentum } base_config = super(MLP, self).get_config() return dict(list(base_config.items()) + list(config.items())) @classmethod def from_config(cls, config): p = cls.params().copy() need_pop = [] for key, value in config.items(): if key in p: if key == 'initializers': p[key] = [tf.initializers.deserialize(init) for init in config[key]] elif key == 'activations': p[key] = [ad_acts.deserialize(act) for act in config[key]] else: p[key] = value need_pop.append(key) for key in need_pop: config.pop(key) return p.instantiate() def get_layer(self, index: int): assert index < len(self._stacked_layers) return self._stacked_layers[index] ================================================ FILE: monolith/native_training/layers/mlp_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import numpy as np import tensorflow as tf from monolith.native_training.layers.mlp import MLP class MLPTest(tf.test.TestCase): def test_mlp_instantiate(self): dense_layer_template = MLP.params() test_params0 = dense_layer_template.copy() test_params0.name = 'test_dense0' test_params0.output_dims = [1, 3, 4, 5] test_params0.activations = None test_params0.initializers = tf.keras.initializers.GlorotNormal() mlp1 = test_params0.instantiate() print(mlp1) mlp2 = MLP(output_dims=[1, 3, 4, 5], activations=None, initializers=tf.keras.initializers.HeUniform()) print(mlp2) def test_mlp_serde(self): dense_layer_template = MLP.params() test_params0 = dense_layer_template.copy() test_params0.name = 'test_dense0' test_params0.output_dims = [1, 3, 4, 5] test_params0.activations = None test_params0.initializers = tf.keras.initializers.GlorotNormal() mlp1 = test_params0.instantiate() cfg = mlp1.get_config() mlp2 = MLP.from_config(cfg) print(mlp1, mlp2) def test_mlp_call(self): dense_layer_template = MLP.params() test_params0 = dense_layer_template.copy() test_params0.name = 'test_dense0' test_params0.output_dims = [100, 50, 10, 1] test_params0.enable_batch_normalization = True test_params0.activations = [ 'relu', tf.keras.activations.tanh, tf.keras.layers.PReLU, None ] test_params0.initializers = tf.keras.initializers.GlorotNormal() layer = test_params0.instantiate() data = tf.keras.backend.variable(np.ones((100, 100))) sum_out = tf.reduce_sum(layer(data)) self.assertEqual(len(layer._stacked_layers), 11) with self.session() as sess: sess.run(tf.compat.v1.global_variables_initializer()) print(sess.run(sum_out)) if __name__ == '__main__': tf.compat.v1.disable_eager_execution() tf.test.main() ================================================ FILE: monolith/native_training/layers/multi_task.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math import numpy as np from typing import Union, List, Optional, Any import tensorflow as tf from tensorflow.keras import backend as K from tensorflow.keras import activations, initializers, regularizers, constraints from tensorflow.keras.layers import Layer, InputSpec from monolith.core.base_layer import add_layer_loss from monolith.native_training.utils import with_params import monolith.native_training.layers.advanced_activations as ad_acts from monolith.native_training.monolith_export import monolith_export from monolith.native_training.layers.mlp import MLP from monolith.native_training.layers.dense import Dense @monolith_export @with_params class MMoE(Layer): """MMoE (Multi-gate Mixture of Experts) 是 MTL (Multi-task Training) 多任务学习的一种结构。通过引入 Multi-gate 来描述任务之间相关性以及每个任务对底层共享参数的依赖程度。 论文可参考: https://www.kdd.org/kdd2018/accepted-papers/view/modeling-task-relationships-in-multi-task-learning-with-multi-gate-mixture- Args: num_tasks (:obj:`int`): 任务训练的数量 expert_output_dims (:obj:`List[int]`, `List[List[int]]`): 每个Expert MLP的output_dims, 可以通过两种方法来定义 1) 用`List[int]`指定, 此时, 每个Expert的结构是相同的; 2) 用`List[List[int]]`, 此时, 每个Expert MLP结构都可以不同, 内部不会处理最上层Dense 层, 所以用户必须确保每个Expert最上层的shape是相同的 expert_activations (:obj:`List[Any]`, `str`): 每个Expert激活函数, 可以用str表示, 也可以用TF中的activation expert_initializers (:obj:`List[Any]`, `str`): W的初始化器, 可以是 str 也可以用户定义使用列表,默认使用 Glorot_uniform 初始化 gate_type (:obj:`str`): 每个gate所使用的计算方式, 可以在 (softmax, topk, noise_topk) 。默认使用的是 softmax topk (:obj:`int`): 定义gate使用(topk, noise_topk)计算后保留最大的k个Expert, 默认是1 num_experts (:obj:`int`): 定义 Expert 的个数, 默认会根据 Expert 的其他参数生成个数 kernel_regularizer (:obj:`tf.regularizer`): kernel正侧化器 use_weight_norm (:obj:`bool`): 是否开启kernel_norm, 默认为True use_learnable_weight_norm (:obj:`bool`): 是否让kernel_norm可训练, 默认为True use_bias (:obj:`bool`): 是否使用bias, 默认为True bias_regularizer (:obj:`tf.regularizer`): bias正侧化 enable_batch_normalization (:obj:`bool`): 是否开启batch normalization, 如果开启, 会对输入数据, 及每个Dense Layer的输出匀做 BatchNorm (最后一个Dense Layer除外). batch_normalization_momentum (:obj:`float`): BatchNorm中的动量因子 batch_normalization_renorm (:obj:`bool`): 是否使用renorm, (论文可参考 https://arxiv.org/abs/1702.03275) batch_normalization_renorm_clipping (:obj:`bool`): renorm中的clipping, 具体请参考TF中的 `BatchNormalization`_ batch_normalization_renorm_momentum (:obj:`float`): renorm中的momentum, 具体请参考TF中的 `BatchNormalization`_ .. _BatchNormalization: https://www.tensorflow.org/api_docs/python/tf/keras/layers/BatchNormalization """ def __init__(self, num_tasks: int, expert_output_dims: Union[List[int], List[List[int]]], expert_activations: Union[str, List[Any]], expert_initializers: Union[str, List[Any]] = 'glorot_uniform', gate_type: str = 'softmax', top_k: int = 1, num_experts: Optional[int] = None, kernel_regularizer=None, use_weight_norm=True, use_learnable_weight_norm=True, use_bias=True, bias_regularizer=None, enable_batch_normalization=False, batch_normalization_momentum=0.99, batch_normalization_renorm=False, batch_normalization_renorm_clipping=None, batch_normalization_renorm_momentum=0.99, **kwargs): super(MMoE, self).__init__(**kwargs) assert gate_type in {'softmax', 'topk', 'noise_topk'} self._gate_type = gate_type if num_experts is None: if all(isinstance(dims, (list, tuple)) for dims in expert_output_dims): self._num_experts = len(expert_output_dims) elif isinstance(expert_activations, (list, tuple)): self._num_experts = len(expert_activations) elif isinstance(expert_initializers, (list, tuple)): self._num_experts = len(expert_initializers) else: raise Exception('num_experts not set') else: self._num_experts = num_experts if all(isinstance(dims, (list, tuple)) for dims in expert_output_dims): last_dim = expert_output_dims[0][-1] for dims in expert_output_dims: assert last_dim == dims[-1] self._expert_output_dims = expert_output_dims else: self._expert_output_dims = [expert_output_dims] * self._num_experts if isinstance(expert_activations, (tuple, list)): assert len(expert_activations) == self._num_experts self._expert_activations = [ activations.get(act) for act in expert_activations ] else: self._expert_activations = [ activations.get(expert_activations) for _ in range(self._num_experts) ] if isinstance(expert_initializers, (tuple, list)): assert len(expert_initializers) == self._num_experts self._expert_initializers = [ initializers.get(init) for init in expert_initializers ] else: self._expert_initializers = [ initializers.get(expert_initializers) for _ in range(self._num_experts) ] self._top_k = top_k self._num_tasks = num_tasks self.use_weight_norm = use_weight_norm self.use_learnable_weight_norm = use_learnable_weight_norm self.kernel_regularizer = regularizers.get(kernel_regularizer) self.use_bias = use_bias self.bias_regularizer = regularizers.get(bias_regularizer) self.enable_batch_normalization = enable_batch_normalization self.batch_normalization_momentum = batch_normalization_momentum self.batch_normalization_renorm = batch_normalization_renorm self.batch_normalization_renorm_clipping = batch_normalization_renorm_clipping self.batch_normalization_renorm_momentum = batch_normalization_renorm_momentum def build(self, input_shape): self.experts = [] for i in range(self._num_experts): mlp = MLP(name=f'expert_{i}', output_dims=self._expert_output_dims[i], activations=self._expert_activations[i], initializers=self._expert_initializers[i], kernel_regularizer=self.kernel_regularizer, use_weight_norm=self.use_weight_norm, use_learnable_weight_norm=self.use_learnable_weight_norm, use_bias=self.use_bias, bias_regularizer=self.bias_regularizer, enable_batch_normalization=self.enable_batch_normalization, batch_normalization_momentum=self.batch_normalization_momentum, batch_normalization_renorm=self.batch_normalization_renorm, batch_normalization_renorm_clipping=self. batch_normalization_renorm_clipping, batch_normalization_renorm_momentum=self. batch_normalization_renorm_momentum) self._trainable_weights.extend(mlp.trainable_weights) self._non_trainable_weights.extend(mlp.non_trainable_weights) self.experts.append(mlp) # input_shape: [TensorShape([bz, dim1]), TensorShape([bz, dim2])] if all(isinstance(shape, tf.TensorShape) for shape in input_shape): gate_input_dim = input_shape[-1].as_list()[-1] elif all( isinstance(shape, tf.compat.v1.Dimension) for shape in input_shape): gate_input_dim = input_shape[-1].value else: assert isinstance(input_shape[-1], int) gate_input_dim = input_shape[-1] gate_shape = (gate_input_dim, self._num_experts * self._num_tasks) self._gate_weight = self.add_weight(name="gate_weight", shape=gate_shape, dtype=tf.float32, initializer=initializers.Zeros(), trainable=True) if self._gate_type == 'noise_topk': self._gate_noise = self.add_weight( name="gate_noise", shape=gate_shape, dtype=tf.float32, initializer=initializers.GlorotNormal(), trainable=True) else: self._gate_noise = None super(MMoE, self).build(input_shape) def calc_gate(self, gate_input: tf.Tensor): # (batch, num_tasks * num_experts) gete_logit = tf.matmul(gate_input, self._gate_weight) if self._gate_type == 'noise_topk': noise = tf.random.normal(shape=tf.shape(gete_logit)) noise = noise * tf.nn.softplus(tf.matmul(gate_input, self._gate_noise)) gete_logit = gete_logit + noise # (batch, num_tasks, num_experts) gete_logit = tf.reshape(gete_logit, shape=(-1, self._num_tasks, self._num_experts)) gates = tf.nn.softmax(gete_logit, axis=2) if self._gate_type in {'topk', 'noise_topk'}: # (batch, num_tasks, top_k) top_gates, _ = tf.nn.top_k(gates, self._top_k) # (batch, num_tasks, 1) threshold = tf.reduce_min(top_gates, axis=2, keepdims=True) # (batch, num_tasks, num_experts) gates = tf.where(gates >= threshold, gates, tf.zeros_like(gates, dtype=gates.dtype)) gates /= tf.reduce_sum(gates, axis=2, keepdims=True) # normalize # (batch, num_experts, num_tasks) return tf.transpose(gates, perm=[0, 2, 1]) def call(self, inputs, **kwargs): if isinstance(inputs, (list, tuple)): assert len(inputs) == 2 expert_input, gate_input = inputs else: inputs = tf.convert_to_tensor(inputs) expert_input = inputs gate_input = inputs # (batch, output_dim, num_experts) expert_outputs = tf.stack([expert(expert_input) for expert in self.experts], axis=2) # (batch, num_experts, num_tasks) gates = self.calc_gate(gate_input) if self._gate_type != 'softmax': # add layer loss # (num_experts, num_tasks) importance = tf.reduce_sum(gates, axis=0) # (num_tasks, ) mean, variance = tf.nn.moments(importance, [0]) cv_square = variance / tf.square(mean) self.add_loss(cv_square) # (batch, output_dim, num_tasks) mmoe_output = tf.matmul(expert_outputs, gates) # (batch, output_dim) * num_tasks final_outputs = tf.unstack(mmoe_output, axis=2) return final_outputs def get_config(self): config = { 'num_tasks': self._num_tasks, 'num_experts': self._num_experts, 'expert_output_dims': self._expert_output_dims, "expert_activations": [ ad_acts.serialize(act) for act in self._expert_activations ], "expert_initializers": [ tf.initializers.serialize(init) for init in self._expert_initializers ], 'gate_type': self._gate_type, 'top_k': self._top_k, "use_weight_norm": self.use_weight_norm, "use_learnable_weight_norm": self.use_learnable_weight_norm, "enable_batch_normalization": self.enable_batch_normalization, "batch_normalization_momentum": self.batch_normalization_momentum, "use_bias": self.use_bias, 'kernel_regularizer': regularizers.serialize(self.kernel_regularizer), 'bias_regularizer': regularizers.serialize(self.bias_regularizer), 'batch_normalization_renorm': self.batch_normalization_renorm, 'batch_normalization_renorm_clipping': self.batch_normalization_renorm_clipping, 'batch_normalization_renorm_momentum': self.batch_normalization_renorm_momentum } base_config = super(MMoE, self).get_config() return dict(list(base_config.items()) + list(config.items())) @tf.custom_gradient def hard_concrete_ste(x): out = tf.minimum(1.0, tf.maximum(x, 0.0)) def grad(dy): return dy return out, grad @monolith_export @with_params class SNR(Layer): """SNR (Sub-Network Routing) 是为了解决多任务学习(MTL)中任务之间相关性不大导致出现训练效果不好(Negative Transfer)而提出的 一种灵活共享参数的方法, 论文可参考: https://ojs.aaai.org/index.php/AAAI/article/view/3788 Args: num_out_subnet (:obj:`int`): 表示Sub_Network (Expert) 输出的个数 out_subnet_dim (:obj:`int`): 表示Sub_Network (Expert) 输出的维度 snr_type (:obj:`str`): 表示Sub_Networks之前的连接的结构, 可以在 ('trans', 'aver'), 默认使用 'trans' zeta (:obj:`float`): 表示改变Conrete分布范围的上界 gamma (:obj:`float`): 表示改变Conrete分布范围的下界 beta (:obj:`float`): 表示Concrete分布的温度因子, 用于决定分布的平滑程度 use_ste: (:obj:`bool`): 表示是否使用STE (Straight-Through Estimator), 默认为False mode (:obj:`str`): 表示tf.esitimator.Estimator的模式, 默认是训练模式 initializer (:obj:`str`): 表示参数W的初始化器, 配合 'trans' 结构默认使用glorot_uniform regularizer (:obj:`tf.regularizer`): 表示参数W的正则化 """ def __init__(self, num_out_subnet: int, out_subnet_dim: int, snr_type: str = 'trans', zeta: float = 1.1, gamma: float = -0.1, beta: float = 0.5, use_ste: bool = False, mode: str = tf.estimator.ModeKeys.TRAIN, initializer='glorot_uniform', regularizer=None, **kwargs): assert snr_type in {'trans', 'aver'} self._mode = mode self._num_out_subnet = num_out_subnet self._out_subnet_dim = out_subnet_dim self._num_in_subnet = None self._in_subnet_dim = None self._snr_type = snr_type self._weight = None self._log_alpha = None self._beta = beta self._zeta = zeta self._gamma = gamma self._use_ste = use_ste self._mode = mode self._initializer = initializers.get(initializer) self._regularizer = regularizers.get(regularizer) super(SNR, self).__init__(**kwargs) def build(self, input_shape): assert isinstance(input_shape, (list, tuple)) self._num_in_subnet = len(input_shape) in_subnet_dim = 0 for i, shape in enumerate(input_shape): last_dim = shape[-1] if not isinstance(last_dim, int): last_dim = last_dim.value if i == 0: in_subnet_dim = last_dim else: assert in_subnet_dim == last_dim self._in_subnet_dim = in_subnet_dim num_route = self._num_in_subnet * self._num_out_subnet block_size = self._in_subnet_dim * self._out_subnet_dim self._log_alpha = self.add_weight(name='log_alpha', shape=(num_route, 1), initializer=initializers.Zeros(), trainable=True) factor = self._beta * math.log(-self._gamma / self._zeta) l0_loss = tf.reduce_sum(tf.sigmoid(self._log_alpha - factor)) self.add_loss(l0_loss) if self._snr_type == 'trans': self._weight = self.add_weight(name='weight', dtype=tf.float32, shape=(num_route, block_size), initializer=self._initializer, regularizer=self._regularizer, trainable=True) else: assert self._snr_type == 'aver' and self._in_subnet_dim == self._out_subnet_dim self._weight = tf.tile(tf.reshape(tf.eye(self._in_subnet_dim), shape=(1, block_size)), multiples=(num_route, 1)) super(SNR, self).build(input_shape) def sample(self): if self._mode != tf.estimator.ModeKeys.PREDICT: num_route = self._num_in_subnet * self._num_out_subnet u = tf.random.uniform(shape=(num_route, 1), minval=0, maxval=1) s = tf.sigmoid((tf.math.log(u) - tf.math.log(1.0 - u) + self._log_alpha) / self._beta) else: s = tf.sigmoid(self._log_alpha) s_ = s * (self._zeta - self._gamma) + self._gamma if self._use_ste: z = hard_concrete_ste(s_) else: z = tf.minimum(1.0, tf.maximum(s_, 0.0)) return z def call(self, inputs, **kwargs): z = self.sample() weight = tf.multiply(self._weight, z) shape1 = (self._num_in_subnet, self._num_out_subnet, self._in_subnet_dim, self._out_subnet_dim) shape2 = (self._num_in_subnet * self._in_subnet_dim, self._num_out_subnet * self._out_subnet_dim) weight = tf.reshape( tf.transpose(tf.reshape(weight, shape1), perm=[0, 2, 1, 3]), shape2) return tf.split(tf.matmul(tf.concat(inputs, axis=1), weight), num_or_size_splits=self._num_out_subnet, axis=1) def get_config(self): config = { 'num_out_subnet': self._num_out_subnet, 'out_subnet_dim': self._out_subnet_dim, 'snr_type': self._snr_type, 'zeta': self._zeta, 'gamma': self._gamma, 'beta': self._beta, 'use_ste': self._use_ste, 'mode': self._mode, 'initializer': initializers.serialize(self._initializer), 'regularizer': regularizers.serialize(self._regularizer) } base_config = super(SNR, self).get_config() return dict(list(base_config.items()) + list(config.items())) ================================================ FILE: monolith/native_training/layers/multi_task_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import numpy as np import tensorflow as tf from monolith.native_training.layers.multi_task import MMoE, SNR class MultiTaskTest(tf.test.TestCase): def test_mmoe_instantiate(self): mmoe_layer_template = MMoE.params() test_params0 = mmoe_layer_template.copy() test_params0.name = 'test_mmoe' test_params0.num_tasks = 2 test_params0.num_experts = 3 test_params0.expert_output_dims = [128, 64, 64] test_params0.expert_activations = 'relu' test_params0.expert_initializers = tf.keras.initializers.GlorotNormal() mmoe1 = test_params0.instantiate() print(mmoe1) mmoe2 = MMoE(num_tasks=2, num_experts=3, expert_output_dims=[128, 64, 64], expert_activations='relu', expert_initializers=tf.keras.initializers.GlorotNormal()) print(mmoe2) def test_mmoe_serde(self): mmoe_layer_template = MMoE.params() test_params0 = mmoe_layer_template.copy() test_params0.name = 'test_mmoe' test_params0.num_tasks = 2 test_params0.num_experts = 3 test_params0.expert_output_dims = [128, 64, 64] test_params0.expert_activations = 'relu' test_params0.expert_initializers = tf.keras.initializers.GlorotNormal() mmoe1 = test_params0.instantiate() cfg = mmoe1.get_config() mmoe2 = MMoE.from_config(cfg) print(mmoe1, mmoe2) def test_mmoe_call(self): layer = MMoE(num_tasks=2, num_experts=3, gate_type='topk', top_k=2, expert_output_dims=[[128, 64, 64], [64, 64], [128, 64]], expert_activations='relu', expert_initializers=tf.keras.initializers.GlorotNormal()) dense_data = tf.keras.backend.variable(np.ones((100, 128))) # mmoe_data = tf.keras.backend.variable(np.ones((100, 64))) # sum_out = tf.reduce_sum(layer([dense_data, mmoe_data])) sum_out = tf.reduce_sum(layer(dense_data)) with self.session() as sess: sess.run(tf.compat.v1.global_variables_initializer()) print(sess.run(sum_out)) def test_snr_instantiate(self): snr_layer_template = SNR.params() test_params0 = snr_layer_template.copy() test_params0.name = 'test_snr' test_params0.num_out_subnet = 3 test_params0.out_subnet_dim = 128 test_params0.use_ste = False snr1 = test_params0.instantiate() print(snr1) snr2 = SNR(num_out_subnet=3, out_subnet_dim=128, use_ste=False) print(snr2) def test_snr_serde(self): snr_layer_template = SNR.params() test_params0 = snr_layer_template.copy() test_params0.name = 'test_snr' test_params0.num_out_subnet = 3 test_params0.out_subnet_dim = 128 test_params0.use_ste = False snr1 = test_params0.instantiate() print(snr1) cfg = snr1.get_config() snr2 = SNR.from_config(cfg) print(snr1, snr2) def test_snr_call(self): layer = SNR(num_out_subnet=3, out_subnet_dim=128, snr_type='aver', use_ste=False, mode=tf.estimator.ModeKeys.PREDICT) snr_data1 = tf.keras.backend.variable(np.ones((100, 128))) snr_data2 = tf.keras.backend.variable(np.ones((100, 128))) snr_data3 = tf.keras.backend.variable(np.ones((100, 128))) snr_data4 = tf.keras.backend.variable(np.ones((100, 128))) sum_out = tf.reduce_sum(layer([snr_data1, snr_data2, snr_data3, snr_data4])) with self.session() as sess: sess.run(tf.compat.v1.global_variables_initializer()) print(sess.run(sum_out)) if __name__ == '__main__': tf.compat.v1.disable_eager_execution() tf.test.main() ================================================ FILE: monolith/native_training/layers/norms.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import tensorflow as tf import tensorflow.keras.initializers as initializers from tensorflow.python.keras import regularizers from tensorflow.keras.layers import Layer, InputSpec from monolith.core.base_layer import add_layer_loss from monolith.native_training.utils import with_params from monolith.native_training.monolith_export import monolith_export from monolith.native_training.layers.utils import check_dim, dim_size @with_params class BatchNorm(Layer): def __init__(self, momentum=0.99, center=True, scale=True, moving_mean_initializer=initializers.Zeros(), moving_variance_initializer=initializers.Ones(), beta_initializer=initializers.Zeros(), gamma_initializer=initializers.Ones(), regularizer=None, training_use_global_dist=False, global_dist_momentum=1.0, stop_grad_of_var_mean=False, epsilon=1e-6, mode=tf.estimator.ModeKeys.TRAIN, **kwargs): super(BatchNorm, self).__init__(**kwargs) self.momentum = momentum self.epsilon = epsilon self.center = center self.scale = scale self.beta_initializer = initializers.get(beta_initializer) self.gamma_initializer = initializers.get(gamma_initializer) self.moving_mean_initializer = initializers.get(moving_mean_initializer) self.moving_variance_initializer = initializers.get( moving_variance_initializer) self.training_use_global_dist = training_use_global_dist self.global_dist_momentum = global_dist_momentum self.stop_grad_of_var_mean = stop_grad_of_var_mean self.mode = mode self.regularizer = regularizers.get(regularizer) self.input_spec = InputSpec(min_ndim=2) def build(self, input_shape): assert len(input_shape) >= 2 self.input_dim = check_dim(input_shape[-1]) self.moving_mean = self.add_weight(name='moving_mean', shape=[self.input_dim], dtype=self.dtype, initializer=self.moving_mean_initializer) self.moving_variance = self.add_weight( name='moving_variance', dtype=self.dtype, shape=[self.input_dim], initializer=self.moving_variance_initializer) if self.center: self.beta_offset = self.add_weight(name='beta_offset', dtype=self.dtype, shape=[self.input_dim], initializer=self.beta_initializer, regularizer=self.regularizer) else: self.beta_offset = tf.constant(0.0, dtype=tf.float32) if self.scale: self.gamma_scale = self.add_weight(name='gamma_scale', dtype=self.dtype, shape=[self.input_dim], initializer=self.gamma_initializer, regularizer=self.regularizer) else: self.gamma_scale = tf.constant(1.0, dtype=tf.float32) self.input_spec = InputSpec(min_ndim=2, axes={-1: self.input_dim}) super(BatchNorm, self).build(input_shape) def call(self, inputs, **kwargs): @tf.custom_gradient def replace_gradient(original_moving_average, gradient): def grad(dy): return gradient, None return tf.identity(original_moving_average), grad if self.mode == tf.estimator.ModeKeys.TRAIN: if self.stop_grad_of_var_mean: self.mean, self.variance = tf.nn.moments(tf.stop_gradient(inputs), axes=[0]) else: self.mean, self.variance = tf.nn.moments(inputs, axes=[0]) # replace moving average gradient by mean & variance in current minibatch moving_mean = replace_gradient(self.moving_mean, self.mean) moving_variance = replace_gradient(self.moving_variance, self.variance) moving_variance = tf.maximum(moving_variance, tf.constant(0, dtype=tf.float32)) add_layer_loss('{}_moving_mean'.format(self.name), tf.reduce_sum(moving_mean)) add_layer_loss('{}_moving_variance'.format(self.name), tf.reduce_sum(moving_variance)) if self.training_use_global_dist: mean = self.global_dist_momentum * moving_mean + \ (1.0 - self.global_dist_momentum) * self.mean variance = self.global_dist_momentum * moving_variance + \ (1.0 - self.global_dist_momentum) * self.variance else: mean, variance = self.mean, self.variance tf.compat.v1.summary.histogram(self.name + '/mean_train', mean) tf.compat.v1.summary.scalar(self.name + '/mean_train', tf.reduce_mean(mean)) tf.compat.v1.summary.histogram(self.name + '/var_train', variance) tf.compat.v1.summary.scalar(self.name + '/var_train', tf.reduce_mean(variance)) else: moving_variance = tf.maximum(self.moving_variance, tf.constant(0, dtype=tf.float32)) mean, variance = tf.stop_gradient( self.moving_mean), tf.stop_gradient(moving_variance) tf.compat.v1.summary.histogram(self.name + '/mean_test', mean) tf.compat.v1.summary.scalar(self.name + '/mean_test', tf.reduce_mean(mean)) tf.compat.v1.summary.histogram(self.name + '/var_test', variance) tf.compat.v1.summary.scalar(self.name + '/var_test', tf.reduce_mean(variance)) output = tf.nn.batch_normalization(inputs, mean, variance, self.beta_offset, self.gamma_scale, self.epsilon) return output def set_use_global_dist(self, training_use_global_dist): assert type(training_use_global_dist) is bool self.training_use_global_dist = training_use_global_dist def get_config(self): config = { 'momentum': self.momentum, 'epsilon': self.epsilon, 'center': self.center, 'scale': self.scale, 'beta_initializer': initializers.serialize(self.beta_initializer), 'gamma_initializer': initializers.serialize(self.gamma_initializer), 'moving_mean_initializer': initializers.serialize(self.moving_mean_initializer), 'moving_variance_initializer': initializers.serialize(self.moving_variance_initializer), 'training_use_global_dist': self.training_use_global_dist, 'global_dist_momentum': self.global_dist_momentum, 'stop_grad_of_var_mean': self.stop_grad_of_var_mean, 'mode': self.mode, 'regularizer': regularizers.serialize(self.regularizer), } base_config = super(BatchNorm, self).get_config() return dict(list(base_config.items()) + list(config.items())) @monolith_export @with_params class LayerNorm(Layer): """与BatchNorm类似, 但是是在样本内做归一化. 与BatchNorm不同的是LayerNorm在训练与推理时, 使用同一套逻辑, 不用分别处理 Args: initializer (:obj:`tf.initializer`): gamma的初始化器 regularizer (:obj:`tf.regularizer`): beta/gamma的变量的正则化 """ def __init__(self, initializer, regularizer=None, **kwargs): super(LayerNorm, self).__init__(**kwargs) self.beta, self.gamma = None, None self.initializer = initializers.get(initializer) or initializers.Ones() self.regularizer = regularizers.get(regularizer) def build(self, input_shape): params_shape = [check_dim(input_shape[-1])] self.beta = self.add_weight(name='beta', dtype=tf.float32, shape=params_shape, initializer=initializers.Zeros(), regularizer=self.regularizer) self.gamma = self.add_weight(name='gamma', dtype=tf.float32, shape=params_shape, initializer=self.initializer, regularizer=self.regularizer) super(LayerNorm, self).build(input_shape) def call(self, inputs, **kwargs): mean, variance = tf.nn.moments(inputs, [-1], keepdims=True) output = tf.nn.batch_normalization(inputs, mean, variance, self.beta, self.gamma, variance_epsilon=1e-6) return output def get_config(self): config = { 'initializer': initializers.serialize(self.initializer), 'regularizer': regularizers.serialize(self.regularizer), } base_config = super(LayerNorm, self).get_config() return dict(list(base_config.items()) + list(config.items())) @monolith_export @with_params class GradNorm(Layer): """GradNorm提出通过将不同任务的梯度控制在一定的范围来进行多任务学习, 论文可参考 https://arxiv.org/abs/1711.02257 GradNorm是通过构造辅助loss实现, 辅助的构造过程如下: - 选择shared bottom的最顶层变量W, 然后计算每个head对它的梯度 (如果顶层有多个W, 则分别计算梯度, 再concat起来) - 对上一步得到的梯度取L2范数, 得到gnorms, gnorms是一个n维向量, 长度与task的个数相同 - gnorms加权weight, 得到wgnorms, wgnorms平均, 得到avgnorm - gnorm_loss = scale * sum([(wgnorms - avgnorm) / (avgnorm + epsilon)]^loss_pow), relative_diff = True - gnorm_loss = scale * sum([wgnorms - avgnorm]^loss_pow), relative_diff = False - weighted_loss = sum(weight * losses) Args: loss_names (:obj:`str`): loss名称, 用于确定loss的个数, 写相关日志 scale (:obj:`float`): 缩放因子, 用于缩放 loss_pow (:obj:`float`): gnorm diff的指数因子 relative_diff (:obj:`bool`): gnorm diff的计算方式, 如果为True, 会计算相对值 epsilon (:obj:`float`): 一个非常小的常数, 防止除以0 """ def __init__(self, loss_names, scale=1.0, loss_pow=2.0, relative_diff=False, epsilon=1e-6, **kwargs): super(GradNorm, self).__init__(**kwargs) self.loss_names = loss_names self.scale = scale self.loss_pow = loss_pow self.relative_diff = relative_diff self.epsilon = epsilon def build(self, input_shape): n = len(self.loss_names) self.weight = self.add_weight(name='grad_norm_weights', shape=[n], dtype=tf.float32, initializer=tf.initializers.Zeros()) self._weights = tf.nn.softmax(self.weight) for i in range(n): tf.compat.v1.summary.scalar( 'gradnorm_weight/{}'.format(self.loss_names[i]), self._weights[i]) super(GradNorm, self).build(input_shape) def _get_norm(self, grad): return (tf.reduce_sum(tf.multiply(grad, grad)))**0.5 def get_weights(self): return self._weights def call(self, inputs, **kwargs): losses, shared_inputs = inputs if not isinstance(shared_inputs, list): shared_inputs = [shared_inputs] grads = [tf.gradients(loss, shared_inputs) for loss in losses] grads = [tf.concat(gs, axis=1) for gs in grads] gnorms = [self._get_norm(g) for g in grads] gnorms = tf.stop_gradient(tf.stack(gnorms, axis=0)) weights = self._weights n = len(self.loss_names) avgnorm = tf.reduce_sum(gnorms * weights) / n wgnorms = gnorms * weights grad_diff = tf.abs(wgnorms - avgnorm) if self.relative_diff: grad_diff = grad_diff / (avgnorm + self.epsilon) gnorm_loss = tf.reduce_sum(grad_diff**self.loss_pow) * self.scale weighted_loss = tf.reduce_sum( tf.stack(losses, axis=0) * tf.stop_gradient(weights)) for i in range(n): tf.compat.v1.summary.scalar( 'gradnorm_gnorm/{}'.format(self.loss_names[i]), gnorms[i]) tf.compat.v1.summary.scalar( 'gradnorm_wgnorm/{}'.format(self.loss_names[i]), wgnorms[i]) return gnorm_loss, weighted_loss def get_config(self): config = { 'loss_names': self.loss_names, 'scale': self.scale, 'loss_pow': self.loss_pow, 'relative_diff': self.relative_diff, 'epsilon': self.epsilon } base_config = super(GradNorm, self).get_config() return dict(list(base_config.items()) + list(config.items())) ================================================ FILE: monolith/native_training/layers/norms_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import numpy as np import tensorflow as tf from monolith.native_training.layers.norms import LayerNorm, GradNorm class NormTest(tf.test.TestCase): def test_ln_instantiate(self): layer_template = LayerNorm.params() test_params0 = layer_template.copy() test_params0.initializer = tf.keras.initializers.GlorotNormal() bn1 = test_params0.instantiate() print(bn1) bn2 = LayerNorm(initializer=tf.keras.initializers.HeUniform()) print(bn2) def test_ln_serde(self): layer_template = LayerNorm.params() test_params0 = layer_template.copy() test_params0.initializer = tf.keras.initializers.GlorotNormal() bn1 = test_params0.instantiate() print(bn1) cfg = bn1.get_config() bn2 = LayerNorm.from_config(cfg) print(bn1, bn2) def test_ln_call(self): bn = LayerNorm(initializer=tf.keras.initializers.HeUniform()) data = tf.keras.backend.variable(np.ones((100, 100))) sum_out = tf.reduce_sum(bn(data)) with self.session() as sess: sess.run(tf.compat.v1.global_variables_initializer()) print(sess.run(sum_out)) def test_gn_instantiate(self): layer_template = GradNorm.params() test_params0 = layer_template.copy() test_params0.loss_names = ["abc", 'defg'] bn1 = test_params0.instantiate() print(bn1) bn2 = GradNorm(loss_names=["abc", 'defg'], relative_diff=True) print(bn2) def test_gn_serde(self): layer_template = GradNorm.params() test_params0 = layer_template.copy() test_params0.loss_names = ["abc", 'defg'] bn1 = test_params0.instantiate() print(bn1) cfg = bn1.get_config() bn2 = GradNorm.from_config(cfg) print(bn1, bn2) if __name__ == '__main__': tf.compat.v1.disable_eager_execution() tf.test.main() ================================================ FILE: monolith/native_training/layers/ops/feature_insight_ops.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/shape_inference.h" namespace tensorflow { REGISTER_OP("FeatureInsight") .Input("input: float") .Input("weight: float") .Output("output: float") .Attr("segment_sizes: list(int)") .SetShapeFn([](shape_inference::InferenceContext *ctx) { std::vector segment_sizes; TF_RETURN_IF_ERROR(ctx->GetAttr("segment_sizes", &segment_sizes)); auto batch_size = ctx->Dim(ctx->input(0), 0); shape_inference::DimensionHandle out_dims; TF_RETURN_IF_ERROR(ctx->Multiply(ctx->MakeDim(segment_sizes.size()), ctx->Dim(ctx->input(1), 1), &out_dims)); ctx->set_output(0, ctx->Matrix(batch_size, out_dims)); return Status::OK(); }); REGISTER_OP("FeatureInsightGrad") .Input("grad: float") .Input("input: float") .Input("weight: float") .Output("input_grad: float") .Output("weight_grad: float") .Attr("segment_sizes: list(int)") .Attr("K: int") .SetShapeFn([](shape_inference::InferenceContext *ctx) { ctx->set_output(0, ctx->input(1)); ctx->set_output(1, ctx->input(2)); return Status::OK(); }); } // namespace tensorflow ================================================ FILE: monolith/native_training/layers/ops/ffm_ops.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/shape_inference.h" namespace tensorflow { REGISTER_OP("FFM") .Input("left: float") .Input("right: float") .Output("output: float") .Attr("dim_size: int") .Attr("int_type: string") .SetShapeFn([](shape_inference::InferenceContext *ctx) { int dim_size; TF_RETURN_IF_ERROR(ctx->GetAttr("dim_size", &dim_size)); auto batch_size = ctx->Dim(ctx->input(0), 0); std::string int_type; TF_RETURN_IF_ERROR(ctx->GetAttr("int_type", &int_type)); shape_inference::DimensionHandle tmp_dims; ctx->Multiply(ctx->DimKnownRank(ctx->input(0), 1), ctx->DimKnownRank(ctx->input(1), 1), &tmp_dims); shape_inference::DimensionHandle out_dims; if (int_type == "dot") { ctx->Divide(tmp_dims, dim_size * dim_size, true, &out_dims); } else { ctx->Divide(tmp_dims, dim_size, true, &out_dims); } ctx->set_output(0, ctx->Matrix(batch_size, out_dims)); return Status::OK(); }); REGISTER_OP("FFMGrad") .Input("grad: float") .Input("left: float") .Input("right: float") .Output("left_grad: float") .Output("right_grad: float") .Attr("dim_size: int") .Attr("int_type: string") .SetShapeFn([](shape_inference::InferenceContext *ctx) { ctx->set_output(0, ctx->input(1)); ctx->set_output(1, ctx->input(2)); return Status::OK(); }); } // namespace tensorflow ================================================ FILE: monolith/native_training/layers/ops/fid_counter_op.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/shape_inference.h" namespace tensorflow { REGISTER_OP("MonolithFidCounter") .Input("counter: float") .Output("output: float") .Attr("step: float") .Attr("counter_threshold: int") .SetShapeFn([](shape_inference::InferenceContext *ctx) { ctx->set_output(0, ctx->input(0)); return Status::OK(); }); } // namespace tensorflow ================================================ FILE: monolith/native_training/layers/ops/nas_ops.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/shape_inference.h" namespace tensorflow { REGISTER_OP("BernoulliGate") .Input("alpha: float") .Output("sampled: float") .Output("proba: float") .Attr("ste_type: string") .Attr("use_logistic: bool") .Attr("temperature: float") .SetShapeFn([](shape_inference::InferenceContext *ctx) { ctx->set_output(0, ctx->input(0)); return Status::OK(); }); REGISTER_OP("BernoulliGateGrad") .Input("grad: float") .Input("alpha: float") .Input("proba: float") .Output("output: float") .Attr("ste_type: string") .Attr("use_logistic: bool") .Attr("temperature: float") .SetShapeFn([](shape_inference::InferenceContext *ctx) { ctx->set_output(0, ctx->input(0)); return Status::OK(); }); REGISTER_OP("DiscreteGate") .Input("alpha: float") .Output("sampled: float") .Output("proba: float") .Attr("is_one_hot: bool") .Attr("use_gumbel: bool") .Attr("temperature: float") .SetShapeFn([](shape_inference::InferenceContext *ctx) { ctx->set_output(0, ctx->input(0)); ctx->set_output(1, ctx->input(0)); return Status::OK(); }); REGISTER_OP("DiscreteGateGrad") .Input("grad: float") .Input("sampled: float") .Input("proba: float") .Output("output: float") .Attr("is_one_hot: bool") .Attr("temperature: float") .SetShapeFn([](shape_inference::InferenceContext *ctx) { ctx->set_output(0, ctx->input(0)); return Status::OK(); }); REGISTER_OP("DiscreteTruncatedGate") .Input("alpha: float") .Output("sampled: float") .Output("proba: float") .Attr("threshold: float") .Attr("drop_first_dim: bool") .Attr("use_gumbel: bool") .Attr("temperature: float") .SetShapeFn([](shape_inference::InferenceContext *ctx) { bool drop_first_dim; TF_RETURN_IF_ERROR(ctx->GetAttr("drop_first_dim", &drop_first_dim)); if (drop_first_dim) { shape_inference::DimensionHandle alpha_dim = ctx->Dim(ctx->input(0), 0); shape_inference::DimensionHandle sampled_dim; ctx->Subtract(alpha_dim, 1, &sampled_dim); ctx->set_output(0, ctx->Vector(sampled_dim)); } else { ctx->set_output(0, ctx->input(0)); } ctx->set_output(1, ctx->input(0)); return Status::OK(); }); REGISTER_OP("DiscreteTruncatedGateGrad") .Input("grad: float") .Input("sampled: float") .Input("proba: float") .Output("output: float") .Attr("threshold: float") .Attr("drop_first_dim: bool") .Attr("temperature: float") .SetShapeFn([](shape_inference::InferenceContext *ctx) { ctx->set_output(0, ctx->input(2)); return Status::OK(); }); REGISTER_OP("NasArchWeight") .Input("arch_weight: float") .Input("global_step: float") .Output("arch_output: float") .Attr("update_rate: float") .SetShapeFn([](shape_inference::InferenceContext *ctx) { ctx->set_output(0, ctx->input(0)); return Status::OK(); }); REGISTER_OP("BinaryMaskToSlotWeight") .Input("weight: float") .Input("mask: int32") .Input("shape: int32") .Output("output: float") .SetShapeFn([](shape_inference::InferenceContext *ctx) { ctx->set_output(0, ctx->Vector(ctx->UnknownDim())); return Status::OK(); }); REGISTER_OP("BinaryMaskToSlotWeightGrad") .Input("grad: float") .Input("mask: int32") .Input("shape: int32") .Output("output: float") .SetShapeFn([](shape_inference::InferenceContext *ctx) { ctx->set_output(0, ctx->Vector(ctx->UnknownDim())); return Status::OK(); }); REGISTER_OP("SegmentDiscreteGate") .Input("proba: float") .Input("segment_sizes: int32") .Output("output: float") .Attr("use_gumbel: bool") .SetShapeFn([](shape_inference::InferenceContext *ctx) { ctx->set_output(0, ctx->input(0)); return Status::OK(); }); } // namespace tensorflow ================================================ FILE: monolith/native_training/layers/pooling.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from tensorflow.keras.layers import Layer from tensorflow.python.ops import math_ops from tensorflow.python.ops import array_ops from monolith.native_training.utils import with_params, check_list from monolith.native_training.monolith_export import monolith_export @monolith_export class Pooling(Layer): """Pooling基类 Args: kwargs (:obj:`dict`): 其它位置参数, 详情请参考 `TF Layer`_ .. _TF Layer: https://www.tensorflow.org/api_docs/python/tf/keras/layers/Layer """ def __init__(self, **kwargs): super(Pooling, self).__init__(**kwargs) def pool(self, vec_list): raise NotImplementedError def call(self, vec_list, **kwargs): check_list(vec_list, lambda x: x > 0) if len(vec_list) == 1: return vec_list[0] return self.pool(vec_list) @monolith_export @with_params class SumPooling(Pooling): """Sum pooling, 加法池化 Args: kwargs (:obj:`dict`): 其它位置参数, 详情请参考 `TF Layer`_ .. _TF Layer: https://www.tensorflow.org/api_docs/python/tf/keras/layers/Layer """ def __init__(self, **kwargs): super(SumPooling, self).__init__(**kwargs) def pool(self, vec_list): return math_ops.add_n(vec_list) @monolith_export @with_params class AvgPooling(Pooling): """Average pooling, 平匀池化 Args: kwargs (:obj:`dict`): 其它位置参数, 详情请参考 `TF Layer`_ .. _TF Layer: https://www.tensorflow.org/api_docs/python/tf/keras/layers/Layer """ def __init__(self, **kwargs): super(AvgPooling, self).__init__(**kwargs) def pool(self, vec_list): return math_ops.add_n(vec_list) / len(vec_list) @monolith_export @with_params class MaxPooling(Pooling): """Max pooling, 最大池化 Args: kwargs (:obj:`dict`): 其它位置参数, 详情请参考 `TF Layer`_ .. _TF Layer: https://www.tensorflow.org/api_docs/python/tf/keras/layers/Layer """ def __init__(self, **kwargs): super(MaxPooling, self).__init__(**kwargs) def pool(self, vec_list): return math_ops.reduce_max(array_ops.stack(vec_list), axis=0) ================================================ FILE: monolith/native_training/layers/pooling_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import numpy as np import tensorflow as tf from monolith.native_training.layers.pooling import * class PoolingTest(tf.test.TestCase): def test_sp_instantiate(self): layer_template = SumPooling.params() test_params0 = layer_template.copy() ins1 = test_params0.instantiate() print(ins1) ins2 = SumPooling() print(ins2) def test_sp_serde(self): layer_template = SumPooling.params() test_params0 = layer_template.copy() ins1 = test_params0.instantiate() print(ins1) cfg = ins1.get_config() ins2 = SumPooling.from_config(cfg) print(ins1, ins2) def test_sp_call(self): layer_template = SumPooling.params() test_params0 = layer_template.copy() test_params0.name = 'test_dense0' layer = test_params0.instantiate() data = [ tf.keras.backend.variable(np.random.uniform(size=(100, 10))) for _ in range(5) ] sum_out = tf.reduce_sum(layer(data)) with self.session() as sess: sess.run(tf.compat.v1.global_variables_initializer()) print(sess.run(sum_out)) def test_mp_instantiate(self): layer_template = MaxPooling.params() test_params0 = layer_template.copy() ins1 = test_params0.instantiate() print(ins1) ins2 = MaxPooling() print(ins2) def test_mp_serde(self): layer_template = MaxPooling.params() test_params0 = layer_template.copy() ins1 = test_params0.instantiate() print(ins1) cfg = ins1.get_config() ins2 = MaxPooling.from_config(cfg) print(ins1, ins2) def test_mp_call(self): layer_template = MaxPooling.params() test_params0 = layer_template.copy() test_params0.name = 'test_dense0' layer = test_params0.instantiate() data = [ tf.keras.backend.variable(np.random.uniform(size=(100, 10))) for _ in range(5) ] sum_out = tf.reduce_sum(layer(data)) with self.session() as sess: sess.run(tf.compat.v1.global_variables_initializer()) print(sess.run(sum_out)) def test_ap_instantiate(self): layer_template = AvgPooling.params() test_params0 = layer_template.copy() ins1 = test_params0.instantiate() print(ins1) ins2 = AvgPooling() print(ins2) def test_ap_serde(self): layer_template = AvgPooling.params() test_params0 = layer_template.copy() ins1 = test_params0.instantiate() print(ins1) cfg = ins1.get_config() ins2 = AvgPooling.from_config(cfg) print(ins1, ins2) def test_ap_call(self): layer_template = AvgPooling.params() test_params0 = layer_template.copy() test_params0.name = 'test_dense0' layer = test_params0.instantiate() data = [ tf.keras.backend.variable(np.random.uniform(size=(100, 10))) for _ in range(5) ] sum_out = tf.reduce_sum(layer(data)) with self.session() as sess: sess.run(tf.compat.v1.global_variables_initializer()) print(sess.run(sum_out)) if __name__ == '__main__': tf.compat.v1.disable_eager_execution() tf.test.main() ================================================ FILE: monolith/native_training/layers/sparse_nas.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from glob import glob import os, re from absl import logging, flags import numpy as np from typing import List import tensorflow as tf from tensorflow.python.framework import ops from tensorflow.python.keras.layers import Layer from tensorflow.python.keras import initializers from tensorflow.python.keras.engine.input_spec import InputSpec from monolith.native_training.utils import with_params from monolith.native_training.monolith_export import monolith_export from monolith.native_training.layers.utils import check_dim, dim_size from monolith.native_training.data.feature_list import FeatureList from monolith.native_training.summary.utils import SummaryType ================================================ FILE: monolith/native_training/layers/sparse_nas_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import numpy as np import tensorflow as tf if __name__ == '__main__': tf.compat.v1.disable_eager_execution() tf.test.main() ================================================ FILE: monolith/native_training/layers/utils.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import numpy as np import tensorflow as tf from tensorflow.python.ops import array_ops from monolith.native_training.monolith_export import monolith_export @monolith_export class MergeType: CONCAT = 'concat' STACK = 'stack' NONE = None @monolith_export class DCNType: Vector = "vector" Matrix = "matrix" Mixed = "mixed" def check_dim(dim): if dim is None: return -1 elif isinstance(dim, int): return dim elif isinstance(dim, tf.compat.v1.Dimension): return dim.value else: raise Exception(f'dim {dim} is error') def dim_size(inputs, axis: int): shape = inputs.get_shape().as_list() assert len(shape) > axis dim = check_dim(shape[axis]) if dim == -1: return array_ops.shape(inputs)[axis] else: return dim @monolith_export def merge_tensor_list(tensor_list, merge_type: str = 'concat', num_feature: int = None, axis: int = 1, keep_list: bool = False): """将Tensor列表合并 Args: tensor_list (:obj:`List[tf.Tensor]`): 输入的Tensor列表 merge_type (:obj:`str`): 合并类型, 支持stack/concat两种, 如果设为None, 则不做任何处理 num_feature (:obj:`int`): 特征个数 axis (:obj:`int`): merge延哪个轴进行 keep_list (:obj:`bool`): 输出结果是否保持list """ if isinstance(tensor_list, tf.Tensor): tensor_list = [tensor_list] assert merge_type in {'stack', 'concat', None} if len(tensor_list) == 1: shapes = [check_dim(dim) for dim in tensor_list[0].get_shape().as_list()] if len(shapes) == 3: (batch_size, num_feat, emb_size) = shapes if merge_type == MergeType.STACK: output = tensor_list if keep_list else tensor_list[0] elif merge_type == MergeType.CONCAT: tensor_list[0] = tf.reshape(tensor_list[0], shape=(batch_size, num_feat * emb_size)) output = tensor_list if keep_list else tensor_list[0] else: output = tf.unstack(tensor_list[0], axis=axis) elif len(shapes) == 2 and num_feature is not None and num_feature > 1: (batch_size, emb_size) = shapes emb_size = int(emb_size / num_feature) if merge_type == MergeType.STACK: tensor_list[0] = tf.reshape(tensor_list[0], shape=(batch_size, num_feature, emb_size)) output = tensor_list if keep_list else tensor_list[0] elif merge_type == MergeType.CONCAT: output = tensor_list if keep_list else tensor_list[0] else: tensor_list[0] = tf.reshape(tensor_list[0], shape=(batch_size, num_feature, emb_size)) output = tf.unstack(tensor_list[0], axis=axis) elif len(shapes) == 2: output = tensor_list if keep_list else tensor_list[0] else: raise Exception("shape error: ({})".format(', '.join(map(str, shapes)))) elif merge_type == 'stack': stacked = tf.stack(tensor_list, axis=axis) output = [stacked] if keep_list else stacked elif merge_type == 'concat': concated = tf.concat(tensor_list, axis=axis) output = [concated] if keep_list else concated else: output = tensor_list return output EPSILON = np.finfo(tf.float32.as_numpy_dtype).tiny # Copy from: https://github.com/ermongroup/subsets/blob/master/subsets/sample_subsets.py # Reparameterizable Subset Sampling via Continuous Relaxations (https://arxiv.org/pdf/1901.10517.pdf) # [code](https://github.com/ermongroup/subsets) def gumbel_keys(w): # sample some gumbels uniform = tf.random_uniform( tf.shape(w), minval=EPSILON, maxval=1.0) z = -tf.log(-tf.log(uniform)) w = w + z return w def continuous_topk(w, k, t, separate=False): khot_list = [] onehot_approx = tf.zeros_like(w, dtype=tf.float32) for i in range(k): khot_mask = tf.maximum(1.0 - onehot_approx, EPSILON) w += tf.log(khot_mask) onehot_approx = tf.nn.softmax(w / t, axis=-1) khot_list.append(onehot_approx) if separate: return khot_list else: return tf.reduce_sum(khot_list, 0) def sample_subset(w, k, t=0.1): ''' Args: w (Tensor): Float Tensor of weights for each element. In gumbel mode these are interpreted as log probabilities k (int): number of elements in the subset sample t (float): temperature of the softmax ''' w = gumbel_keys(w) return continuous_topk(w, k, t) ================================================ FILE: monolith/native_training/learning_rate_functions.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== import abc import tensorflow as tf class LearningRateFunction(): """The learning rate function base class. You can use a learning rate function to modulate how the learning rate of your optimizer changes over time. A `LearningRateFunction` instance can be passed in as the `learning_rate` argument of any dense optimizer or as `learning_rate_fn` argument for adding feature slice of embedding table. To implement your own function object, you should implement the `__call__` method. """ @abc.abstractmethod def __call__(self): raise NotImplementedError("Learning rate function must override __call__") # Used to check whether two LearningRateFunctions have the same feature. def __str__(self): return "LearningRateFunction(\"%s\",Params:%s)" % ( self.__class__.__name__, ",".join([ "%s=%s" % (key, self.__dict__[key]) for key in sorted(self.__dict__) ])) class PolynomialDecay(LearningRateFunction): """A LearningRateFunction that uses an polynomial decay schedule. This function applies a polynomial decay function to a provided `initial_learning_rate` to reach an `end_learning_rate` in the given `decay_steps`. """ def __init__(self, initial_learning_rate, decay_steps, end_learning_rate=0.0001, power=1.0, cycle=False, name=None): """Applies polynomial decay to the learning rate. Args: initial_learning_rate: A scalar `float32` or `float64` `Tensor` or a Python number. The initial learning rate. decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number. Must be positive. See the decay computation above. end_learning_rate: A scalar `float32` or `float64` `Tensor` or a Python number. The minimal end learning rate. power: A scalar `float32` or `float64` `Tensor` or a Python number. The power of the polynomial. Defaults to linear, 1.0. cycle: A boolean, whether or not it should cycle beyond decay_steps. name: String. Optional name of the operation. Defaults to 'PolynomialDecay'. Returns: A scalar `Tensor` of the same type as `initial_learning_rate`. The decayed learning rate. """ super(PolynomialDecay, self).__init__() self.initial_learning_rate = initial_learning_rate self.decay_steps = decay_steps self.end_learning_rate = end_learning_rate self.power = power self.cycle = cycle self.name = name def __call__(self): global_step = tf.compat.v1.train.get_or_create_global_step() return tf.compat.v1.train.polynomial_decay( learning_rate=self.initial_learning_rate, global_step=global_step, decay_steps=self.decay_steps, end_learning_rate=self.end_learning_rate, power=self.power, cycle=self.cycle) ================================================ FILE: monolith/native_training/learning_rate_functions_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import numpy as np import tensorflow as tf from monolith.native_training import learning_rate_functions class PolynomialDecayTest(tf.test.TestCase): def test_basic(self): with tf.compat.v1.Session() as sess: global_step = tf.compat.v1.train.get_or_create_global_step() self.evaluate(tf.compat.v1.global_variables_initializer()) self.evaluate(tf.compat.v1.assign_add(global_step, 1)) learning_rate_fn = learning_rate_functions.PolynomialDecay( initial_learning_rate=0.01, decay_steps=10, end_learning_rate=0.11) learning_rate = self.evaluate(learning_rate_fn()) self.assertAllClose(learning_rate, 0.02, 1e-6) self.evaluate(tf.compat.v1.assign_add(global_step, 1)) learning_rate = self.evaluate(learning_rate_fn()) self.assertAllClose(learning_rate, 0.03, 1e-6) learning_rate_fn2 = learning_rate_functions.PolynomialDecay( initial_learning_rate=0.01, decay_steps=10, end_learning_rate=0.11) self.assertEqual(str(learning_rate_fn), str(learning_rate_fn2)) def test_dense_optimizer(self): with tf.compat.v1.Session() as sess: global_step = tf.compat.v1.train.get_or_create_global_step() learning_rate_fn = learning_rate_functions.PolynomialDecay( initial_learning_rate=3.0, decay_steps=10, end_learning_rate=11.0) var0 = tf.Variable([1.0, 2.0], dtype=tf.float32) var1 = tf.Variable([3.0, 4.0], dtype=tf.float32) grads0 = tf.constant([0.1, 0.1], dtype=tf.float32) grads1 = tf.constant([0.01, 0.01], dtype=tf.float32) ada_opt = tf.compat.v1.train.AdagradOptimizer( learning_rate_fn, initial_accumulator_value=0.1) ada_update = ada_opt.apply_gradients(zip([grads0, grads1], [var0, var1])) self.evaluate(tf.compat.v1.global_variables_initializer()) # Fetch params to validate initial values v0_val, v1_val = self.evaluate([var0, var1]) self.assertAllClose([1.0, 2.0], v0_val) self.assertAllClose([3.0, 4.0], v1_val) # Run 3 steps of adagrad for _ in range(3): self.evaluate(ada_update) # Validate updated params v0_val, v1_val = self.evaluate([var0, var1]) self.assertAllCloseAccordingToType( np.array([-1.6026098728179932, -0.6026098728179932]), v0_val) self.assertAllCloseAccordingToType( np.array([2.715679168701172, 3.715679168701172]), v1_val) if __name__ == '__main__': tf.compat.v1.disable_eager_execution() tf.test.main() ================================================ FILE: monolith/native_training/logging_ops.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Dict, List, Callable, Tuple import tensorflow as tf from absl import flags from monolith.native_training.runtime.ops import gen_monolith_ops flags.DEFINE_integer( "monolith_default_machine_info_mem_limit", 1 << 62, "The default value for mem_limit in machine info. (Bytes)") FLAGS = flags.FLAGS logging_ops = gen_monolith_ops def tensors_timestamp( tensors: List[tf.Tensor]) -> Tuple[List[tf.Tensor], tf.Tensor]: """Gets the timestamp when the tensors are ready.""" return logging_ops.monolith_tensors_timestamp(tensors) def emit_timer(key: str, value: tf.Tensor, tags: Dict[str, str] = None) -> tf.Operation: tags = tags or {} tag_str = "|".join([f"{k}={v}" for k, v in tags.items()]) return logging_ops.monolith_metric_v2(value, key=key, tags=tag_str) def machine_info(mem_limit=None, shared_name=None) -> tf.Tensor: """Returns a MachineInfo tensor which contains a MachineInfo resource.""" if mem_limit is None: mem_limit = FLAGS.monolith_default_machine_info_mem_limit return logging_ops.monolith_machine_info(mem_limit=mem_limit, name=shared_name, shared_name=shared_name) def check_machine_health(machine_info_tensor: tf.Tensor) -> tf.Tensor: """Returns a scalar string tensor, which is serialized version of MachineHealthResult.""" return logging_ops.monolith_check_machine_health(machine_info_tensor) ================================================ FILE: monolith/native_training/logging_ops_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import tensorflow as tf from absl import flags from monolith.native_training import logging_ops from monolith.native_training.runtime.ops import logging_ops_pb2 FLAGS = flags.FLAGS class LoggingOpsTest(tf.test.TestCase): def test_tensors_timestamp(self): tensor = [tf.constant(0)] tensor, ts = logging_ops.tensors_timestamp(tensor) tensor, new_ts = logging_ops.tensors_timestamp(tensor) with self.session() as sess: new_ts_value, ts_value = sess.run([new_ts, ts]) self.assertGreaterEqual(new_ts_value, ts_value) def test_emit_timer(self): op = logging_ops.emit_timer("test", 0.0) self.evaluate(op) def test_machine_health(self): FLAGS.monolith_default_machine_info_mem_limit = 1 << 62 info = logging_ops.machine_info() self.assertEqual(self.evaluate(logging_ops.check_machine_health(info)), b"") def test_machine_health_oom(self): FLAGS.monolith_default_machine_info_mem_limit = 0 info = logging_ops.machine_info() serialized_result = self.evaluate(logging_ops.check_machine_health(info)) result = logging_ops_pb2.MachineHealthResult() result.ParseFromString(serialized_result) self.assertEqual(result.status, logging_ops_pb2.MachineHealthResult.OUT_OF_MEMORY) if __name__ == "__main__": tf.compat.v1.disable_eager_execution() tf.test.main() ================================================ FILE: monolith/native_training/losses/BUILD ================================================ load("@rules_python//python:defs.bzl", "py_binary", "py_library", "py_test") package( default_visibility = ["//visibility:public"], ) py_library( name = "batch_softmax_loss", srcs = ["batch_softmax_loss.py"], srcs_version = "PY3", deps = [ "@org_tensorflow//tensorflow:tensorflow_py", ], ) py_test( name = "batch_softmax_loss_test", srcs = ["batch_softmax_loss_test.py"], srcs_version = "PY3", deps = [ ":batch_softmax_loss", ], ) py_binary( name = "inbatch_auc_loss", srcs = ["inbatch_auc_loss.py"], deps = [ "//monolith/native_training/runtime/ops:gen_monolith_ops", ], ) py_test( name = "inbatch_auc_loss_test", srcs = ["inbatch_auc_loss_test.py"], deps = [ ":inbatch_auc_loss", ], ) py_binary( name = "ltr_losses", srcs = ["ltr_losses.py"], deps = [ "//monolith:utils", "@org_tensorflow//tensorflow:tensorflow_py", ], ) py_library( name = "losses", srcs = [], srcs_version = "PY3", deps = [ ":batch_softmax_loss", ":inbatch_auc_loss", ":ltr_losses", ], ) ================================================ FILE: monolith/native_training/losses/batch_softmax_loss.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import tensorflow as tf def batch_softmax_loss(query: tf.Tensor, item: tf.Tensor, item_step_interval: tf.Tensor, r: tf.Tensor, normalize: bool = True, temperature: float = 1.0) -> tf.Tensor: """ Batch Softmax Loss Args: query (:obj:`tf.Tensor`): query 向量, shape=(batch_size, k) item (:obj:`tf.Tensor`): item 向量, shape=(batch_size, k) item_step_interval (:obj:`tf.Tensor`): item 出现的平均 global step 间隔, shape=(batch_size,) r (:obj:`tf.Tensor`): query 对 item 感兴趣程度权重 normalize (:obj:`bool`): 是否对 query/item 向量归一化 temperature (:obj:`float`): hyper-parameter tuned to maximize retrieval metrics such as recall or precision """ if temperature <= 0: raise ValueError( "temperature should be positive, while got {}".format(temperature)) if normalize: query = tf.linalg.l2_normalize(query, axis=1) item = tf.linalg.l2_normalize(item, axis=1) # (batch_size, batch_size) similarity = tf.matmul(query, item, transpose_b=True) / temperature # The first looked-up item_step_interval is filled by zeros item_step_interval = tf.math.maximum(item_step_interval, tf.constant([1.0], dtype=tf.float32)) item_frequency = 1 / item_step_interval similarity = tf.math.exp(similarity - tf.math.log(item_frequency)) loss = -tf.reduce_sum( tf.multiply( r, tf.math.log( tf.linalg.tensor_diag_part(similarity) / tf.reduce_sum(similarity, axis=1)))) return loss ================================================ FILE: monolith/native_training/losses/batch_softmax_loss_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import numpy as np import tensorflow as tf from monolith.native_training.losses.batch_softmax_loss import batch_softmax_loss class BatchSoftmaxLossTest(tf.test.TestCase): def test_batch_softmax_loss(self): batch_size, dim = 4, 3 query = tf.constant(np.random.random([batch_size, dim]), dtype=tf.float32) item = tf.constant(np.random.random([batch_size, dim]), dtype=tf.float32) item_step_interval = tf.constant( [np.random.randint(1, 10) for _ in range(batch_size)], dtype=tf.float32) r = tf.ones((batch_size,), dtype=tf.float32) loss = batch_softmax_loss(query, item, item_step_interval, r) self.assertAllClose([loss], [6.5931373]) if __name__ == '__main__': tf.test.main() ================================================ FILE: monolith/native_training/losses/inbatch_auc_loss.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os from typing import List, Tuple, Optional, NamedTuple import tensorflow as tf from monolith.native_training.runtime.ops import gen_monolith_ops inbatch_auc_loss_ops = gen_monolith_ops def inbatch_auc_loss(label: tf.Tensor, logit: tf.Tensor, neg_weight=1.0) -> tf.Tensor: return inbatch_auc_loss_ops.inbatch_auc_loss(label=label, logit=logit, neg_weight=neg_weight) @tf.RegisterGradient(op_type='InbatchAucLoss') def _inbatch_auc_loss_grad(op: tf.Operation, grad: tf.Tensor): label, logit = op.inputs[0], op.inputs[1] neg_weight = op.get_attr(name='neg_weight') logit_grad = inbatch_auc_loss_ops.inbatch_auc_loss_grad(label=label, logit=logit, grad=grad, neg_weight=neg_weight) return None, logit_grad ================================================ FILE: monolith/native_training/losses/inbatch_auc_loss_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os from math import log, exp import tensorflow as tf from monolith.native_training.losses import inbatch_auc_loss class InbatchAucLossTest(tf.test.TestCase): def test_inbatch_auc_loss(self): label = [1, 0, 0, 1] logit = [0.5, -0.2, -0.4, 0.8] loss = inbatch_auc_loss.inbatch_auc_loss(label=label, logit=logit) loss_truth = 0 pos, neg = [], [] for i, l in enumerate(label): if l > 0: pos.append(i) else: neg.append(i) for i in pos: for j in neg: diff = logit[i] - logit[j] loss_truth += log(1 / (1 + exp(-diff))) self.assertAlmostEqual(loss, tf.constant(loss_truth), delta=0.000001) def test_inbatch_auc_loss_grad(self): label = [1, 0, 0, 1] logit = [0.5, -0.2, -0.4, 0.8] logit_grad = inbatch_auc_loss.inbatch_auc_loss_ops.inbatch_auc_loss_grad( label=label, logit=logit, grad=2, neg_weight=1.0) pos, neg = [], [] for i, l in enumerate(label): if l > 0: pos.append(i) else: neg.append(i) logit_grad_truth = [0] * len(logit) for i in pos: for j in neg: diff = logit[i] - logit[j] grad_ij = 1 - 1 / (1 + exp(-diff)) logit_grad_truth[i] += grad_ij logit_grad_truth[j] -= grad_ij logit_grad_truth = [2 * x for x in logit_grad_truth] self.assertAllClose(logit_grad, tf.constant(logit_grad_truth)) if __name__ == "__main__": # tf.compat.v1.disable_eager_execution() tf.test.main() ================================================ FILE: monolith/native_training/losses/ltr_losses.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== from __future__ import absolute_import from __future__ import division from __future__ import print_function import abc import tensorflow as tf from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import random_ops from tensorflow.python.ops import sparse_ops from tensorflow.python.ops.losses import losses as core_losses # The smallest probability that is used to derive smallest logit for invalid or # padding entries. _EPSILON = 1e-10 def label_valid_fn(labels): """Returns a boolean `Tensor` for label validity.""" labels = ops.convert_to_tensor(labels) return math_ops.greater_equal(labels, 0.) def sort_by_scores(scores, features_list, topn=None): """Sorts example features according to per-example scores. Args: scores: A `Tensor` of shape [batch_size, list_size] representing the per-example scores. features_list: A list of `Tensor`s with the same shape as scores to be sorted. topn: An integer as the cutoff of examples in the sorted list. Returns: A list of `Tensor`s as the list of sorted features by `scores`. """ scores = ops.convert_to_tensor(scores) scores.get_shape().assert_has_rank(2) batch_size, list_size = array_ops.unstack(array_ops.shape(scores)) if topn is None: topn = list_size topn = math_ops.minimum(topn, list_size) _, indices = nn_ops.top_k(scores, topn, sorted=True) list_offsets = array_ops.expand_dims( math_ops.range(batch_size) * list_size, 1) # The shape of `indices` is [batch_size, topn] and the shape of # `list_offsets` is [batch_size, 1]. Broadcasting is used here. gather_indices = array_ops.reshape(indices + list_offsets, [-1]) output_shape = array_ops.stack([batch_size, topn]) # Each feature is first flattened to a 1-D vector and then gathered by the # indices from sorted scores and then re-shaped. return [ array_ops.reshape( array_ops.gather(array_ops.reshape(feature, [-1]), gather_indices), output_shape) for feature in features_list ] def organize_valid_indices(is_valid, shuffle=True, seed=None): """Organizes indices in such a way that valid items appear first. Args: is_valid: A boolen `Tensor` for entry validity with shape [batch_size, list_size]. shuffle: A boolean indicating whether valid items should be shuffled. seed: An int for random seed at the op level. It works together with the seed at global graph level together to determine the random number generation. See `tf.set_random_seed`. Returns: A tensor of indices with shape [batch_size, list_size, 2]. The returned tensor can be used with `tf.gather_nd` and `tf.scatter_nd` to compose a new [batch_size, list_size] tensor. The values in the last dimension are the indices for an element in the input tensor. """ is_valid = ops.convert_to_tensor(is_valid) is_valid.get_shape().assert_has_rank(2) output_shape = array_ops.shape(is_valid) if shuffle: values = random_ops.random_uniform(output_shape, seed=seed) else: values = (array_ops.ones_like(is_valid, dtypes.float32) * array_ops.reverse( math_ops.to_float(math_ops.range(output_shape[1])), [-1])) rand = array_ops.where(is_valid, values, array_ops.ones(output_shape) * -1e-6) # shape(indices) = [batch_size, list_size] _, indices = nn_ops.top_k(rand, output_shape[1], sorted=True) # shape(batch_ids) = [batch_size, list_size] batch_ids = array_ops.ones_like(indices) * array_ops.expand_dims( math_ops.range(output_shape[0]), 1) return array_ops.concat( [ array_ops.expand_dims(batch_ids, 2), #[[0,...0], [1, ..., 1]] array_ops.expand_dims(indices, 2) ], # shuffle之后的indices,dim0=batch dim1=shuffle后的index,例如[0, 1, 2, 3] 变为[2,3,0,1],为list的长度 axis=2) def shuffle_valid_indices(is_valid, seed=None): """Returns a shuffle of indices with valid ones on top.""" return organize_valid_indices(is_valid, shuffle=True, seed=seed) def reshape_first_ndims(tensor, first_ndims, new_shape): """Reshapes the first n dims of the input `tensor` to `new shape`. Args: tensor: The input `Tensor`. first_ndims: A int denoting the first n dims. new_shape: A list of int representing the new shape. Returns: A reshaped `Tensor`. """ assert tensor.get_shape().ndims is None or tensor.get_shape( ).ndims >= first_ndims, ( 'Tensor shape is less than {} dims.'.format(first_ndims)) new_shape = array_ops.concat( [new_shape, array_ops.shape(tensor)[first_ndims:]], 0) if isinstance(tensor, sparse_tensor.SparseTensor): return sparse_ops.sparse_reshape(tensor, new_shape) return array_ops.reshape(tensor, new_shape) def approx_ranks(logits, alpha=10.): r"""Computes approximate ranks given a list of logits. Given a list of logits, the rank of an item in the list is simply one plus the total number of items with a larger logit. In other words, rank_i = 1 + \sum_{j \neq i} I_{s_j > s_i}, where "I" is the indicator function. The indicator function can be approximated by a generalized sigmoid: I_{s_j < s_i} \approx 1/(1 + exp(-\alpha * (s_j - s_i))). This function approximates the rank of an item using this sigmoid approximation to the indicator function. This technique is at the core of "A general approximation framework for direct optimization of information retrieval measures" by Qin et al. Args: logits: A `Tensor` with shape [batch_size, list_size]. Each value is the ranking score of the corresponding item. alpha: Exponent of the generalized sigmoid function. Returns: A `Tensor` of ranks with the same shape as logits. """ list_size = array_ops.shape(logits)[1] x = array_ops.tile(array_ops.expand_dims(logits, 2), [1, 1, list_size]) y = array_ops.tile(array_ops.expand_dims(logits, 1), [1, list_size, 1]) pairs = math_ops.sigmoid(alpha * (y - x)) return math_ops.reduce_sum(pairs, -1) + .5 def inverse_max_dcg(labels, gain_fn=lambda labels: math_ops.pow(2.0, labels) - 1., rank_discount_fn=lambda rank: 1. / math_ops.log1p(rank), topn=None): """Computes the inverse of max DCG. Args: labels: A `Tensor` with shape [batch_size, list_size]. Each value is the graded relevance of the corresponding item. gain_fn: A gain function. By default this is set to: 2^label - 1. rank_discount_fn: A discount function. By default this is set to: 1/log(1+rank). topn: An integer as the cutoff of examples in the sorted list. Returns: A `Tensor` with shape [batch_size, 1]. """ ideal_sorted_labels, = sort_by_scores(labels, [labels], topn=topn) rank = math_ops.range(array_ops.shape(ideal_sorted_labels)[1]) + 1 discounted_gain = gain_fn(ideal_sorted_labels) * rank_discount_fn( math_ops.to_float(rank)) discounted_gain = math_ops.reduce_sum(discounted_gain, 1, keepdims=True) return array_ops.where(math_ops.greater(discounted_gain, 0.), 1. / discounted_gain, array_ops.zeros_like(discounted_gain)) def get_batch_idx_size(logits, labels, rank_id, name_prefix): batch_size = tf.shape(logits)[0] rank_key, rank_idx, count = tf.unique_with_counts(rank_id) max_count = tf.reduce_max(count) unique_rank_id_num = tf.shape(count)[0] rank_idx_tile = tf.tile(tf.expand_dims(rank_idx, 0), [unique_rank_id_num, 1]) range_count_tile = tf.tile(tf.expand_dims(tf.range(unique_rank_id_num), 1), [1, batch_size]) list_id_mask = tf.cast(tf.equal(rank_idx_tile, range_count_tile), tf.int32) cum_mask = tf.cumsum(list_id_mask, axis=1, exclusive=True) masked_list_id = list_id_mask * cum_mask col_cor = tf.reduce_sum(masked_list_id, axis=0) row_cor = rank_idx batch_idx = tf.concat( [tf.expand_dims(row_cor, 1), tf.expand_dims(col_cor, 1)], axis=1) output_shape = [tf.shape(count)[0], max_count] logits_idx = tf.scatter_nd(updates=logits, indices=batch_idx, shape=output_shape, name=name_prefix + "logits_idx") label_idx = tf.scatter_nd(updates=labels, indices=batch_idx, shape=output_shape, name=name_prefix + "label_idx") - 1e-6 mask_idx = tf.scatter_nd(updates=tf.ones_like(logits), indices=batch_idx, shape=[batch_size, batch_size], name=name_prefix + 'mask_idx') unique_idx = tf.argmax(list_id_mask, axis=1) return logits_idx, label_idx, mask_idx, unique_idx class RankingLossKey(object): """Ranking loss key strings.""" # Names for the ranking based loss functions. PAIRWISE_HINGE_LOSS = 'pairwise_hinge_loss' PAIRWISE_LOGISTIC_LOSS = 'pairwise_logistic_loss' PAIRWISE_SOFT_ZERO_ONE_LOSS = 'pairwise_soft_zero_one_loss' SOFTMAX_LOSS = 'softmax_loss' SIGMOID_CROSS_ENTROPY_LOSS = 'sigmoid_cross_entropy_loss' MEAN_SQUARED_LOSS = 'mean_squared_loss' LIST_MLE_LOSS = 'list_mle_loss' APPROX_NDCG_LOSS = 'approx_ndcg_loss' def make_loss_fn(loss_keys, loss_weights=None, weights_feature_name=None, lambda_weight=None, reduction=core_losses.Reduction.SUM_BY_NONZERO_WEIGHTS, name=None, seed=None, extra_args=None): """Makes a loss function using a single loss or multiple losses. Args: loss_keys: A string or list of strings representing loss keys defined in `RankingLossKey`. Listed loss functions will be combined in a weighted manner, with weights specified by `loss_weights`. If `loss_weights` is None, default weight of 1 will be used. loss_weights: List of weights, same length as `loss_keys`. Used when merging losses to calculate the weighted sum of losses. If `None`, all losses are weighted equally with weight being 1. weights_feature_name: A string specifying the name of the weights feature in `features` dict. lambda_weight: A `_LambdaWeight` object created by factory methods like `create_ndcg_lambda_weight()`. reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to reduce training loss over batch. name: A string used as the name for this loss. seed: A randomization seed used in computation of some loss functions such as ListMLE and pListMLE. extra_args: A string-keyed dictionary that contains any other loss-specific arguments. Returns: A function _loss_fn(). See `_loss_fn()` for its signature. Raises: ValueError: If `reduction` is invalid. ValueError: If `loss_keys` is None or empty. ValueError: If `loss_keys` and `loss_weights` have different sizes. """ if (reduction not in core_losses.Reduction.all() or reduction == core_losses.Reduction.NONE): raise ValueError('Invalid reduction: {}'.format(reduction)) if not loss_keys: raise ValueError('loss_keys cannot be None or empty.') if loss_weights: if len(loss_keys) != len(loss_weights): raise ValueError('loss_keys and loss_weights must have the same size.') if not isinstance(loss_keys, list): loss_keys = [loss_keys] def _loss_fn(labels, logits, features): """Computes a single loss or weighted combination of losses. Args: labels: A `Tensor` of the same shape as `logits` representing relevance. logits: A `Tensor` with shape [batch_size, list_size]. Each value is the ranking score of the corresponding item. features: Dict of Tensors of shape [batch_size, list_size, ...] for per-example features and shape [batch_size, ...] for non-example context features. Returns: An op for a single loss or weighted combination of multiple losses. Raises: ValueError: If `loss_keys` is invalid. """ weights = features[weights_feature_name] if weights_feature_name else None loss_kwargs = { 'labels': labels, 'logits': logits, 'weights': weights, 'reduction': reduction, 'name': name, } if extra_args is not None: loss_kwargs.update(extra_args) loss_kwargs_with_lambda_weight = loss_kwargs.copy() loss_kwargs_with_lambda_weight['lambda_weight'] = lambda_weight loss_kwargs_with_lambda_weight_and_seed = ( loss_kwargs_with_lambda_weight.copy()) loss_kwargs_with_lambda_weight_and_seed['seed'] = seed key_to_fn = { RankingLossKey.PAIRWISE_HINGE_LOSS: (_pairwise_hinge_loss, loss_kwargs_with_lambda_weight), RankingLossKey.PAIRWISE_LOGISTIC_LOSS: (_pairwise_logistic_loss, loss_kwargs_with_lambda_weight), RankingLossKey.PAIRWISE_SOFT_ZERO_ONE_LOSS: (_pairwise_soft_zero_one_loss, loss_kwargs_with_lambda_weight), RankingLossKey.SOFTMAX_LOSS: (_softmax_loss, loss_kwargs_with_lambda_weight), RankingLossKey.SIGMOID_CROSS_ENTROPY_LOSS: (_sigmoid_cross_entropy_loss, loss_kwargs), RankingLossKey.MEAN_SQUARED_LOSS: (_mean_squared_loss, loss_kwargs), RankingLossKey.LIST_MLE_LOSS: (_list_mle_loss, loss_kwargs_with_lambda_weight_and_seed), RankingLossKey.APPROX_NDCG_LOSS: (_approx_ndcg_loss, loss_kwargs), } # Obtain the list of loss ops. loss_ops = [] for loss_key in loss_keys: if loss_key not in key_to_fn: raise ValueError('Invalid loss_key: {}.'.format(loss_key)) loss_fn, kwargs = key_to_fn[loss_key] loss_ops.append(loss_fn(**kwargs)) # Compute weighted combination of losses. if loss_weights: weighted_losses = [] for loss_op, loss_weight in zip(loss_ops, loss_weights): weighted_losses.append(math_ops.multiply(loss_op, loss_weight)) else: weighted_losses = loss_ops return math_ops.add_n(weighted_losses) return _loss_fn def create_ndcg_lambda_weight(topn=None, smooth_fraction=0.): """Creates _LambdaWeight for NDCG metric.""" return DCGLambdaWeight( topn, gain_fn=lambda labels: math_ops.pow(2.0, labels) - 1., rank_discount_fn=lambda rank: 1. / math_ops.log1p(rank), normalized=True, smooth_fraction=smooth_fraction) def create_reciprocal_rank_lambda_weight(topn=None, smooth_fraction=0.): """Creates _LambdaWeight for MRR-like metric.""" return DCGLambdaWeight(topn, gain_fn=lambda labels: labels, rank_discount_fn=lambda rank: 1. / rank, normalized=True, smooth_fraction=smooth_fraction) def create_p_list_mle_lambda_weight(list_size): """Creates _LambdaWeight based on Position-Aware ListMLE paper. Produces a weight based on the formulation presented in the "Position-Aware ListMLE" paper (Lan et al.) and available using create_p_list_mle_lambda_weight() factory function above. Args: list_size: Size of the input list. Returns: A _LambdaWeight for Position-Aware ListMLE. """ return ListMLELambdaWeight( rank_discount_fn=lambda rank: math_ops.pow(2., list_size - rank) - 1.) class _LambdaWeight(object): """Interface for ranking metric optimization. This class wraps weights used in the LambdaLoss framework for ranking metric optimization (https://ai.google/research/pubs/pub47258). Such an interface is to be instantiated by concrete lambda weight models. The instance is used together with standard loss such as logistic loss and softmax loss. """ __metaclass__ = abc.ABCMeta def _get_valid_pairs_and_clean_labels(self, sorted_labels): """Returns a boolean Tensor for valid pairs and cleaned labels.""" sorted_labels = ops.convert_to_tensor(sorted_labels) sorted_labels.get_shape().assert_has_rank(2) is_label_valid = label_valid_fn(sorted_labels) valid_pairs = math_ops.logical_and(array_ops.expand_dims(is_label_valid, 2), array_ops.expand_dims(is_label_valid, 1)) sorted_labels = array_ops.where(is_label_valid, sorted_labels, array_ops.zeros_like(sorted_labels)) return valid_pairs, sorted_labels @abc.abstractmethod def pair_weights(self, sorted_labels): """Returns the weight adjustment `Tensor` for example pairs. Args: sorted_labels: A dense `Tensor` of labels with shape [batch_size, list_size] that are sorted by logits. Returns: A `Tensor` that can weight example pairs. """ raise NotImplementedError('Calling an abstract method.') def individual_weights(self, sorted_labels): """Returns the weight `Tensor` for individual examples. Args: sorted_labels: A dense `Tensor` of labels with shape [batch_size, list_size] that are sorted by logits. Returns: A `Tensor` that can weight individual examples. """ return sorted_labels class IdentityLambdaWeight(_LambdaWeight): def __init__(self,): pass def pair_weights(self, sorted_labels): return 1.0 class DCGLambdaWeight(_LambdaWeight): """LambdaWeight for Discounted Cumulative Gain metric.""" def __init__(self, topn=None, gain_fn=lambda label: label, rank_discount_fn=lambda rank: 1. / rank, normalized=False, smooth_fraction=0.): """Constructor. Ranks are 1-based, not 0-based. Given rank i and j, there are two types of pair weights: u = |rank_discount_fn(|i-j|) - rank_discount_fn(|i-j| + 1)| v = |rank_discount_fn(i) - rank_discount_fn(j)| where u is the newly introduced one in LambdaLoss paper (https://ai.google/research/pubs/pub47258) and v is the original one in the LambdaMART paper "From RankNet to LambdaRank to LambdaMART: An Overview". The final pair weight contribution of ranks is (1-smooth_fraction) * u + smooth_fraction * v. Args: topn: (int) The topn for the DCG metric. gain_fn: (function) Tranforms labels. rank_discount_fn: (function) The rank discount function. normalized: (bool) If True, normalize weight by the max DCG. smooth_fraction: (float) parameter to control the contribution from LambdaMART. """ self._topn = topn self._gain_fn = gain_fn self._rank_discount_fn = rank_discount_fn self._normalized = normalized assert 0. <= smooth_fraction and smooth_fraction <= 1., ( 'smooth_fraction %s should be in range [0, 1].' % smooth_fraction) self._smooth_fraction = smooth_fraction def pair_weights(self, sorted_labels): """See `_LambdaWeight`.""" with ops.name_scope(None, 'dcg_lambda_weight', (sorted_labels,)): valid_pair, sorted_labels = self._get_valid_pairs_and_clean_labels( sorted_labels) gain = self._gain_fn(sorted_labels) if self._normalized: gain *= inverse_max_dcg(sorted_labels, gain_fn=self._gain_fn, rank_discount_fn=self._rank_discount_fn, topn=self._topn) pair_gain = array_ops.expand_dims(gain, 2) - array_ops.expand_dims( gain, 1) pair_gain *= math_ops.to_float(valid_pair) list_size = array_ops.shape(sorted_labels)[1] topn = self._topn or list_size rank = math_ops.range(list_size) + 1 def _discount_for_relative_rank_diff(): """Rank-based discount in the LambdaLoss paper.""" # The LambdaLoss is not well defined when topn is active and topn < # list_size. We cap the rank of examples to topn + 1 so that the rank # differene is capped to topn. This is just a convenient upperbound # when topn is active. We need to revisit this. capped_rank = array_ops.where(math_ops.greater(rank, topn), array_ops.ones_like(rank) * (topn + 1), rank) rank_diff = math_ops.to_float( math_ops.abs( array_ops.expand_dims(capped_rank, 1) - array_ops.expand_dims(capped_rank, 0))) pair_discount = array_ops.where( math_ops.greater(rank_diff, 0), math_ops.abs( self._rank_discount_fn(rank_diff) - self._rank_discount_fn(rank_diff + 1)), array_ops.zeros_like(rank_diff)) return pair_discount def _discount_for_absolute_rank(): """Standard discount in the LambdaMART paper.""" # When the rank discount is (1 / rank) for example, the discount is # |1 / r_i - 1 / r_j|. When i or j > topn, the discount becomes 0. rank_discount = array_ops.where( math_ops.greater(rank, topn), array_ops.zeros_like(math_ops.to_float(rank)), self._rank_discount_fn(math_ops.to_float(rank))) pair_discount = math_ops.abs( array_ops.expand_dims(rank_discount, 1) - array_ops.expand_dims(rank_discount, 0)) return pair_discount u = _discount_for_relative_rank_diff() v = _discount_for_absolute_rank() pair_discount = (1. - self._smooth_fraction) * u + self._smooth_fraction * v pair_weight = math_ops.abs(pair_gain) * pair_discount if self._topn is None: return pair_weight pair_mask = math_ops.logical_or( array_ops.expand_dims(math_ops.less_equal(rank, self._topn), 1), array_ops.expand_dims(math_ops.less_equal(rank, self._topn), 0)) return pair_weight * math_ops.to_float(pair_mask) def individual_weights(self, sorted_labels): """See `_LambdaWeight`.""" with ops.name_scope(None, 'dcg_lambda_weight', (sorted_labels,)): sorted_labels = ops.convert_to_tensor(sorted_labels) sorted_labels = array_ops.where(label_valid_fn(sorted_labels), sorted_labels, array_ops.zeros_like(sorted_labels)) gain = self._gain_fn(sorted_labels) if self._normalized: gain *= inverse_max_dcg(sorted_labels, gain_fn=self._gain_fn, rank_discount_fn=self._rank_discount_fn, topn=self._topn) rank_discount = self._rank_discount_fn( math_ops.to_float( math_ops.range(array_ops.shape(sorted_labels)[1]) + 1)) return gain * rank_discount class PrecisionLambdaWeight(_LambdaWeight): """LambdaWeight for Precision metric.""" def __init__(self, topn, positive_fn=lambda label: math_ops.greater_equal(label, 1.0)): """Constructor. Args: topn: (int) The K in Precision@K metric. positive_fn: (function): A function on `Tensor` that output boolean True for positive examples. The rest are negative examples. """ self._topn = topn self._positive_fn = positive_fn def pair_weights(self, sorted_labels): """See `_LambdaWeight`. The current implementation here is that for any pairs of documents i and j, we set the weight to be 1 if - i and j have different labels. - i <= topn and j > topn or i > topn and j <= topn. This is exactly the same as the original LambdaRank method. The weight is the gain of swapping a pair of documents. Args: sorted_labels: A dense `Tensor` of labels with shape [batch_size, list_size] that are sorted by logits. Returns: A `Tensor` that can weight example pairs. """ with ops.name_scope(None, 'precision_lambda_weight', (sorted_labels,)): valid_pair, sorted_labels = self._get_valid_pairs_and_clean_labels( sorted_labels) binary_labels = math_ops.to_float(self._positive_fn(sorted_labels)) label_diff = math_ops.abs( array_ops.expand_dims(binary_labels, 2) - array_ops.expand_dims(binary_labels, 1)) label_diff *= math_ops.to_float(valid_pair) # i <= topn and j > topn or i > topn and j <= topn, i.e., xor(i <= topn, j # <= topn). list_size = array_ops.shape(sorted_labels)[1] rank = math_ops.range(list_size) + 1 rank_mask = math_ops.logical_xor( array_ops.expand_dims(math_ops.less_equal(rank, self._topn), 1), array_ops.expand_dims(math_ops.less_equal(rank, self._topn), 0)) return label_diff * math_ops.to_float(rank_mask) class ListMLELambdaWeight(_LambdaWeight): """LambdaWeight for ListMLE cost function.""" def __init__(self, rank_discount_fn): """Constructor. Ranks are 1-based, not 0-based. Args: rank_discount_fn: (function) The rank discount function. """ self._rank_discount_fn = rank_discount_fn def pair_weights(self, sorted_labels): """See `_LambdaWeight`.""" return sorted_labels def individual_weights(self, sorted_labels): """See `_LambdaWeight`.""" with ops.name_scope(None, 'p_list_mle_lambda_weight', (sorted_labels,)): sorted_labels = ops.convert_to_tensor(sorted_labels) rank_discount = self._rank_discount_fn( math_ops.to_float( math_ops.range(array_ops.shape(sorted_labels)[1]) + 1)) return array_ops.ones_like(sorted_labels) * rank_discount def _sort_and_normalize(labels, logits, weights=None): """Sorts `labels` and `logits` and normalize `weights`. Args: labels: A `Tensor` of the same shape as `logits` representing graded relevance. logits: A `Tensor` with shape [batch_size, list_size]. Each value is the ranking score of the corresponding item. weights: A scalar, a `Tensor` with shape [batch_size, 1], or a `Tensor` with the same shape as `labels`. Returns: A tuple of (sorted_labels, sorted_logits, sorted_weights). """ labels = ops.convert_to_tensor(labels) logits = ops.convert_to_tensor(logits) logits.get_shape().assert_has_rank(2) logits.get_shape().assert_is_compatible_with(labels.get_shape()) weights = 1.0 if weights is None else ops.convert_to_tensor(weights) weights = array_ops.ones_like(labels) * weights _, topn = array_ops.unstack(array_ops.shape(logits)) # Only sort entries with valid labels that are >= 0. scores = array_ops.where( math_ops.greater_equal(labels, 0.), logits, -1e-6 * array_ops.ones_like(logits) + math_ops.reduce_min(logits, axis=1, keepdims=True)) sorted_labels, sorted_logits, sorted_weights = sort_by_scores( scores, [labels, logits, weights], topn=topn) return sorted_labels, sorted_logits, sorted_weights def _pairwise_comparison(sorted_labels, sorted_logits, sorted_weights, lambda_weight=None): r"""Returns pairwise comparison `Tensor`s. Given a list of n items, the labels of graded relevance l_i and the logits s_i, we sort the items in a list based on s_i and obtain ranks r_i. We form n^2 pairs of items. For each pair, we have the following: / | 1 if l_i > l_j * `pairwise_labels` = | | 0 if l_i <= l_j \ * `pairwise_logits` = s_i - s_j / | 0 if l_i <= l_j, * `pairwise_weights` = | |l_i - l_j| if lambda_weight is None, | lambda_weight otherwise. \ The `sorted_weights` is item-wise and is applied non-symmetrically to update pairwise_weights as pairwise_weights(i, j) = w_i * pairwise_weights(i, j). This effectively applies to all pairs with l_i > l_j. Note that it is actually symmetric when `sorted_weights` are constant per list, i.e., listwise weights. Args: sorted_labels: A `Tensor` with shape [batch_size, list_size] of labels sorted. sorted_logits: A `Tensor` with shape [batch_size, list_size] of logits sorted. sorted_weights: A `Tensor` with shape [batch_size, list_size] of item-wise weights sorted. lambda_weight: A `_LambdaWeight` object. Returns: A tuple of (pairwise_labels, pairwise_logits, pairwise_weights) with each having the shape [batch_size, list_size, list_size]. """ # Compute the difference for all pairs in a list. The output is a Tensor with # shape [batch_size, list_size, list_size] where the entry [-1, i, j] stores # the information for pair (i, j). pairwise_label_diff = array_ops.expand_dims( sorted_labels, 2) - array_ops.expand_dims(sorted_labels, 1) pairwise_logits = array_ops.expand_dims( sorted_logits, 2) - array_ops.expand_dims(sorted_logits, 1) pairwise_labels = math_ops.to_float(math_ops.greater(pairwise_label_diff, 0)) is_label_valid = label_valid_fn(sorted_labels) valid_pair = math_ops.logical_and(array_ops.expand_dims(is_label_valid, 2), array_ops.expand_dims(is_label_valid, 1)) # Only keep the case when l_i > l_j. pairwise_weights = pairwise_labels * math_ops.to_float(valid_pair) # Apply the item-wise weights along l_i. pairwise_weights *= tf.cast(array_ops.expand_dims(sorted_weights, 2), tf.float32) if lambda_weight is not None: pairwise_weights *= lambda_weight.pair_weights(sorted_labels) else: pairwise_weights *= math_ops.abs(pairwise_label_diff) pairwise_weights = array_ops.stop_gradient(pairwise_weights, name='weights_stop_gradient') return pairwise_labels, pairwise_logits, pairwise_weights def _pairwise_loss(loss_fn, labels, logits, weights=None, lambda_weight=None, lambda_scale=True, reduction=core_losses.Reduction.SUM_BY_NONZERO_WEIGHTS): """Template to compute pairwise loss. Args: loss_fn: A function that computes loss from the pairwise logits with l_i > l_j. labels: A `Tensor` of the same shape as `logits` representing graded relevance. logits: A `Tensor` with shape [batch_size, list_size]. Each value is the ranking score of the corresponding item. weights: A scalar, a `Tensor` with shape [batch_size, 1] for list-wise weights, or a `Tensor` with shape [batch_size, list_size] for item-wise weights. lambda_weight: A `_LambdaWeight` object. reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to reduce training loss over batch. Returns: An op for the pairwise loss. """ sorted_labels, sorted_logits, sorted_weights = _sort_and_normalize( labels, logits, weights) _, pairwise_logits, pairwise_weights = _pairwise_comparison( sorted_labels, sorted_logits, sorted_weights, lambda_weight) if lambda_weight is not None and lambda_scale: # For LambdaLoss with relative rank difference, the scale of loss becomes # much smaller when applying LambdaWeight. This affects the training can # make the optimal learning rate become much larger. We use a heuristic to # scale it up to the same magnitude as standard pairwise loss. pairwise_weights *= math_ops.to_float(array_ops.shape(sorted_labels)[1]) return core_losses.compute_weighted_loss(loss_fn(pairwise_logits), weights=pairwise_weights, reduction=reduction) def _pairwise_hinge_loss(labels, logits, weights=None, lambda_weight=None, lambda_scale=True, reduction=core_losses.Reduction.SUM_BY_NONZERO_WEIGHTS, name=None): """Computes the pairwise hinge loss for a list. The hinge loss is defined as Hinge(l_i > l_j) = max(0, 1 - (s_i - s_j)). So a correctly ordered pair has 0 loss if (s_i - s_j >= 1). Otherwise the loss increases linearly with s_i - s_j. When the list_size is 2, this reduces to the standard hinge loss. Args: labels: A `Tensor` of the same shape as `logits` representing graded relevance. logits: A `Tensor` with shape [batch_size, list_size]. Each value is the ranking score of the corresponding item. weights: A scalar, a `Tensor` with shape [batch_size, 1] for list-wise weights, or a `Tensor` with shape [batch_size, list_size] for item-wise weights. lambda_weight: A `_LambdaWeight` object. reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to reduce training loss over batch. name: A string used as the name for this loss. Returns: An op for the pairwise hinge loss. """ def _loss(logits): """The loss of pairwise logits with l_i > l_j.""" # TODO(xuanhui, pointer-team): Consider pass params object into the loss and # put a margin here. return nn_ops.relu(1. - logits) with ops.name_scope(name, 'pairwise_hinge_loss', (labels, logits, weights)): return _pairwise_loss(_loss, labels, logits, weights, lambda_weight, lambda_scale, reduction=reduction) def _pairwise_logistic_loss( labels, logits, weights=None, lambda_weight=None, lambda_scale=True, reduction=core_losses.Reduction.SUM_BY_NONZERO_WEIGHTS, name=None): """Computes the pairwise logistic loss for a list. The preference probability of each pair is computed as the sigmoid function: P(l_i > l_j) = 1 / (1 + exp(s_j - s_i)) and the logistic loss is log(P(l_i > l_j)) if l_i > l_j and 0 otherwise. Args: labels: A `Tensor` of the same shape as `logits` representing graded relevance. logits: A `Tensor` with shape [batch_size, list_size]. Each value is the ranking score of the corresponding item. weights: A scalar, a `Tensor` with shape [batch_size, 1] for list-wise weights, or a `Tensor` with shape [batch_size, list_size] for item-wise weights. lambda_weight: A `_LambdaWeight` object. reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to reduce training loss over batch. name: A string used as the name for this loss. Returns: An op for the pairwise logistic loss. """ def _loss(logits): """The loss of pairwise logits with l_i > l_j.""" # The following is the same as log(1 + exp(-pairwise_logits)). return nn_ops.relu(-logits) + math_ops.log1p( math_ops.exp(-math_ops.abs(logits))) with ops.name_scope(name, 'pairwise_logistic_loss', (labels, logits, weights)): return _pairwise_loss(_loss, labels, logits, weights, lambda_weight, lambda_scale, reduction=reduction) def _pairwise_soft_zero_one_loss( labels, logits, weights=None, lambda_weight=None, reduction=core_losses.Reduction.SUM_BY_NONZERO_WEIGHTS, name=None): """Computes the pairwise soft zero-one loss. Note this is different from sigmoid cross entropy in that soft zero-one loss is a smooth but non-convex approximation of zero-one loss. The preference probability of each pair is computed as the sigmoid function: P(l_i > l_j) = 1 / (1 + exp(s_j - s_i)). Then 1 - P(l_i > l_j) is directly used as the loss. So a correctly ordered pair has a loss close to 0, while an incorrectly ordered pair has a loss bounded by 1. Args: labels: A `Tensor` of the same shape as `logits` representing graded relevance. logits: A `Tensor` with shape [batch_size, list_size]. Each value is the ranking score of the corresponding item. weights: A scalar, a `Tensor` with shape [batch_size, 1] for list-wise weights, or a `Tensor` with shape [batch_size, list_size] for item-wise weights. lambda_weight: A `_LambdaWeight` object. reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to reduce training loss over batch. name: A string used as the name for this loss. Returns: An op for the pairwise soft zero one loss. """ def _loss(logits): """The loss of pairwise logits with l_i > l_j.""" return array_ops.where(math_ops.greater(logits, 0), 1. - math_ops.sigmoid(logits), math_ops.sigmoid(-logits)) with ops.name_scope(name, 'pairwise_soft_zero_one_loss', (labels, logits, weights)): return _pairwise_loss(_loss, labels, logits, weights, lambda_weight, reduction=reduction) def _softmax_loss(labels, logits, weights=None, lambda_weight=None, reduction=core_losses.Reduction.SUM_BY_NONZERO_WEIGHTS, name=None): """Computes the softmax cross entropy for a list. Given the labels l_i and the logits s_i, we sort the examples and obtain ranks r_i. The standard softmax loss doesn't need r_i and is defined as -sum_i l_i * log(exp(s_i) / (exp(s_1) + ... + exp(s_n))). The `lambda_weight` re-weight examples based on l_i and r_i. -sum_i w(l_i, r_i) * log(exp(s_i) / (exp(s_1) + ... + exp(s_n))).abc See 'individual_weights' in 'DCGLambdaWeight' for how w(l_i, r_i) is computed. Args: labels: A `Tensor` of the same shape as `logits` representing graded relevance. logits: A `Tensor` with shape [batch_size, list_size]. Each value is the ranking score of the corresponding item. weights: A scalar, a `Tensor` with shape [batch_size, 1] for list-wise weights, or a `Tensor` with shape [batch_size, list_size] for item-wise weights. lambda_weight: A `DCGLambdaWeight` instance. reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to reduce training loss over batch. name: A string used as the name for this loss. Returns: An op for the softmax cross entropy as a loss. """ with ops.name_scope(name, 'softmax_loss', (labels, logits, weights)): sorted_labels, sorted_logits, sorted_weights = _sort_and_normalize( labels, logits, weights) is_label_valid = label_valid_fn(sorted_labels) # Reset the invalid labels to 0 and reset the invalid logits to a logit with # ~= 0 contribution in softmax. sorted_labels = array_ops.where(is_label_valid, sorted_labels, array_ops.zeros_like(sorted_labels)) sorted_logits = array_ops.where( is_label_valid, sorted_logits, math_ops.log(_EPSILON) * array_ops.ones_like(sorted_logits)) if lambda_weight is not None and isinstance(lambda_weight, DCGLambdaWeight): sorted_labels = lambda_weight.individual_weights(sorted_labels) sorted_labels *= sorted_weights label_sum = math_ops.reduce_sum(sorted_labels, 1, keepdims=True) nonzero_mask = math_ops.greater(array_ops.reshape(label_sum, [-1]), 0.0) label_sum, sorted_labels, sorted_logits = [ array_ops.boolean_mask(x, nonzero_mask) for x in [label_sum, sorted_labels, sorted_logits] ] return core_losses.softmax_cross_entropy(sorted_labels / label_sum, sorted_logits, weights=array_ops.reshape( label_sum, [-1]), reduction=reduction) def _sigmoid_cross_entropy_loss( labels, logits, weights=None, reduction=core_losses.Reduction.SUM_BY_NONZERO_WEIGHTS, name=None): """Computes the sigmoid_cross_entropy loss for a list. Given the labels of graded relevance l_i and the logits s_i, we calculate the sigmoid cross entropy for each ith position and aggregate the per position losses. Args: labels: A `Tensor` of the same shape as `logits` representing graded relevance. logits: A `Tensor` with shape [batch_size, list_size]. Each value is the ranking score of the corresponding item. weights: A scalar, a `Tensor` with shape [batch_size, 1] for list-wise weights, or a `Tensor` with shape [batch_size, list_size] for item-wise weights. reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to reduce training loss over batch. name: A string used as the name for this loss. Returns: An op for the sigmoid cross entropy as a loss. """ with ops.name_scope(name, 'sigmoid_cross_entropy_loss', (labels, logits, weights)): is_label_valid = array_ops.reshape(label_valid_fn(labels), [-1]) weights = 1.0 if weights is None else ops.convert_to_tensor(weights) weights = array_ops.ones_like(labels) * weights label_vector, logit_vector, weight_vector = [ array_ops.boolean_mask(array_ops.reshape(x, [-1]), is_label_valid) for x in [labels, logits, weights] ] return core_losses.sigmoid_cross_entropy(label_vector, logit_vector, weights=weight_vector, reduction=reduction) def _mean_squared_loss(labels, logits, weights=None, reduction=core_losses.Reduction.SUM_BY_NONZERO_WEIGHTS, name=None): """Computes the mean squared loss for a list. Given the labels of graded relevance l_i and the logits s_i, we calculate the squared error for each ith position and aggregate the per position losses. Args: labels: A `Tensor` of the same shape as `logits` representing graded relevance. logits: A `Tensor` with shape [batch_size, list_size]. Each value is the ranking score of the corresponding item. weights: A scalar, a `Tensor` with shape [batch_size, 1] for list-wise weights, or a `Tensor` with shape [batch_size, list_size] for item-wise weights. reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to reduce training loss over batch. name: A string used as the name for this loss. Returns: An op for the mean squared error as a loss. """ with ops.name_scope(name, 'mean_squared_loss', (labels, logits, weights)): is_label_valid = array_ops.reshape(label_valid_fn(labels), [-1]) weights = 1.0 if weights is None else ops.convert_to_tensor(weights) weights = array_ops.ones_like(labels) * weights label_vector, logit_vector, weight_vector = [ array_ops.boolean_mask(array_ops.reshape(x, [-1]), is_label_valid) for x in [labels, logits, weights] ] return core_losses.mean_squared_error(label_vector, logit_vector, weights=weight_vector, reduction=reduction) def _list_mle_loss(labels, logits, weights=None, lambda_weight=None, reduction=core_losses.Reduction.SUM_BY_NONZERO_WEIGHTS, name=None, seed=None): """Computes the ListMLE loss [Xia et al. 2008] for a list. Given the labels of graded relevance l_i and the logits s_i, we calculate the ListMLE loss for the given list. The `lambda_weight` re-weights examples based on l_i and r_i. The recommended weighting scheme is the formulation presented in the "Position-Aware ListMLE" paper (Lan et al.) and available using create_p_list_mle_lambda_weight() factory function above. Args: labels: A `Tensor` of the same shape as `logits` representing graded relevance. logits: A `Tensor` with shape [batch_size, list_size]. Each value is the ranking score of the corresponding item. weights: A scalar, a `Tensor` with shape [batch_size, 1] for list-wise weights, or a `Tensor` with shape [batch_size, list_size] for item-wise weights. lambda_weight: A `DCGLambdaWeight` instance. reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to reduce training loss over batch. name: A string used as the name for this loss. seed: A randomization seed used when shuffling ground truth permutations. Returns: An op for the ListMLE loss. """ with ops.name_scope(name, 'list_mle_loss', (labels, logits, weights)): is_label_valid = label_valid_fn(labels) # Reset the invalid labels to 0 and reset the invalid logits to a logit with # ~= 0 contribution. labels = array_ops.where(is_label_valid, labels, array_ops.zeros_like(labels)) logits = array_ops.where( is_label_valid, logits, math_ops.log(_EPSILON) * array_ops.ones_like(logits)) weights = 1.0 if weights is None else ops.convert_to_tensor(weights) weights = array_ops.squeeze(weights) # Shuffle labels and logits to add randomness to sort. shuffled_indices = shuffle_valid_indices(is_label_valid, seed) shuffled_labels = array_ops.gather_nd(labels, shuffled_indices) shuffled_logits = array_ops.gather_nd(logits, shuffled_indices) sorted_labels, sorted_logits = sort_by_scores( shuffled_labels, [shuffled_labels, shuffled_logits]) raw_max = math_ops.reduce_max(sorted_logits, axis=1, keepdims=True) sorted_logits = sorted_logits - raw_max sums = math_ops.cumsum(math_ops.exp(sorted_logits), axis=1, reverse=True) sums = math_ops.log(sums) - sorted_logits if lambda_weight is not None and isinstance(lambda_weight, ListMLELambdaWeight): sums *= lambda_weight.individual_weights(sorted_labels) negative_log_likelihood = math_ops.reduce_sum(sums, 1) return core_losses.compute_weighted_loss(negative_log_likelihood, weights=weights, reduction=reduction) def _approx_ndcg_loss(labels, logits, weights=None, reduction=core_losses.Reduction.SUM, name=None, alpha=10.): """Computes ApproxNDCG loss. ApproxNDCG ["A general approximation framework for direct optimization of information retrieval measures" by Qin et al.] is a smooth approximation to NDCG. Args: labels: A `Tensor` of the same shape as `logits` representing graded relevance. logits: A `Tensor` with shape [batch_size, list_size]. Each value is the ranking score of the corresponding item. weights: A scalar, a `Tensor` with shape [batch_size, 1] for list-wise weights, or a `Tensor` with shape [batch_size, list_size] for item-wise weights. If None, the weight of a list in the mini-batch is set to the sum of the labels of the items in that list. reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to reduce training loss over batch. name: A string used as the name for this loss. alpha: The exponent in the generalized sigmoid function. Returns: An op for the ApproxNDCG loss. """ with ops.name_scope(name, 'approx_ndcg_loss', (labels, logits, weights)): is_label_valid = label_valid_fn(labels) labels = array_ops.where(is_label_valid, labels, array_ops.zeros_like(labels)) logits = array_ops.where( is_label_valid, logits, -1e3 * array_ops.ones_like(logits) + math_ops.reduce_min(logits, axis=-1, keepdims=True)) label_sum = math_ops.reduce_sum(labels, 1, keepdims=True) if weights is None: weights = array_ops.ones_like(label_sum) weights = array_ops.squeeze(weights) nonzero_mask = math_ops.greater(array_ops.reshape(label_sum, [-1]), 0.0) labels, logits, weights = [ array_ops.boolean_mask(x, nonzero_mask) for x in [labels, logits, weights] ] gains = math_ops.pow(2., math_ops.to_float(labels)) - 1. ranks = approx_ranks(logits, alpha=alpha) discounts = 1. / math_ops.log1p(ranks) dcg = math_ops.reduce_sum(gains * discounts, -1) cost = -dcg * array_ops.squeeze(inverse_max_dcg(labels)) return core_losses.compute_weighted_loss(cost, weights=weights, reduction=reduction) ================================================ FILE: monolith/native_training/metric/BUILD ================================================ load("@org_tensorflow//tensorflow:tensorflow.bzl", "tf_cc_binary") load("@rules_python//python:defs.bzl", "py_library", "py_test") load("@pip_deps//:requirements.bzl", "requirement") package( default_visibility = [ "//monolith/native_training:__subpackages__", "//monolith/sail:__subpackages__", ], ) py_library( name = "deep_insight_ops", srcs = ["deep_insight_ops.py"], deps = [ "//monolith/native_training/runtime/ops:gen_monolith_ops", ], ) py_test( name = "deep_insight_ops_test", srcs = ["deep_insight_ops_test.py"], deps = [":deep_insight_ops"], ) py_library( name = "cli", srcs = ["cli.py"], visibility = ["//visibility:public"], ) py_library( name = "exit_hook", srcs = ["exit_hook.py"], deps = [ ":cli", "//monolith/native_training:native_task_context", ], ) py_library( name = "metric_hook", srcs = [ "kafka_utils.py", "metric_hook.py", ], deps = [ ":cli", ":exit_hook", "//monolith/native_training:utils", "//monolith/native_training/alert", "@org_tensorflow//tensorflow:tensorflow_py", requirement("kafka-python"), ], ) py_test( name = "metric_hook_test", srcs = ["metric_hook_test.py"], deps = [ ":metric_hook", ], ) py_library( name = "utils", srcs = ["utils.py"], deps = [ ":deep_insight_ops", ], ) py_test( name = "utils_test", srcs = ["utils_test.py"], deps = [ ":utils", ], ) ================================================ FILE: monolith/native_training/metric/cli.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import threading from absl import logging class Client: def __init__(self, *args, **kwargs): pass def __getattr__(self, name): def method(*args, **kwargs): pass return method def get_cli(*args, **kwargs): return Client() ================================================ FILE: monolith/native_training/metric/deep_insight_ops.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import socket from typing import List import tensorflow as tf from tensorflow.python.framework import ops from monolith.native_training.runtime.ops import gen_monolith_ops deep_insight_ops = gen_monolith_ops _FEATURE_REQ_TIME = "req_time" _SAMPLE_RATE = "sample_rate" _UID = "uid" def deep_insight_client(enable_metrics_counter: bool=False, is_fake: bool = False, dump_filename=None, container: str=socket.gethostname()) \ -> tf.Tensor: """ Create a deep insight client Args: enable_metrics_counter - whether enable metrics counter for using deepinsight. container - Use host name as the container name. So that each container will create and use a seperate deepinsight resource. """ return deep_insight_ops.monolith_create_deep_insight_client( enable_metrics_counter, is_fake, dump_filename, container) def write_deep_insight(deep_insight_client_tensor: tf.Tensor, uids: tf.Tensor, req_times: tf.Tensor, labels: tf.Tensor, preds: tf.Tensor, sample_rates: tf.Tensor, model_name: str, target: str = "ctr_head", sample_ratio: float = 0.01, return_msgs: bool = False, use_zero_train_time=False) -> tf.Tensor: """ Write one instance's metrics to deep insight. Internal it includes parse and build JSON format deep insight message. And send to databus channel using unblock API. Args: uid - a 1-D int64 tensor. req_time - a 1-D int64 tensor. labels - a 1-D float tensor. preds - a 1-D float tensor. sample_rates - a 1-D float tensor. model_name - model name of string type. target - target of string type. sample_ratio - sample ratio of float type. return_msg - whether return the msg sent to deepinsight for debugging. use_zero_train_time - Use True if you want to use training time (0) in deepinsight. this is actually used only in test. Use false if you want to use real training time to write to deepinsight. Returns: 1-D string tensor. """ return deep_insight_ops.monolith_write_deep_insight( deep_insight_client_handle=deep_insight_client_tensor, uids=uids, req_times=req_times, labels=labels, preds=preds, sample_rates=sample_rates, model_name=model_name, target=target, sample_ratio=sample_ratio, return_msgs=return_msgs, use_zero_train_time=use_zero_train_time) def write_deep_insight_v2(deep_insight_client_tensor: tf.Tensor, req_times: tf.Tensor, labels: tf.Tensor, preds: tf.Tensor, sample_rates: tf.Tensor, extra_fields_values: List[tf.Tensor], extra_fields_keys: List[str], model_name: str, targets: List[str], sample_ratio: float = 0.01, return_msgs: bool = False, use_zero_train_time=False) -> tf.Tensor: """ Write one instance's metrics to deep insight. Internal it includes parse and build JSON format deep insight message. And send to databus channel using unblock API. Args: deep_insight_client_tensor: MonolithCreateDeepInsightClient req_times: 1-D int64 tensor, shape = (batch_size,) labels: 2-D float tensor, shape = (num_targets, batch_size) preds: 2-D float tensor, shape = (num_targets, batch_size) sample_rates: 2-D float tensor, shape = (num_targets, batch_size) extra_fields_values: List of 1-D tensors, each shape = (batch_size,) extra_fields_keys: List of strings. model_name: model name of string type. targets: List of target names. sample_ratio: sample ratio of float type. return_msgs: whether return the msg sent to deepinsight for debugging. use_zero_train_time: Use True if you want to use training time (0) in deepinsight. this is actually used only in test. Use false if you want to use real training time to write to deepinsight. Returns: 1-D string tensor. """ return deep_insight_ops.monolith_write_deep_insight_v2( deep_insight_client_handle=deep_insight_client_tensor, req_times=req_times, labels=labels, preds=preds, sample_rates=sample_rates, extra_fields_values=extra_fields_values, extra_fields_keys=extra_fields_keys, model_name=model_name, targets=targets, sample_ratio=sample_ratio, return_msgs=return_msgs, use_zero_train_time=use_zero_train_time) ================================================ FILE: monolith/native_training/metric/deep_insight_ops_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import time from absl import logging import json import tensorflow as tf import monolith.native_training.metric.deep_insight_ops as ops class DeepInsightOpsTest(tf.test.TestCase): def dummy_test(self): pass if __name__ == "__main__": tf.compat.v1.disable_eager_execution() tf.test.main() ================================================ FILE: monolith/native_training/metric/exit_hook.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import atexit import sys import signal from monolith.native_training import utils from monolith.native_training import native_task_context from monolith.native_training.metric import cli sig_no = None def sig_handler(signo, frame): global sig_no sig_no = signo sys.exit(signo) signal.signal(signal.SIGHUP, sig_handler) signal.signal(signal.SIGINT, sig_handler) signal.signal(signal.SIGTERM, sig_handler) @atexit.register def exit_hook(): ctx = native_task_context.get() mcli = cli.get_cli(utils.get_metric_prefix()) index = ctx.worker_index if ctx.server_type == 'worker' else ctx.ps_index tags = { 'server_type': ctx.server_type, 'index': str(index), 'sig': str(sig_no), } if sig_no is not None: mcli.emit_counter("exit_hook", 1, tags) ================================================ FILE: monolith/native_training/metric/kafka_utils.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import time from queue import Queue from absl import logging from threading import Thread, RLock from kafka import KafkaProducer class KProducer(object): def __init__(self, brokers, topic) -> None: self.brokers = brokers self.topic = topic self._producer = KafkaProducer(bootstrap_servers=brokers) self._lock = RLock() self._has_stopped = False self._msg_queue = Queue() # thread safe self._total = 0 self._success = 0 self._failed = 0 self._thread = Thread(target=self._poll) self._thread.start() def send(self, msgs): if msgs is None or len(msgs) == 0: return elif isinstance(msgs, (str, bytes)): msgs = [msgs] else: msgs = [msg for msg in msgs if msg is not None and len(msg) > 0] if len(msgs) > 0: logging.log_first_n(level=logging.INFO, msg=msgs[0], n=10) self._total += len(msgs) self._msg_queue.put(msgs) def _poll(self): while True: try: msg_batch = self._msg_queue.get(timeout=1) except: with self._lock: if self._has_stopped: break else: continue if msg_batch is not None and len(msg_batch) > 0: for msg in msg_batch: future = self._producer.send(self.topic, msg) future.add_callback(self._send_success).add_errback(self._send_failed) with self._lock: if self._has_stopped: break def total(self): return self._total def success(self): return self._success def failed(self): return self._failed def _flush(self): with self._lock: assert self._has_stopped while True: try: msg_batch = self._msg_queue.get(timeout=1) except: break if not msg_batch: break for msg in msg_batch: future = self._producer.send(self.topic, msg) future.add_callback(self._send_success).add_errback(self._send_failed) def close(self): try: logging.info('set stopped') with self._lock: self._has_stopped = True logging.info('wait for thread join') self._thread.join() logging.info('flush queue') self._flush() logging.info('close kafka producer') self._producer.close(timeout=1) except Exception as e: logging.warning(str(e)) def _send_success(self, *args, **kwargs): self._success += 1 def _send_failed(self, *args, **kwargs): time.sleep(2) # if failed, sleep two second logging.warning('send metric to kafka error') self._failed += 1 ================================================ FILE: monolith/native_training/metric/metric_hook.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Metrics codes are ported from Lagrange Lite: lagrange_lite/tensorflow/train.py #coding:utf-8 import json import numpy as np import os import tensorflow as tf import time from typing import Any, Tuple, Callable from queue import Queue, Empty from threading import Thread, RLock from absl import logging, flags from datetime import datetime from tensorflow.python.profiler.internal import _pywrap_traceme from tensorflow.python.training import basic_session_run_hooks from tensorflow.python.training import session_run_hook from tensorflow.python.training import training_util from monolith.native_training.alert import alert_manager from monolith.native_training.alert import alert_pb2 from monolith.native_training.metric import cli from monolith.native_training import utils from monolith.native_training.metric.kafka_utils import KProducer from monolith.native_training.metric.exit_hook import exit_hook FLAGS = flags.FLAGS class ThroughputMetricHook(tf.estimator.SessionRunHook): """ Log accumulated steps and time elapsed per step. """ def __init__(self, model_name, start_time_secs, cluster_type="stable", run_every_n_secs=30): self._model_name = model_name self._start_time_secs = start_time_secs self._cluster_type = cluster_type self._run_every_n_secs = run_every_n_secs self._is_first_step = True self._mcli = cli.get_cli(utils.get_metric_prefix()) am = alert_manager.get_default_alert_manager() if am: proto = alert_pb2.AlertProto() proto.training_alert.prefix = utils.get_metric_prefix() am.add_rules(proto) def begin(self): self._global_step_tensor = tf.compat.v1.train.get_global_step() def before_run(self, run_context): if self._is_first_step is True: self._emit_step = run_context.session.run(self._global_step_tensor) self._emit_time = int(time.time()) if self._start_time_secs is not None: tags = { "model_name": self._model_name, "cluster_type": self._cluster_type } run_start_elapsed_time = self._emit_time - self._start_time_secs logging.info("Run start took {}s.".format(run_start_elapsed_time)) self._mcli.emit_timer("run_start_elapsed_time.all", run_start_elapsed_time, tags) self._is_first_step = False return session_run_hook.SessionRunArgs({ "global_step": self._global_step_tensor, }) def after_run(self, run_context, run_values): end_time = int(time.time()) elapsed_time = end_time - self._emit_time if elapsed_time >= self._run_every_n_secs: global_step = run_values.results["global_step"] step_inerval = global_step - self._emit_step tags = { "model_name": self._model_name, "cluster_type": self._cluster_type } self._mcli.emit_counter("run_steps.all", step_inerval, tags) self._mcli.emit_timer("run_steps_elapsed_time.all", elapsed_time / step_inerval, tags) self._emit_step = global_step self._emit_time = end_time class StepLossMetricHook(tf.estimator.SessionRunHook): """ Log loss of each step. """ def __init__(self, loss_tensor): self._loss_tensor = loss_tensor self._mcli = cli.get_cli(utils.get_metric_prefix()) def before_run(self, run_context): return tf.estimator.SessionRunArgs(self._loss_tensor) def after_run(self, run_context, run_value): self._mcli.emit_store("step_loss", run_value.results) class CustomMetricHook(tf.estimator.SessionRunHook): """ Log group of customed metircs for a batch. """ def __init__(self, metric_tensors): for name in metric_tensors: tensor = metric_tensors[name] if len(tensor.shape.dims) > 0: raise ValueError("The metric tensor should be a scalar!") if tensor.dtype.base_dtype not in (tf.float32, tf.int32): raise ValueError( "The dtype of a metric tensor should be either tf.float or tf.int32!" ) if len(metric_tensors) == 0: raise ValueError("At least one metric tensor should be offered!") self._metric_tensors = metric_tensors self._mcli = cli.get_cli(utils.get_metric_prefix()) def before_run(self, run_context): return tf.estimator.SessionRunArgs(self._metric_tensors) def after_run(self, run_context, run_value): metric_values = run_value.results for name in metric_values: self._mcli.emit_store(name, float(metric_values[name])) class Tf2ProfilerHook(tf.estimator.SessionRunHook): """ Using TF2 profiler in esitmator """ def __init__(self, logdir: str, init_step_range: Tuple[int, int], save_steps: int = None, save_secs: int = None, options: tf.profiler.experimental.ProfilerOptions = None): """Only one of save_steps and save_secs should be provided.""" self._logdir = logdir self._options = options self._start_step, self._end_step = init_step_range if self._start_step is not None and (self._end_step is None or self._end_step <= self._start_step): raise ValueError("End step invalid, start_step: {}, end_step: {}".format(self._start_step, self._end_step)) self._default_delta = 10 self._delta = self._end_step - self._start_step if self._end_step is not None else self._default_delta if save_steps is not None and save_steps <= self._delta: raise ValueError("Save steps must be greater than delta steps(default: {})".format(self._default_delta)) self._timer = tf.estimator.SecondOrStepTimer(every_steps=save_steps, every_secs=save_secs) self._current_step = 0 self._trace_me = None self._profiling = False def begin(self): try: # if enable_sync_training, there is no tf.distribute.Server # we need start profiler server if FLAGS.enable_sync_training: tf.profiler.experimental.server.start(6666) except: logging.warning("cannot start profiler server at 6666") def before_run(self, run_context): # fix step-time graph, related issue: https://github.com/tensorflow/profiler/issues/282 # TODO(huangruiteng): remove this after updating tensorflow if self._profiling: self._trace_me = _pywrap_traceme.TraceMe("TraceContext", graph_type="train", step_num=self._current_step) return tf.estimator.SessionRunArgs(fetches=None) def after_run(self, run_context, run_values: tf.estimator.SessionRunValues): self._current_step += 1 if self._profiling: self._trace_me.Stop() if self._start_step is None: self._start_step = self._current_step + 500 self._end_step = self._start_step + self._default_delta if self._current_step < self._start_step: return if self._current_step >= self._end_step: self._stop_profiling() if self._timer.should_trigger_for_step(self._current_step): self._start_profiling() self._timer.update_last_triggered_step(self._current_step) self._start_step = self._current_step self._end_step = self._start_step + self._delta def end(self, sess): if self._profiling: self._stop_profiling() def _start_profiling(self): try: tf.profiler.experimental.start(self._logdir, self._options) self._profiling = True except tf.errors.AlreadyExistsError: # Two cases: # 1. User profiles by themselves. # 2. When profiling by save_secs, it's still profiling after save_secs. # OK to ignore here. self._profiling = True def _stop_profiling(self): try: if self._profiling: self._profiling = False tf.profiler.experimental.stop() except tf.errors.UnavailableError: # Maybe user terminates profiling self._profiling = False class ByteCCLTelemetryHook(tf.estimator.SessionRunHook): """Log telemetry information at regular intervals""" def __init__(self, interval: int): """Log telemetry information at regular intervals""" self._interval = interval self._last_step = 0 logging.info(f"Created ByteCCL telemetry hook, interval={interval}") def begin(self): self._global_step_tensor = training_util._get_or_create_global_step_read() if self._global_step_tensor is None: raise RuntimeError( "Global step should be created to use ByteCCLTelemetryHook") def before_run(self, run_context): return tf.estimator.SessionRunArgs(self._global_step_tensor) def after_run(self, run_context, run_values: tf.estimator.SessionRunValues): current_step = run_values.results if current_step > self._last_step + self._interval: self._log_telemetry() self._last_step = current_step def end(self, sess): pass def _log_telemetry(self): import byteps.tensorflow as bps if bps.rank() == 0: telemetry = bps.get_telemetry() # sample a few operations and show them samples = [] num_allreduce_ops = 0 for name, mean, stdev, count in telemetry: name = str(name) is_alltoall = 'alltoall' in name.lower() if is_alltoall or ('PushPull' in name and num_allreduce_ops < 3): num_allreduce_ops += 1 entry = f'name: {name} mean(ms): {mean:.2f} stdev(ms): {stdev:.2f} count: {count}' samples.append(entry) if samples: logging.info(f'Communication telemetry: {samples} ...') class NVProfilerHook(Tf2ProfilerHook): def __init__(self, init_step_range: Tuple[int, int], save_steps: int = None, save_secs: int = None, options: tf.profiler.experimental.ProfilerOptions = None): super().__init__(None, init_step_range, save_steps, save_secs) import ctypes self._libcudart = ctypes.cdll.LoadLibrary("libcudart.so") # linux def _start_profiling(self): # http://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__PROFILER.html, self._libcudart.cudaProfilerStart() self._profiling = True def _stop_profiling(self): if self._profiling: self._profiling = False self._libcudart.cudaProfilerStop() class KafkaMetricHook(tf.estimator.SessionRunHook): """ Log group of customed metircs for a batch. """ __instance = None def __new__(cls, *args, **kwargs): if cls.__instance is None: cls.__instance = super().__new__(cls) cls.__instance._kproducer = None cls.__instance._init_kafka() return cls.__instance @classmethod def _init_kafka(cls): brokers = os.getenv('KAFKA_BROKER_LIST', None) topic = os.getenv('KAFKA_TOPIC_NAME', None) if brokers is None or topic is None: logging.info( 'KafkaMetricHook init kafka failed, brokers: {}, topic: {}'.format( brokers, topic)) return cls.__instance._kproducer = KProducer(brokers, topic) logging.info( 'KafkaMetricHook init kafka success, brokers: {}, topic: {}'.format( brokers, topic)) def __init__(self, deep_insight_op=None): if deep_insight_op is None: collection = tf.compat.v1.get_collection(key='deep_insight_op') if collection: if isinstance(collection, (list, tuple)): deep_insight_op = collection[0] else: deep_insight_op = collection self._metric_tensors = {'deep_insight_op': deep_insight_op} def before_run(self, run_context): return tf.estimator.SessionRunArgs(self._metric_tensors) def after_run(self, run_context, run_value): if self._kproducer: metric_values = run_value.results msgs = metric_values.get('deep_insight_op') if msgs is not None and len(msgs) > 0: self._kproducer.send(msgs) def end(self, session): if self._kproducer: self._kproducer.close() logging.info('KafkaMetricHook end, flush msg, success: {}, failed: {}'.\ format(self._kproducer.success(), self._kproducer.failed())) self._kproducer = None def default_parse_fn(obj: Any) -> Any: if obj is not None: if isinstance(obj, (str, bytes)): return json.loads(obj) return obj def default_layout_fn(obj, indent=None) -> str: if isinstance(obj, str): return obj else: try: return json.dumps(obj, indent=indent) except: return repr(obj) def vepfs_layout_fn(obj) -> str: req_time = obj.get('__REQ_TIME__') or obj.get('req_time') gid = obj.get('__FEED_ID__') or obj.get('feedid') or obj.get('gid') or 'gid' uid = obj.get('__UID__') or obj.get('userid') or obj.get('uid') or 'uid' predict_scores = json.dumps(obj['predict']) if 'predict' in obj else None labels = json.dumps(obj['label']) if 'label' in obj else None return f"{req_time};{gid};{uid};{predict_scores};{labels}" def vepfs_key_fn(obj, worker_id: int, base_name: str) -> str: model_name = obj.get('model_name') or 'model_name' date = obj.get('__REQ_TIME__') or obj.get('req_time') return os.path.join(base_name, model_name, date, f'worker_{worker_id}') class WriteOnlyFileAndStat(object): def __init__(self, key: str, layout_fn: Callable[[Any], str] = None, batch_size: int = 1024, partition_size: int = None, file_ext: str = 'txt'): self.current_partition: int = 0 self.current_offset: int = 0 self.last_update_time: float = time.time() self.buffer: List[Any] = [] self.batch_size = batch_size self.partition_size = partition_size or int(1e6) self.layout_fn = layout_fn or default_layout_fn self.file_ext = file_ext assert key is not None self.key = key self.stream = None self._lock = RLock() def write(self, obj): if len(self.buffer) >= self.batch_size: self.flush() with self._lock: if obj is not None: self.buffer.append(self.layout_fn(obj)) self.current_offset += 1 self.last_update_time = time.time() def write_many(self, objs): if objs: for obj in objs: self.write(obj) def flush(self, check: bool = True): with self._lock: if self.stream is None: if not tf.io.gfile.exists(path=self.key): tf.io.gfile.makedirs(path=self.key) part_name = os.path.join( self.key, f'part_{self.current_partition:06d}.{self.file_ext}') self.stream = tf.io.gfile.GFile(part_name, 'w+') if self.stream is not None: if self.buffer: self.stream.write('\n'.join(self.buffer)) self.stream.write('\n') self.buffer = [] self.stream.flush() if check and self.current_offset >= self.partition_size: self.current_partition += 1 self.current_offset = 0 self.stream.close() part_name = os.path.join( self.key, f'part_{self.current_partition:06d}.{self.file_ext}') self.stream = tf.io.gfile.GFile(part_name, 'w+') def close(self): with self._lock: self.flush(check=False) if self.stream is not None: self.stream.close() self.stream = None def is_available(self): return (time.time() - self.last_update_time) < 24 * 60 * 60 class FileMetricHook(tf.estimator.SessionRunHook): """ Log group of customed metircs for a batch. """ __instance = None def __new__(cls, *args, **kwargs): if cls.__instance is None: cls.__instance = super().__new__(cls) return cls.__instance def __init__(self, deep_insight_op=None, *, worker_id: int = None, parse_fn: Callable[[Any], Any] = None, key_fn: Callable[[Any, int, str], str] = None, layout_fn: Callable[[Any], str] = None, batch_size: int = 1024, partition_size: int = None, base_name: str = '/vepfs/jaguar_deepinsight_results', file_ext: str = 'txt'): if deep_insight_op is None: collection = tf.compat.v1.get_collection(key='deep_insight_op') if collection: if isinstance(collection, (list, tuple)): deep_insight_op = collection[0] else: deep_insight_op = collection else: deep_insight_op = None self._worker_id = worker_id self._key_fn = key_fn self._layout_fn = layout_fn or default_layout_fn self._parse_fn = parse_fn or default_parse_fn self._batch_size = batch_size self._partition_size = partition_size self._base_name = base_name self._file_ext = file_ext self._queue: Queue = Queue() self._files: Dict[str, WriteOnlyFileAndStat] = {} self._stopped = False self._metric_tensors = {'deep_insight_op': deep_insight_op} self._thread = None def before_run(self, run_context): return tf.estimator.SessionRunArgs(self._metric_tensors) def after_run(self, run_context, run_value): if self._thread is None: self._thread = Thread(target=self._send) self._thread.start() metric_values = run_value.results msgs = metric_values.get('deep_insight_op') if msgs is not None: if isinstance(msgs, (list, tuple, np.ndarray)): for msg in msgs: if msg: self._queue.put(msg) else: self._queue.put(msgs) def end(self, session): logging.info('end FileMetricHook: empty the queue ...') while not self._queue.empty(): time.sleep(1) logging.info('end FileMetricHook: queue is empty, begin to stop thread ...') self._stopped = True if self._thread is not None: self._thread.join() self._thread = None logging.info( 'end FileMetricHook: thread stopped, begin to close open file ...') for fs in self._files.values(): fs.close() logging.info('end FileMetricHook: all done! ') def _send(self): last_check_time = time.time() while not self._stopped: try: item = self._queue.get(timeout=1) item = self._parse_fn(item) except Empty as e: continue key = self._key_fn(item, self._worker_id, self._base_name) if key not in self._files: file_and_stat = WriteOnlyFileAndStat( key, layout_fn=self._layout_fn, batch_size=self._batch_size, partition_size=self._partition_size, file_ext=self._file_ext) self._files[key] = file_and_stat else: file_and_stat = self._files[key] file_and_stat.write(item) if time.time() - last_check_time > 600: to_remove = set() for key, fs in self._files.items(): if not fs.is_available(): fs.close() to_remove.add(key) for key in to_remove: del self._files[key] last_check_time = time.time() ================================================ FILE: monolith/native_training/metric/metric_hook_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import time import json from datetime import datetime, timedelta from random import choice, randint import tensorflow as tf from monolith.native_training.metric import metric_hook class Tf2ProfilerHookTest(tf.test.TestCase): def setUp(self): super().setUp() self.logdir = os.path.join(os.environ["TEST_TMPDIR"], self._testMethodName) self.filepattern = os.path.join(self.logdir, "plugins/profile/*") self.graph = tf.Graph() with self.graph.as_default(): self.global_step = tf.compat.v1.train.get_or_create_global_step() self.train_op = tf.compat.v1.assign_add(self.global_step, 1) def _count_files(self): return len(tf.io.gfile.glob(self.filepattern)) def test_steps(self): with self.graph.as_default(): hook = metric_hook.Tf2ProfilerHook(self.logdir, init_step_range=[0, 10], save_steps=50) with tf.compat.v1.train.SingularMonitoredSession(hooks=[hook]) as sess: sess.run(self.train_op) self.assertEqual(self._count_files(), 1) def test_multiple_steps_1(self): with self.graph.as_default(): hook = metric_hook.Tf2ProfilerHook(self.logdir, init_step_range=[0, 10], save_steps=30) with tf.compat.v1.train.SingularMonitoredSession(hooks=[hook]) as sess: for _ in range(30): sess.run(self.train_op) # Since profiler directory is named by seconds, we need to make sure # two dumps are in the different folder. time.sleep(0.15) # Triggered at 0~9 self.assertEqual(self._count_files(), 1) def test_multiple_steps_2(self): with self.graph.as_default(): hook = metric_hook.Tf2ProfilerHook(self.logdir, init_step_range=[0, 10], save_steps=30) with tf.compat.v1.train.SingularMonitoredSession(hooks=[hook]) as sess: for _ in range(31): sess.run(self.train_op) # Since profiler directory is named by seconds, we need to make sure # two dumps are in the different folder. time.sleep(0.15) # Triggered at 0~9, 30 self.assertEqual(self._count_files(), 2) def test_secs_1(self): with self.graph.as_default(): hook = metric_hook.Tf2ProfilerHook(self.logdir, init_step_range=[0, 10], save_secs=1) with tf.compat.v1.train.SingularMonitoredSession(hooks=[hook]) as sess: for _ in range(10): sess.run(self.train_op) # In total, we will sleep for 1.5s, but it still remains profiling first step range time.sleep(0.15) self.assertGreaterEqual(self._count_files(), 1) def test_secs_2(self): with self.graph.as_default(): hook = metric_hook.Tf2ProfilerHook(self.logdir, init_step_range=[0, 10], save_secs=3) with tf.compat.v1.train.SingularMonitoredSession(hooks=[hook]) as sess: for _ in range(21): sess.run(self.train_op) # In total, we will sleep for 3.15s time.sleep(0.15) # At least we will 2 dumps (maybe more depending on how fast we run the program) self.assertGreaterEqual(self._count_files(), 2) class FileMetricHookTest(tf.test.TestCase): @classmethod def setUpClass(cls): cls.model_name = 'test_model' cls.base_name = f'{os.environ.get("HOME")}/tmp/file_metric_hook' cls.hook = metric_hook.FileMetricHook(worker_id=0, key_fn=metric_hook.vepfs_key_fn, layout_fn=metric_hook.vepfs_layout_fn, batch_size=8, partition_size=32, base_name=cls.base_name) @classmethod def tearDownClass(cls): cls.hook.end(None) date_dir = tf.io.gfile.listdir(path=f'{cls.base_name}/{cls.model_name}') for i in range(7, -1, -1): date = datetime.today() - timedelta(days=i) date_str = date.strftime('%Y%m%d') assert date_str in date_dir path = f'{cls.base_name}/{cls.model_name}/{date_str}/worker_0/' data_dir = tf.io.gfile.listdir(path=path) assert len(data_dir) == 2 for df in data_dir: fname = f'{path}{df}' with tf.io.gfile.GFile(fname, 'r') as stream: assert len(stream.readlines()) == 32 def test_vepfs_key_fn(self): data = { 'model_name': 'test_model', 'req_time': '20220927', 'userid': '1854', 'predict': { 'feed_comment': 0.5, 'click_comment': 0.2, 'feed_share': 0.2 }, 'label': { 'feed_comment': 0, 'click_comment': 1, 'feed_share': 0 } } self.assertEqual( metric_hook.vepfs_key_fn(data, worker_id=0, base_name=self.base_name), f'{self.base_name}/test_model/20220927/worker_0') def test_vepfs_layout_fn(self): data = { 'model_name': 'test_model', 'req_time': '20220927', 'userid': '1854', 'predict': { 'feed_comment': 0.5, 'click_comment': 0.2, 'feed_share': 0.2 }, 'label': { 'feed_comment': 0, 'click_comment': 1, 'feed_share': 0 } } self.assertEqual( metric_hook.vepfs_layout_fn(data), '20220927;gid;1854;{"feed_comment": 0.5, "click_comment": 0.2, "feed_share": 0.2};{"feed_comment": 0, "click_comment": 1, "feed_share": 0}' ) def test_after_run(self): run_context = None head_names = ['feed_comment', 'click_comment', 'feed_share'] predicts = [0.01, 0.1, 0.2, 0.5, 0.9, 0.99] labels = [0, 1] class RunValue(object): def __init__(self, rv): self.results = {'deep_insight_op': [json.dumps(rv)]} for i in range(7, -1, -1): date = datetime.today() - timedelta(days=i) date_str = date.strftime('%Y%m%d') for _ in range(64): run_value = { 'model_name': self.model_name, 'req_time': date_str, 'feedid': str(randint(1, 4096)), 'userid': str(randint(1, 4096)), 'predict': {name: choice(predicts) for name in head_names}, 'label': {name: choice(labels) for name in head_names}, } self.hook.after_run(run_context, run_value=RunValue(run_value)) if __name__ == "__main__": tf.test.main() ================================================ FILE: monolith/native_training/metric/utils.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import datetime import logging from typing import Dict, List import tensorflow as tf from monolith.native_training.metric import deep_insight_ops def write_deep_insight(features: Dict[str, tf.Tensor], sample_ratio: float, model_name: str, labels: tf.Tensor = None, preds: tf.Tensor = None, target: str = None, targets: List[str] = None, labels_list: List[tf.Tensor] = None, preds_list: List[tf.Tensor] = None, sample_rates_list: List[tf.Tensor] = None, extra_fields_keys: List[str] = [], enable_deep_insight_metrics=True, enable_kafka_metrics=False, dump_filename=None) -> tf.Tensor: """ Writes the data into deepinsight Requires 'uid', 'req_time', and 'sample_rate' in features. sample_ratio is deepinsight sample ratio, set value like 0.01. If targets is non-empty, MonolithWriteDeepInsightV2 will be used, enabling: - Multi-target sent as one message; - Dump extra fields. When using MonolithWriteDeepInsightV2, labels/preds/sample_rates should be shape (num_targets, batch_size). sample_rates is optional. Extra fields specified in extra_fields_keys must be present in features, and must have batch_size numbers of values. """ if 'req_time' not in features: logging.info("Disabling deep_insight because req_time is absent") return tf.no_op() is_fake = enable_kafka_metrics or (dump_filename is not None and len(dump_filename) > 0) deep_insight_client = deep_insight_ops.deep_insight_client( enable_deep_insight_metrics, is_fake, dump_filename=dump_filename) req_times = tf.reshape(features["req_time"], [-1]) if not targets: uids = tf.reshape(features["uid"], [-1]) sample_rates = tf.reshape(features["sample_rate"], [-1]) deep_insight_op = deep_insight_ops.write_deep_insight( deep_insight_client_tensor=deep_insight_client, uids=uids, req_times=req_times, labels=labels, preds=preds, sample_rates=sample_rates, model_name=model_name, target=target, sample_ratio=sample_ratio, return_msgs=is_fake) else: labels = tf.stack([label if label.shape.rank == 1 else tf.reshape(label, (-1,)) for label in labels_list if label is not None]) preds = tf.stack([pred if pred.shape.rank == 1 else tf.reshape(pred, (-1,)) for pred in preds_list if pred is not None]) if not sample_rates_list: sample_rates_list = [tf.reshape(features["sample_rate"], [-1]) ] * len(targets) elif isinstance(sample_rates_list, (tuple, list)): sample_rates_list = [sample_rate if sample_rate.shape.rank == 1 else tf.reshape(sample_rate, (-1,)) for sample_rate in sample_rates_list if sample_rate is not None] else: raise Exception("sample_rates_list error!") sample_rates = tf.stack(sample_rates_list) if "uid" not in extra_fields_keys: extra_fields_keys.append("uid") extra_fields_values = [] for key in extra_fields_keys: extra_fields_values.append(tf.reshape(features[key], [-1])) deep_insight_op = deep_insight_ops.write_deep_insight_v2( deep_insight_client_tensor=deep_insight_client, req_times=req_times, labels=labels, preds=preds, sample_rates=sample_rates, model_name=model_name, sample_ratio=sample_ratio, extra_fields_values=extra_fields_values, extra_fields_keys=extra_fields_keys, targets=targets, return_msgs=is_fake) return deep_insight_op ================================================ FILE: monolith/native_training/metric/utils_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from unittest import mock import tensorflow as tf from monolith.native_training.metric import utils class DeepInsightTest(tf.test.TestCase): @mock.patch( "monolith.native_training.metric.deep_insight_ops.write_deep_insight") def test_basic(self, deep_insight_op): def fake_call(uids, **kwargs): del kwargs with self.session() as sess: uids = sess.run(uids) self.assertAllEqual(uids, [1, 2, 3]) deep_insight_op.side_effect = fake_call features = { "uid": tf.constant([1, 2, 3], dtype=tf.int64), "req_time": tf.constant([1, 2, 3], dtype=tf.int64), "sample_rate": tf.constant([0.5, 0.5, 0.5], dtype=tf.float32), } labels = tf.constant([1.0, 0.0, 1.0], dtype=tf.float32) preds = tf.constant([0.9, 0.2, 0.8], dtype=tf.float32) model_name = "test_model" target = "target" utils.write_deep_insight(features, 0.01, labels, preds, model_name, target) if __name__ == "__main__": tf.compat.v1.disable_eager_execution() tf.test.main() ================================================ FILE: monolith/native_training/mlp_utils.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import sys import time import socket from typing import List from absl import logging, flags import signal from subprocess import Popen, PIPE import tensorflow as tf import tensorflow.python.data.experimental.service as dsvc from tensorflow_estimator.python.estimator.util import _DatasetInitializerHook from monolith.native_training.distribution_utils import get_mpi_rank, \ get_mpi_size, get_mpi_local_size, enable_sync_training, get_device_str from monolith.native_training.model_export.export_context import \ is_exporting, is_dry_run_or_exporting FLAGS = flags.FLAGS from monolith.native_training import yarn_runtime def check_port(host: str, port: int, timeout: float = 1) -> bool: is_ipv6 = ':' in host.strip('[]') skt = socket.socket(socket.AF_INET6 if is_ipv6 else socket.AF_INET, socket.SOCK_STREAM) start = time.time() skt.settimeout(timeout) while True: try: skt.connect((host, int(port))) return True except socket.timeout as e: return False except socket.error as e: now = time.time() remaining = timeout - int(now - start) if remaining > 0: skt.settimeout(remaining) continue else: return False class MLPEnv(object): def __init__(self): self._mlp_env = { k: v for k, v in os.environ.items() if k.startswith('MLP_') or k.startswith('MPI_') } self.framework = self._get('MLP_FRAMEWORK') self.ssh_port = self._get('MLP_SSH_PORT') self.log_path = self._get('MLP_LOG_PATH') self.debug_port = self._get('MLP_DEBUG_PORT') self.entrypoint_dir = self._get('MLP_ENTRYPOINT_DIR') self.task_cmd = self._get('MLP_TASK_CMD') self.role = self._get('MLP_ROLE', "").upper() self.all_roles = { k.split('_')[1]: int(self._get(k, 0)) for k in self._mlp_env if k.endswith('_NUM') and len(k.split('_')) == 3 } if self.enable_mpi: self.index = get_mpi_rank() self.all_roles['WORKER'] = get_mpi_size() self.port = int(self._get('MLP_PORT', 0)) + self.index logging.info( f'total process is {get_mpi_size()}, this is {get_mpi_rank()}, port is {self.port}' ) else: self.index = int(self._get('MLP_ROLE_INDEX', 0)) self.port = int(self._get('MLP_PORT', 0)) logging.info(f'enable_mpi is False, index {self.index}, port {self.port}') if len(self._mlp_env) > 0 and len(self.all_roles) > 0: self.avaiable = True else: self.avaiable = False self.cpu = int(self._get('MLP_CPU', 0)) self.gpu = int(self._get('MLP_GPU', 0)) self.gpu_type = self._get('MLP_GPU_TYPE', "") self.mem = int(self._get('MLP_MEM', 0)) self.host = yarn_runtime.get_local_host() #self._get('MLP_HOST') self._has_started_profiler = False @property def enable_mpi(self): return 'OMPI_COMM_WORLD_RANK' in os.environ and self.role == "WORKER" def _get(self, name: str, default=None): value = self._mlp_env.get(name) if value: return value.strip().strip('"').strip("'") else: return default def num_replicas(self, role: str = None): role = (role or self.role).upper() if self.enable_mpi and role == 'WORKER': return get_mpi_size() key = f'MLP_{role}_NUM' logging.info(f"{key}, mlp_env: {self._mlp_env}") return int(self._get(key, 0)) def get_all_host(self, role: str = None, is_primary: bool = True) -> List[str]: role = (role or self.role).upper() if is_primary: key = f'MLP_{role}_ALL_PRIMARY_HOSTS' else: key = f'MLP_{role}_ALL_HOSTS' return self._get(key) def get_all_addrs(self, role: str = None, is_primary: bool = True) -> List[str]: role = (role or self.role).upper() if is_primary: key = f'MLP_{role}_ALL_PRIMARY_ADDRS' else: key = f'MLP_{role}_ALL_ADDRS' addrs = self._get(key) if addrs: return addrs.split(',') else: return [] def get_host(self, role: str = None, index: int = None, is_primary: bool = True) -> str: role = (role or self.role).upper() if self.enable_mpi and role == 'WORKER': index = (self.index if index is None else index) // get_mpi_local_size() elif role == self.role: index = self.index if index is None else index else: index = 0 if index is None else index if is_primary: key = f'MLP_{role}_{index}_PRIMARY_HOST' else: key = f'MLP_{role}_{index}_HOST' return self._get(key) def get_addr(self, role: str = None, index: int = None, is_primary: bool = True) -> str: role = (role or self.role).upper() if role == self.role: index = self.index if index is None else index else: index = 0 if index is None else index host = self.get_host(role, index, is_primary) if self.enable_mpi and role == 'WORKER': key = f'MLP_{role}_0_PORT' port = self._get(key) if port is not None: port = str(int(port) + index) else: key = f'MLP_{role}_{index}_PORT' port = self._get(key) if host and port: return f'{host}:{port}' else: return None def get_port(self, role: str = None, index: int = None) -> int: role = (role or self.role).upper() if self.enable_mpi and role == 'WORKER': index = self.index if index is None else index key = f'MLP_{role}_0_PORT' return self._get(key, 2222) + index else: index = 0 if index is None else index key = f'MLP_{role}_{index}_PORT' return self._get(key, 2222) def dispatcher_target(self, role: str = None) -> str: addr = self.dispatcher_addr(role) if addr: return f'grpc://{addr}' else: return 'grpc://localhost:5050' def dispatcher_addr(self, role: str = None) -> str: role = (role or 'dispatcher').upper() return self.get_addr(role) def wait(self, role: str = None, index: int = 0, timeout: int = -1, use_ssh: bool = True): host = self.get_host(role, index, True) port = self.ssh_port if use_ssh else self.get_port(role, index) if host: current = 0 while True: if timeout > 0 and current >= timeout: logging.info(f'wait {host}:{port} timeout!') break if check_port(host, port): return else: time.sleep(5) current += 5 else: logging.info('host is None') def join(self, role: str = 'worker', index: int = 0, use_ssh: bool = True): self.wait(role, index, use_ssh=use_ssh) host = self.get_host(role, index, True) port = self.ssh_port if use_ssh else self.get_port(role, index) if host: while True: if not check_port(host, port, timeout=60): break #return else: time.sleep(10) else: logging.info('host is None') if self._has_started_profiler: try: tf.profiler.experimental.stop() except Exception as e: logging.info(f'experimental stop error: {e}') logging.info('profiler stopped!') logging.info(f'current role: {self.role}:{self.index} exit') os._exit(0) @property def queue_device(self) -> str: if 'PS' in self.all_roles: return "/job:ps/task:0/device:CPU:0" elif 'WORKER' in self.all_roles: return "/job:worker/task:0/device:CPU:0" else: return "/device:CPU:0" def start_profiler(self, port=6666): logging.info(f'start_profiler at {self.host}:{port}') if self.enable_mpi: port += self.index tf.profiler.experimental.server.start(port) self._has_started_profiler = True def profiler_trace(self, role: str = 'dsworker', index: int = -1, host_tracer_level=2, python_tracer_level=0, device_tracer_level=1, delay_ms=10000): logdir = self._get("TENSORBOARD_LOG_PATH", "/tensorboard_logs/") options = tf.profiler.experimental.ProfilerOptions( host_tracer_level=host_tracer_level, python_tracer_level=python_tracer_level, device_tracer_level=device_tracer_level, delay_ms=delay_ms) if index < 0: all_addrs = self.get_all_addrs(role) service_addr = ','.join(map(lambda addr: f'grpc://{addr}', all_addrs)) else: service_addr = f'grpc://{self.get_addr(role, index)}' tf.profiler.experimental.client.trace(service_addr=service_addr, logdir=logdir, duration_ms=delay_ms, options=options) def add_mpi_exception_hook(): if 'OMPI_COMM_WORLD_RANK' not in os.environ: return logging.info("add_mpi_exception_hook") # Global error handler def global_except_hook(exctype, value, traceback): try: sys.stderr.write( "\n*****************************************************\n") sys.stderr.write("Uncaught exception was detected on rank {}. \n".format( int(os.environ.get('OMPI_COMM_WORLD_RANK', -1)))) from traceback import print_exception print_exception(exctype, value, traceback) sys.stderr.write( "*****************************************************\n\n\n") sys.stderr.write("\n") sys.stderr.write("Calling MPI_Abort() to shut down MPI processes...\n") sys.stderr.flush() finally: try: import mpi4py.MPI mpi4py.MPI.COMM_WORLD.Abort(1) except Exception as e: sys.stderr.write( "*****************************************************\n") sys.stderr.write( "Sorry, we failed to stop MPI, this process will hang.\n") sys.stderr.write( "*****************************************************\n") sys.stderr.flush() raise e sys.excepthook = global_except_hook EXTRA_DSWORKERS = [] def mlp_pass(dispatcher_role: str = 'dispatcher', dsworker_role: str = 'dsworker', worker_role: str = 'worker', ps_role: str = 'ps'): dispatcher_role = None if dispatcher_role is None else dispatcher_role.upper() dsworker_role = None if dsworker_role is None else dsworker_role.upper() worker_role = None if worker_role is None else worker_role.upper() pa_role = None if ps_role is None else ps_role.upper() if FLAGS.dataset_use_dataservice: _DatasetInitializerHook.begin = begin _DatasetInitializerHook.after_create_session = after_create_session mlp_env = MLPEnv() if mlp_env.avaiable: logging.info('MLP is available') logging.info('mlp_env_host: %s, mlp_env_port: %s', mlp_env.host, mlp_env.port) if mlp_env.role == dispatcher_role: if dispatcher_role: dispatcher = dsvc.DispatchServer( dsvc.DispatcherConfig(port=mlp_env.port)) logging.info('Dispatcher started...') mlp_env.join() elif mlp_env.role == dsworker_role: if dsworker_role: logging.info('Waiting for dispatcher start...') assert dispatcher_role is not None mlp_env.wait(dispatcher_role, use_ssh=False) logging.info('Dispatcher started, dsworker begin to start...') dispatcher_address = mlp_env.dispatcher_addr(role=dispatcher_role) worker = dsvc.WorkerServer( dsvc.WorkerConfig(dispatcher_address=dispatcher_address, worker_address=f'{mlp_env.host}:{mlp_env.port}', port=mlp_env.port)) logging.info('Dsworker started....') mlp_env.start_profiler() mlp_env.join() elif mlp_env.role == worker_role: if FLAGS.dataset_use_dataservice: if dispatcher_role: logging.info("wait dispatcher start ...") mlp_env.wait(dispatcher_role, use_ssh=False) FLAGS.data_service_dispatcher = mlp_env.dispatcher_target() if dsworker_role: logging.info("dispatcher started, wait ds worker start ...") for idx in range(mlp_env.num_replicas(role=dsworker_role)): mlp_env.wait(dsworker_role, index=idx, use_ssh=False) logging.info(f'dsworker {idx} started! ') # Extra dsworkers on GPU worker dispatcher_address = mlp_env.dispatcher_addr(role=dispatcher_role) logging.info("dispatcher_address: %s", dispatcher_address) global EXTRA_DSWORKERS for i in range(FLAGS.num_extra_dsworker_on_gpu_worker): # base port number + gpu worker port offset + number generated by (gpu worker index, extra dswroker index) port = (mlp_env.port - mlp_env.index) + get_mpi_size( ) + mlp_env.index * FLAGS.num_extra_dsworker_on_gpu_worker + i logging.info('extra_port: %s', port) worker = dsvc.WorkerServer( dsvc.WorkerConfig(dispatcher_address=dispatcher_address, worker_address=f'{mlp_env.host}:{port}', port=port)) EXTRA_DSWORKERS.append(worker) logging.info('Start %s extra dsworkers on GPU worker %s', len(EXTRA_DSWORKERS), mlp_env.index) logging.info( f"worker {mlp_env.index} start at {mlp_env.host}:{mlp_env.port}") logging.info(f'{mlp_env.all_roles}') # if ps_role is None or ps_role not in mlp_env.all_roles: # mlp_env.start_profiler() def begin(self): self._initializer = self._iterator.initializer self._broadcast_dataset_id = None if not is_dry_run_or_exporting(): self._rank = -1 if enable_sync_training(): try: enable_bps = int(os.getenv("MONOLITH_WITH_BYTEPS", "0")) if enable_bps: import byteps.tensorflow as hvd else: import horovod.tensorflow as hvd dataset_ids = tf.compat.v1.get_collection(key='registed_dataset_id') if dataset_ids is not None and len(dataset_ids) > 0: dataset_id = dataset_ids[0] if dataset_id is not None: self._rank = hvd.rank() #with tf.device(None), tf.device(get_device_str(True)): self._broadcast_dataset_id = [ dataset_id, hvd.broadcast(tensor=dataset_id, root_rank=0, name="broadcast_dataset_id") ] graph.clear_collection(name='registed_dataset_id') except Exception as e: logging.info(f'import byteps/horovod error: {e}') def after_create_session(self, session, coord): del coord if self._broadcast_dataset_id is not None and not is_dry_run_or_exporting(): dataset_id, bc_dataset_id = session.run(self._broadcast_dataset_id) logging.info( f'dataset_id is {dataset_id}, bc_dataset_id is {bc_dataset_id}, rank {self._rank}' ) self._broadcast_dataset_id = None session.run(self._initializer) ================================================ FILE: monolith/native_training/model.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from datetime import datetime import numpy as np from typing import Callable, Dict, List from absl import logging import tensorflow as tf from monolith.core import model_registry from monolith.core.base_model_params import SingleTaskModelParams from monolith.native_training import entry, feature from monolith.native_training.input import (generate_ffm_example, slot_to_key) import monolith.native_training.metric.deep_insight_ops as deep_insight_ops from monolith.native_training.native_task import NativeTask _NUM_SLOTS = 6 _FFM_SLOT = ((0, 3, 16), (0, 4, 16), (1, 5, 16), (2, 3, 16), (2, 5, 16)) _VOCAB_SIZES = [5, 5, 5, 5, 5, 5] _NUM_EXAMPLES = 64 def _parse_example(example: str) -> Dict[str, tf.Tensor]: def _get_feature_map(): feature_map = {} feature_map["label"] = tf.io.FixedLenFeature([], dtype=tf.float32) for i in range(len(_VOCAB_SIZES)): feature_map[slot_to_key(i)] = tf.io.VarLenFeature(dtype=tf.int64) return feature_map features = tf.io.parse_example(example, _get_feature_map()) for k, v in features.items(): if isinstance(v, tf.sparse.SparseTensor): features[k] = tf.RaggedTensor.from_sparse(v) return features class TestFFMModel(NativeTask): def __init__(self, params): super().__init__(params) self.p = params def create_input_fn(self, mode): def input_fn(): # This keeps the training data stability so we can resume training from the # checkpoint. np.random.seed(0) examples = [ generate_ffm_example(_VOCAB_SIZES) for i in range(_NUM_EXAMPLES) ] dataset = tf.data.Dataset.from_tensor_slices(examples) dataset = dataset.batch(self.p.train.per_replica_batch_size, drop_remainder=True) dataset = dataset.map(_parse_example, num_parallel_calls=tf.data.experimental.AUTOTUNE) dataset = dataset.cache().repeat() dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) return dataset return input_fn def create_model_fn(self): def model_fn(features: Dict, mode: tf.estimator.ModeKeys, config: tf.estimator.RunConfig) -> tf.estimator.EstimatorSpec: del config global_step = tf.compat.v1.train.get_or_create_global_step() slots = {} fc = {} bias_list = [] for i in range(_NUM_SLOTS): slots.update({ i: self.ctx.create_feature_slot( feature.FeatureSlotConfig( name=str(i), has_bias=i <= (_NUM_SLOTS // 2), bias_optimizer=entry.FtrlOptimizer( learning_rate=0.1, initial_accumulator_value=1e-6, beta=1.0), default_vec_optimizer=entry.SgdOptimizer( learning_rate=0.1))) }) fc.update({i: feature.FeatureColumnV1(slots[i], slot_to_key(i))}) if i <= (_NUM_SLOTS // 2): bias_list.append(fc[i].embedding_lookup(slots[i].get_bias_slice())) bias_input = tf.concat(bias_list, axis=1, name='concatenate_tensor_from_{}_bias'.format( len(bias_list))) sum_bias = tf.reduce_sum(bias_input, axis=1) dot_res = [] for user, item, dim in _FFM_SLOT: user_vec = fc[user].embedding_lookup(slots[user].add_feature_slice(dim)) item_vec = fc[item].embedding_lookup(slots[item].add_feature_slice(dim)) dot_res.append(tf.reduce_sum(tf.multiply(user_vec, item_vec), 1)) ffm_out = tf.add_n(dot_res) + sum_bias pred = tf.nn.sigmoid(ffm_out) if mode == tf.estimator.ModeKeys.PREDICT: return tf.estimator.EstimatorSpec(mode, predictions=pred) loss = tf.reduce_sum( tf.losses.binary_crossentropy(features["label"], pred)) # Write deep insight if self.p.metrics.enable_deep_insight and self.p.metrics.deep_insight_sample_ratio > 0: deep_insight_client = deep_insight_ops.deep_insight_client(False) now = datetime.now() model_name = self.p.metrics.deep_insight_name uids = tf.cast(tf.fill([self.p.train.per_replica_batch_size], 0), dtype=tf.int64) req_times = tf.cast(tf.fill([self.p.train.per_replica_batch_size], int(datetime.timestamp(now))), dtype=tf.int64) sample_rates = tf.fill([self.p.train.per_replica_batch_size], 0.1) target = "ctr_head" deep_insight_op = deep_insight_ops.write_deep_insight( deep_insight_client_tensor=deep_insight_client, uids=uids, req_times=req_times, labels=features["label"], preds=pred, sample_rates=sample_rates, model_name=model_name, target=target, sample_ratio=0.01) logging.info("model_name: {}, target: {}.".format(model_name, target)) else: deep_insight_op = tf.no_op() update_global_step = tf.compat.v1.assign_add(global_step, 1) all_embeddings = [v.get_all_embeddings_concat() for v in fc.values()] emb_grads = tf.gradients(loss, all_embeddings) with tf.control_dependencies([update_global_step]): train_op = tf.group( self.ctx.apply_embedding_gradients(zip(emb_grads, all_embeddings)), deep_insight_op) return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op) return model_fn def create_serving_input_receiver_fn(self): def serving_input_receiver_fn(): receiver_tensors = {} instances_placeholder = tf.compat.v1.placeholder(dtype=tf.string, shape=(None,)) receiver_tensors["instances"] = instances_placeholder parsed_results = _parse_example(instances_placeholder) return tf.estimator.export.ServingInputReceiver(parsed_results, receiver_tensors) return serving_input_receiver_fn @model_registry.RegisterSingleTaskModel class FFMParams(SingleTaskModelParams): def task(self): p = TestFFMModel.params() p.train.per_replica_batch_size = 64 return p ================================================ FILE: monolith/native_training/model_comp_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os os.environ["MONOLITH_WITH_HOROVOD"] = "True" os.environ["HOROVOD_AUTOTUNE"] = "1" os.environ["HOROVOD_CYCLE_TIME"] = "0.1" os.environ["MONOLITH_SYNC_EMPTY_RANK0_PS_SHARD"] = "0" os.environ["MONOLITH_WITH_ALLREDUCE_FUSION"] = "one" os.environ['MONOLITH_ROOT_LOG_INTERVAL'] = "10" import time import tensorflow as tf tf.compat.v1.set_random_seed(42) import getpass from tensorflow.python.framework import test_util from monolith.native_training import cpu_training from monolith.native_training.native_model import MonolithModel from monolith.native_training.estimator import Estimator, EstimatorSpec from monolith.native_training.data.training_instance.python.parser_utils import advanced_parse from monolith.native_training.entry import (AdagradOptimizer, Fp16Compressor, ZerosInitializer) from monolith.native_training import layers import horovod.tensorflow as hvd deep = { "initializer": ZerosInitializer(), "optimizer": AdagradOptimizer(learning_rate=0.05, weight_decay_factor=0.0, initial_accumulator_value=0.1, warmup_steps=0), "compressor": Fp16Compressor() } num_features = 17 batch_size = 455 emb_dim = 15 feature_names = [f'feature{i}' for i in range(num_features)] fid_max_val = 100000 def lookup_tf_embedding(features, f_name, dim): f = tf.RaggedTensor.from_row_splits(features[f'tf_{f_name}_p1'], features[f'tf_{f_name}_p2'], validate=False) embeddings = tf.nn.embedding_lookup( params=tf.Variable(initial_value=tf.zeros(shape=(fid_max_val + 1, dim))), ids=f.values) return tf.math.segment_sum(embeddings, f.value_rowids()) class EmbeddingUpdateTask(MonolithModel): """A test task that will compare TF and monolith embedding update.""" def __init__(self, params=None): super(EmbeddingUpdateTask, self).__init__(params) self.train.max_steps = 50 self.train.per_replica_batch_size = batch_size def input_fn(self, mode): def decomp_func(features): # note: this is a workaround to pass fids to model_fn, since all instances of RaggedTensor are gone for i in range(num_features): features[f'tf_feature{i}_p1'] = features[f'feature{i}'].values features[f'tf_feature{i}_p2'] = features[f'feature{i}'].row_splits return advanced_parse(features) @tf.function def input_tensors(): features = {} for i in range(num_features): features[f'feature{i}'] = tf.random.uniform((tf.random.uniform((), 1, 25, dtype=tf.int32),), 0, fid_max_val, dtype=tf.int64) features['label'] = tf.cast(tf.random.uniform((), 0, 2, dtype=tf.int32), tf.float32) return features return tf.data.experimental.Counter(0, 1).map(lambda _: input_tensors(), tf.data.AUTOTUNE).\ apply(tf.data.experimental.dense_to_ragged_batch(batch_size, True)).\ map(decomp_func, tf.data.AUTOTUNE).prefetch(tf.data.AUTOTUNE) def model_fn(self, features, mode): with tf.device("/device:GPU:0"): for f_name in feature_names: self.create_embedding_feature_column(f_name, occurrence_threshold=1) tf_embeddings = [lookup_tf_embedding(features, f_name, emb_dim) for f_name in feature_names] embeddings = self.lookup_embedding_slice(features=feature_names, slice_name='vec', slice_dim=emb_dim, **deep) # if mode == tf.estimator.ModeKeys.PREDICT: # return tf.estimator.EstimatorSpec(mode, predictions=tf.constant(0)) embed_concat = tf.concat(embeddings, axis=1) tf_embed_concat = tf.concat(tf_embeddings, axis=1) assert_op = tf.compat.v1.assert_equal(embed_concat, tf_embed_concat) with tf.compat.v1.control_dependencies([assert_op]): # pred = layers.MLP( # output_dims=[64, 32, 1], activations='relu')(embed_concat) pred = tf.keras.Sequential([ tf.keras.layers.Dense(64, activation="relu"), tf.keras.layers.Dense(32, activation="relu"), tf.keras.layers.Dense(1) ])(embed_concat) tf_pred = tf.keras.Sequential([ tf.keras.layers.Dense(64, activation="relu"), tf.keras.layers.Dense(32, activation="relu"), tf.keras.layers.Dense(1) ])(tf_embed_concat) label = features['label'] loss = tf.reduce_mean(tf.losses.mean_squared_error(pred, label)) tf_loss = tf.reduce_mean(tf.losses.mean_squared_error(tf_pred, label)) optimizer = tf.compat.v1.train.AdagradOptimizer(0.05) # with tf.device("/device:CPU:0"): # loss = tf.compat.v1.Print(loss, [loss], 'monolith loss') # tf_loss = tf.compat.v1.Print(tf_loss, [tf_loss], 'tf loss') return EstimatorSpec( loss=loss + tf_loss, pred=[pred, tf_pred], label=[label, label], head_name=['monolith', 'tf'], classification=[False, False], optimizer=optimizer ) def serving_input_receiver_fn(self): pass class CpuSyncTrainTest(tf.test.TestCase): def _create_config(self, gpu, multi_hash_table): return cpu_training.DistributedCpuTrainingConfig( # save_checkpoints_steps=10000, num_ps=0, num_workers=hvd.size(), model_dir=f'/tmp/{getpass.getuser()}/monolith_test/{int(time.time())}', reorder_fids_in_data_pipeline=True, embedding_prefetch_capacity=0, enable_sync_training=True, enable_gpu_training=gpu, enable_realtime_training=False, use_native_multi_hash_table=multi_hash_table, index=hvd.rank(), ) def test_embedding_update(self): hvd.init() p = EmbeddingUpdateTask.params().instantiate() config = self._create_config(False, False) cpu_training.distributed_sync_train(config, p) config = self._create_config(False, True) cpu_training.distributed_sync_train(config, p) if test_util.is_gpu_available(cuda_only=True): config = self._create_config(True, False) cpu_training.distributed_sync_train(config, p) config = self._create_config(True, True) cpu_training.distributed_sync_train(config, p) if __name__ == "__main__": tf.compat.v1.disable_eager_execution() tf.test.main() ================================================ FILE: monolith/native_training/model_dump/BUILD ================================================ load("@rules_python//python:defs.bzl", "py_binary", "py_library", "py_test") load("@com_google_protobuf//:protobuf.bzl", "py_proto_library") package(default_visibility = ["//visibility:public"]) py_proto_library( name = "monolith_model_py_proto", srcs = ["monolith_model.proto"], srcs_version = "PY2AND3", visibility = ["//visibility:public"], deps = [ "//monolith/native_training/runtime/hash_table/compressor:float_compressor_py_proto", "//monolith/native_training/runtime/hash_table/initializer:initializer_config_py_proto", "//monolith/native_training/runtime/hash_table/optimizer:optimizer_py_proto", ], ) py_library( name = "graph_utils", srcs = ["graph_utils.py"], deps = [ "//idl:line_id_py_proto", "//monolith/native_training:utils", "@org_tensorflow//tensorflow:tensorflow_py", ], ) py_library( name = "dump_utils", srcs = ["dump_utils.py"], deps = [ ":graph_utils", ":monolith_model_py_proto", "//monolith/native_training/data:feature_list", "//monolith/native_training/data:item_pool_hook", "//monolith/native_training/data:parsers_py", ], ) ================================================ FILE: monolith/native_training/model_dump/dump_utils.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from absl import logging import tensorflow as tf import copy import pickle from io import BytesIO from inspect import signature, Parameter from typing import Union, Dict, List, Optional, Any, Set from tensorflow.python.data.ops import dataset_ops from tensorflow_estimator.python.estimator import util from tensorflow.core.framework import variable_pb2 from tensorflow.python.ops.variables import Variable from tensorflow.python.framework import ops SaveSliceInfo = Variable.SaveSliceInfo from monolith.native_training import entry from monolith.native_training.model_dump.monolith_model_pb2 import ProtoModel, ModelDump, \ HashTableConfig, FeatureSliceDim, Combiner, FeatureCombiner from monolith.native_training.model_dump.graph_utils import DatasetInitHook, GraphDefHelper, \ DRY_RUN, _node_name from monolith.native_training.data.utils import get_slot_feature_name from monolith.native_training.embedding_combiners import ReduceMean, ReduceSum, FirstN from monolith.native_training.runtime.hash_table import embedding_hash_table_pb2 from monolith.native_training.data.parsers import get_default_parser_ctx from idl.matrix.proto.example_pb2 import OutConfig from monolith.native_training.data.datasets import POOL_KEY from monolith.native_training.model_export.export_context import get_current_export_ctx from monolith.native_training.data.item_pool_hook import ItemPoolSaveRestoreHook from monolith.native_training.data.feature_list import get_feature_name_and_slot class DumpUtils(object): _instance = None def __new__(cls, *agrs, **kwds): if cls._instance is None: cls._instance = object.__new__(cls) return cls._instance def __init__(self, enable: bool = False): if not hasattr(self, "enable"): self.enable = enable self._params: Set[str] = set() self._run_config = None self._user_params = [] self.train, self.train_graph = None, None self.infer, self.infer_graph = None, None self._ps_sub_model, self._dense_sub_model = {}, {} self._table_configs: List[HashTableConfig] = [] self._feature_slice_dims: List[FeatureSliceDim] = [] self._feature_combiners: List[FeatureCombiner] = [] def add_config(self, run_config: str): self._run_config = run_config def add_user_params(self, user_params: List): self._user_params = user_params @property def model_dump(self) -> ProtoModel: graph = tf.compat.v1.get_default_graph() if hasattr(graph, 'monolith_model_dump'): monolith_model_dump = graph.monolith_model_dump return monolith_model_dump else: setattr(graph, 'monolith_model_dump', ProtoModel()) return graph.monolith_model_dump def update_kwargs_with_default(self, func, kwargs): params = signature(func).parameters for key in kwargs: if (kwargs[key] is not None) or (key not in params) or (params[key].default == Parameter.empty): continue kwargs[key] = params[key].default def record_feature(self, func): def wraper(*args, **kwargs): self.update_kwargs_with_default(func, kwargs) if self.need_record: proto = self.model_dump.features.add() if args: params = signature(func).parameters.values() for p, value in zip(params, args): if p.name == 'self': continue try: if p.name == 'feature_name' and isinstance(value, int): feature_name, _ = get_feature_name_and_slot(value) setattr(proto, p.name, feature_name) else: setattr(proto, p.name, value) except Exception as e: logging.warning(f"{p.name} is not in proto, {e}") for key, value in kwargs.items(): try: if key == 'feature_name' and isinstance(value, int): feature_name, _ = get_feature_name_and_slot(value) setattr(proto, key, feature_name) elif value is not None: setattr(proto, key, value) except Exception as e: logging.warning(f"{key} is not in proto, func {func}, {e}") return func(*args, **kwargs) return wraper def record_slice(self, func): def warper(*args, **kwargs): self.update_kwargs_with_default(func, kwargs) if self.need_record: if kwargs.get('learning_rate_fn', None) is not None: raise Exception('for safety purpose learning_rate_fn is not allowed') proto = self.model_dump.emb_slices.add() if args: params = signature(func).parameters.values() for p, value in zip(params, args): if p.name == 'self': continue try: if value is not None: if p.name == 'features': proto.features = repr(value) elif p.name in {'initializer', 'optimizer', 'compressor'}: getattr(proto, p.name).CopyFrom(value.as_proto()) else: setattr(proto, p.name, value) except Exception as e: logging.warning(f"{p.name} is not in proto, {e}") for key, value in kwargs.items(): try: if value is not None: if key == 'features': proto.features = repr(value) elif key in {'initializer', 'optimizer', 'compressor'}: getattr(proto, key).CopyFrom(value.as_proto()) else: setattr(proto, key, value) except Exception as e: logging.warning(f"{key} is not in proto, func {func}, {e}") results = func(*args, **kwargs) if isinstance(results, (list, tuple)): for res in results: proto.output_tensor_names.append(res.name) else: proto.output_tensor_names.append(results.name) return results else: return func(*args, **kwargs) return warper def record_receiver(self, func): def warper(*args, **kwargs): self.update_kwargs_with_default(func, kwargs) receiver = func(*args, **kwargs) if self.need_record: proto = self.model_dump.serving_input_receiver_fn proto.parser_type = get_default_parser_ctx().parser_type for k, ts in receiver.features.items(): if isinstance(ts, tf.RaggedTensor): proto.features[k] = repr({ "values": ts.values.name, "row_splits": ts.row_splits.name, "is_ragged": True, }) else: if len(ts.shape) > 0: last_dim = ts.shape[-1] if not isinstance(last_dim, int): if hasattr(last_dim, 'value'): last_dim = last_dim.value else: last_dim = 0 proto.features[k] = repr({ "name": ts.name, "is_ragged": False, "dtype": ts.dtype, "last_dim": last_dim }) for k, ts in receiver.receiver_tensors.items(): proto.receiver_name[k] = ts.name return receiver return warper def record_params(self, model): if not self.need_record: return skip_attrs = { '_abc_impl', '_ctx', '_export_outputs', '_global_step', '_losses', '_occurrence_threshold', '_private_children', 'children', 'ctx', 'fc_dict', 'fs_dict', 'losses', 'slice_dict', '_layout_dict', '_training_hooks' } for attr_name in dir(model): if attr_name.startswith('__') or attr_name in skip_attrs: continue attr = getattr(model, attr_name) if callable(attr): continue self._params.add(attr_name) def get_params_bytes(self, model) -> bytes: if not self.need_record: return params = {'_layout_dict'} | self._params model_params = {} for attr_name in params: attr = getattr(model, attr_name) if attr_name == 'p': params = copy.deepcopy(model.p) params.cls = None model_params['p'] = params elif attr_name == '_layout_dict': if attr: model_params[attr_name] = { name: out_cfg.SerializeToString() for name, out_cfg in attr.items() } else: model_params[attr_name] = attr else: model_params[attr_name] = attr f = BytesIO() pickle.dump(model_params, f) return f.getvalue() @classmethod def add_signature(cls, proto_model, graph: tf.Graph): export_ctx = get_current_export_ctx() if export_ctx: for signature in export_ctx.signatures(graph): signature_proto = proto_model.signature.add() signature_proto.name = signature.name for ip_key, value in signature.inputs.items(): signature_proto.inputs[ip_key] = value.name for op_key, value in signature.outputs.items(): signature_proto.outputs[op_key] = value.name @classmethod def restore_signature(cls, proto_model, graph: tf.Graph): export_ctx = get_current_export_ctx() if export_ctx: for signature in proto_model.signature: name = signature.name inputs = { ip_key: graph.get_tensor_by_name(value) for ip_key, value in signature.inputs.items() } outputs = { op_key: graph.get_tensor_by_name(value) for op_key, value in signature.outputs.items() } export_ctx.add_signature(graph, name, inputs, outputs) def add_model_fn(self, model, mode: str, features: Dict[str, tf.Tensor], label: Union[tf.Tensor, List[tf.Tensor], Dict[str, tf.Tensor]], loss: Optional[tf.Tensor], pred: Union[tf.Tensor, List[tf.Tensor], Dict[str, tf.Tensor]], head_name: Union[str, List[str]], is_classification: Union[bool, List[bool]]): if not self.need_record: return model_dump = self.model_dump model_dump.params = self.get_params_bytes(model) model_fn = model_dump.model_fn if label is not None: if isinstance(label, (tuple, list)): model_fn.label.extend(['' if t is None else t.name for t in label]) elif isinstance(label, dict): for value in label.values(): model_fn.label.append('' if value is None else value.name) else: model_fn.label.append(label.name) if loss is not None: model_fn.loss = loss.name if isinstance(pred, (tuple, list)): model_fn.predict.extend([t.name for t in pred if t is not None]) elif isinstance(pred, dict): for value in pred.values(): model_fn.predict.append(value.name) else: model_fn.predict.append(pred.name) if head_name: if isinstance(head_name, str): model_fn.head_name.append(head_name) else: assert isinstance(head_name, (list, tuple)) model_fn.head_name.extend(head_name) else: if label is not None and isinstance(label, dict): model_fn.head_name.extend(list(label.keys())) if isinstance(pred, dict): model_fn.head_name.extend(list(pred.keys())) if is_classification is not None: logging.info("dumped is_classification {}".format(is_classification)) if isinstance(is_classification, bool): model_fn.classification.append(is_classification) else: assert isinstance(is_classification, list) model_fn.classification.extend(is_classification) summaries = [x.op.name for x in ops.get_collection(ops.GraphKeys.SUMMARIES)] if len(summaries) > 0: logging.info("dumped user summaries {}".format(summaries)) model_fn.summary.extend(summaries) regged_features = {fc.feature_name for fc in self.model_dump.features} for name, ts in features.items(): if name not in regged_features and not isinstance(ts, tf.RaggedTensor): model_fn.non_ragged_features[name] = ts.name graph = tf.compat.v1.get_default_graph() extra_losses = model_fn.extra_losses for ts in getattr(graph, '__losses', []): extra_losses.append(ts.name) export_outputs = getattr(graph, '__export_outputs', {}) if export_outputs: for name, predict_output in export_outputs.items(): extra_output = self.model_dump.extra_output.add() extra_output.signature_name = name outputs = predict_output.outputs if isinstance(outputs, dict): for key, ts in outputs.items(): extra_output.fetch_dict[key] = ts.name else: extra_output.fetch_dict[ts.name] = ts.name training_hooks = getattr(graph, '__training_hooks', []) if training_hooks: if len(training_hooks) == 1 and isinstance(training_hooks[0], ItemPoolSaveRestoreHook): pass else: raise Exception('For safety purpose, customer hooks is not allowed!') self.add_signature(model_dump, graph) variables = tf.compat.v1.all_variables() if variables: for v in variables: save_slice_info = v._get_save_slice_info() if save_slice_info is not None: save_slice_info_bytes = save_slice_info.to_proto().SerializeToString() model_dump.save_slice_info[_node_name(v.name)] = save_slice_info_bytes if hasattr(graph, 'monolith_model_dump'): graph_def = graph.as_graph_def() if mode == tf.estimator.ModeKeys.TRAIN: self.train = graph.monolith_model_dump self.train_graph = copy.deepcopy(graph_def) else: self.infer = graph.monolith_model_dump self.infer_graph = copy.deepcopy(graph_def) def add_input_fn(self, results: Dict[str, Union[tf.Tensor, tf.RaggedTensor]]): if not self.need_record: return input_fn = self.model_dump.input_fn if isinstance(results, (list, tuple)): features, label = results[0], results[1] else: features, label = results, None assert isinstance(features, dict) for key, ts in features.items(): if isinstance(ts, tf.RaggedTensor): input_fn.output_features[key] = repr({ "name": ts.values.name.split(':')[0], "is_ragged": True, }) else: input_fn.output_features[key] = repr({ "name": ts.name, "is_ragged": False }) if label is not None: input_fn.label = label.name input_fn.parser_type = get_default_parser_ctx().parser_type pools = tf.compat.v1.get_collection(POOL_KEY) if pools: assert len(pools) == 1 input_fn.item_pool = pools[0].name def add_sub_model(self, sub_model_type: str, name: str, graph: tf.Graph): if not self.need_record: return assert sub_model_type in {'ps', 'dense'} if sub_model_type == 'ps' and name in self._ps_sub_model: return if sub_model_type == 'dense' and name in self._dense_sub_model: return logging.info(f'add_sub_model: {sub_model_type}-{name}') proto = ProtoModel() graph_def = graph.as_graph_def() proto.graph_def = graph_def.SerializeToString() self.add_signature(proto, graph) if sub_model_type == 'ps': self._ps_sub_model[name] = proto elif sub_model_type == 'dense': self._dense_sub_model[name] = proto else: raise Exception('sub_model error!') def restore_sub_model(self, sub_model_type: str): export_ctx = get_current_export_ctx() if export_ctx: if sub_model_type == 'ps': for name, sub_model in self._ps_sub_model.items(): with export_ctx.sub_graph(name).as_default() as g: sub_graph = tf.compat.v1.GraphDef() sub_graph.ParseFromString(sub_model.graph_def) tf.import_graph_def(sub_graph, name="") self.restore_signature(sub_model, g) logging.info(f'restore_sub_model: {sub_model_type}-{name}') elif sub_model_type == 'dense': for name, sub_model in self._dense_sub_model.items(): with export_ctx.dense_sub_graph(name).as_default() as g: sub_graph = tf.compat.v1.GraphDef() sub_graph.ParseFromString(sub_model.graph_def) tf.import_graph_def(sub_graph, name="") self.restore_signature(sub_model, g) logging.info(f'restore_sub_model: {sub_model_type}-{name}') else: raise Exception(f'sub_model_type: {sub_model_type} error ') def add_optimizer(self, optimizer): if not self.need_record: return f = BytesIO() pickle.dump(optimizer, f) value = f.getvalue() self.model_dump.optimizer = value def dump(self, fname: str): if not self.enable: return md = ModelDump() if self._run_config: md.run_config = self._run_config if self._user_params: logging.info("xxx dump to md with user_params {}".format( self._user_params)) md.user_params.extend(self._user_params) md.model_dump['train'].CopyFrom(self.train) md.model_dump['train'].graph_def = self.train_graph.SerializeToString() if self.infer is not None: md.model_dump['infer'].CopyFrom(self.infer) md.model_dump['infer'].graph_def = self.infer_graph.SerializeToString() for name, sub_model in self._ps_sub_model.items(): if sub_model is not None: print('dump ps_sub_model: ', name, flush=True) md.ps_sub_model_dump[name].CopyFrom(sub_model) for name, sub_model in self._dense_sub_model.items(): if sub_model is not None: print('dump dense_sub_model: ', name, flush=True) md.dense_sub_model_dump[name].CopyFrom(sub_model) for table_config in self._table_configs: md.table_configs.add().CopyFrom(table_config) for feature_slice_dim in self._feature_slice_dims: md.feature_slice_dims.add().CopyFrom(feature_slice_dim) for feature_combiner in self._feature_combiners: md.feature_combiners.add().CopyFrom(feature_combiner) with tf.io.gfile.GFile(fname, 'wb') as ostream: ostream.write(file_content=md.SerializeToString()) def load(self, fname: str): with tf.io.gfile.GFile(fname, 'rb') as ostream: md = ModelDump() md.ParseFromString(ostream.read()) self._run_config = md.run_config self.train = md.model_dump['train'] train_graph = tf.compat.v1.GraphDef() train_graph.ParseFromString(self.train.graph_def) self.train_graph = train_graph self.train.graph_def = b'' if 'infer' in md.model_dump: self.infer = md.model_dump['infer'] infer_graph = tf.compat.v1.GraphDef() infer_graph.ParseFromString(self.infer.graph_def) self.infer_graph = infer_graph self.infer.graph_def = b'' else: self.infer = None self.infer_graph = None self._table_configs.extend(md.table_configs) self._feature_slice_dims.extend(md.feature_slice_dims) self._feature_combiners.extend(md.feature_combiners) self._user_params.extend(md.user_params) logging.info("xxx load from md with user_params {}".format( self._user_params)) for name, sub_model in md.ps_sub_model_dump.items(): self._ps_sub_model[name] = sub_model for name, sub_model in md.dense_sub_model_dump.items(): self._dense_sub_model[name] = sub_model def get_proto_model(self, mode: str = tf.estimator.ModeKeys.TRAIN) -> ProtoModel: graph = tf.compat.v1.get_default_graph() if mode in (tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL): return self.train else: return self.infer def get_graph_helper(self, mode: str) -> GraphDefHelper: graph = tf.compat.v1.get_default_graph() if hasattr(graph, 'graph_def_helper'): return graph.graph_def_helper else: save_slice_info = {} if mode in (tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL): for name, info_bytes in self.train.save_slice_info.items(): save_slice_info_def = variable_pb2.SaveSliceInfoDef() save_slice_info_def.ParseFromString(info_bytes) save_slice_info[name] = SaveSliceInfo( save_slice_info_def=save_slice_info_def) helper = GraphDefHelper(self.train_graph, save_slice_info) else: for name, info_bytes in self.infer.save_slice_info.items(): save_slice_info_def = variable_pb2.SaveSliceInfoDef() save_slice_info_def.ParseFromString(info_bytes) save_slice_info[name] = SaveSliceInfo( save_slice_info_def=save_slice_info_def) helper = GraphDefHelper(self.infer_graph, save_slice_info) setattr(graph, 'graph_def_helper', helper) return helper def restore_params(self) -> Dict[str, Any]: params = self.train.params if params is None or len(params) == 0: return None f = BytesIO(params) model_params = pickle.load(f) layout_dict = model_params.get('_layout_dict') if layout_dict: layout_dict_tmp = {} for name, out_cfg_str in layout_dict.items(): out_cfg = OutConfig() out_cfg.ParseFromString(out_cfg_str) layout_dict_tmp[name] = out_cfg layout_dict.update(layout_dict_tmp) else: raise Exception('layout_dict is empty') if '_training_hooks' in model_params: del model_params['_training_hooks'] return model_params def get_config(self): return self._run_config def get_user_params(self): return self._user_params @property def need_record(self) -> bool: graph = tf.compat.v1.get_default_graph() return self.enable and not hasattr(graph, DRY_RUN) @property def table_configs(self) -> Dict[str, entry.HashTableConfigInstance]: result = {} for tcfg in self._table_configs: table_config = embedding_hash_table_pb2.EmbeddingHashTableConfig() table_config.ParseFromString(tcfg.table_config) result[tcfg.name] = entry.HashTableConfigInstance( table_config, list(tcfg.learning_rates), list(tcfg.extra_restore_names)) return result @table_configs.setter def table_configs(self, table_confs: Dict[str, entry.HashTableConfigInstance]): if table_confs: assert isinstance(table_confs, dict) self._table_configs.clear() for name, tcfg in table_confs.items(): hash_table_config = HashTableConfig( name=name, table_config=tcfg._table_config.SerializeToString()) if tcfg.extra_restore_names: hash_table_config.extra_restore_names.extend(tcfg.extra_restore_names) if tcfg.learning_rate_fns: if all(isinstance(lr, (float, int)) for lr in tcfg.learning_rate_fns): hash_table_config.learning_rates.extend(tcfg.learning_rate_fns) else: raise Exception('learning_rate_fn is not support!') self._table_configs.append(hash_table_config) @property def feature_slice_dims(self) -> Dict[str, List[int]]: slice_dims: Dict[str, List[int]] = {} for fsd in self._feature_slice_dims: slice_dims[fsd.name] = list(fsd.dims) return slice_dims @feature_slice_dims.setter def feature_slice_dims(self, slice_dims: Dict[str, List[int]]): if slice_dims: assert isinstance(slice_dims, dict) self._feature_slice_dims.clear() for name, dims in slice_dims.items(): fsd = FeatureSliceDim(name=name) if dims: fsd.dims.extend(dims) self._feature_slice_dims.append(fsd) @property def feature_combiners(self): fcombs = {} for fcomb in self._feature_combiners: if fcomb.combiner == Combiner.ReduceSum: fcombs[fcomb.name] = ReduceSum() elif fcomb.combiner == Combiner.ReduceMean: fcombs[fcomb.name] = ReduceMean() else: fcombs[fcomb.name] = FirstN(seq_length=fcomb.max_seq_length) return fcombs @feature_combiners.setter def feature_combiners(self, fcombs): if fcombs: assert isinstance(fcombs, dict) self._feature_combiners.clear() for name, fcomb in fcombs.items(): fc_proto = FeatureCombiner(name=name, max_seq_length=fcomb.max_seq_length) if isinstance(fcomb, ReduceSum): fc_proto.combiner = Combiner.ReduceSum elif isinstance(fcomb, ReduceMean): fc_proto.combiner = Combiner.ReduceMean else: fc_proto.combiner = Combiner.FirstN self._feature_combiners.append(fc_proto) def get_slot_to_occurrence_threshold(self, mode: str = 'train'): if mode == tf.estimator.ModeKeys.TRAIN: model_dump = self.train else: model_dump = self.infer slot_to_ot = {} for feature in model_dump.features: feature_name, slot = get_feature_name_and_slot(feature.feature_name) if slot is not None: slot_to_ot[slot] = feature.occurrence_threshold else: logging.warning( "feature[{}] slot is None. pls check feature_list.conf".format( feature_name)) return slot_to_ot def get_slot_to_expire_time(self, mode: str = 'train'): if mode == tf.estimator.ModeKeys.TRAIN: model_dump = self.train else: model_dump = self.infer slot_to_et = {} for feature in model_dump.features: feature_name, slot = get_feature_name_and_slot(feature.feature_name) if slot is not None: slot_to_et[slot] = feature.expire_time else: logging.warning( "feature[{}] slot is None. pls check feature_list.conf".format( feature_name)) return slot_to_et @property def has_collected(self) -> bool: if self._table_configs and self._feature_slice_dims and self._feature_combiners: return True else: assert self._table_configs is None or len(self._table_configs) == 0 assert self._feature_slice_dims is None or len( self._feature_slice_dims) == 0 assert self._feature_combiners is None or len( self._feature_combiners) == 0 return False def parse_input_fn_result(result): input_hooks = [] if isinstance(result, dataset_ops.DatasetV2): iterator = dataset_ops.make_initializable_iterator(result) input_hooks.append(util._DatasetInitializerHook(iterator)) result = iterator.get_next() else: initializer = tf.compat.v1.get_collection('mkiter') if isinstance(initializer, (list, tuple)) and len(initializer) > 0: initializer = initializer[0] assert initializer is not None input_hooks.append(DatasetInitHook(initializer)) DumpUtils().add_input_fn(result) return util.parse_iterator_result(result) + (input_hooks,) util.parse_input_fn_result = parse_input_fn_result ================================================ FILE: monolith/native_training/model_dump/graph_utils.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os, re, time import six, copy, time from absl import logging, flags from inspect import signature import pickle from io import BytesIO from typing import Dict, Any, Optional, Union, List, Set import tensorflow as tf from tensorflow.keras import layers from tensorflow.keras import models from google.protobuf import text_format from tensorflow.keras import initializers from tensorflow.python.framework import ops from tensorflow.python.ops.variables import Variable SaveSliceInfo = Variable.SaveSliceInfo from idl.matrix.proto.line_id_pb2 import LineId from idl.matrix.proto.example_pb2 import FeatureConfigs from tensorflow.python.ops.variables import PartitionedVariable from tensorflow.python.ops.gen_resource_variable_ops import read_variable_op from tensorflow.python.ops.ragged.row_partition import RowPartition from tensorflow.python.framework import ops from monolith.native_training.utils import add_to_collections, get_collection DRY_RUN = 'dry_run' FLAGS = flags.FLAGS class DatasetInitHook(tf.compat.v1.train.SessionRunHook): def __init__(self, initializer): self._initializer = initializer def after_create_session(self, session, coord): del coord session.run(self._initializer) def _node_name(name): if name.startswith("^"): return name[1:] else: return name.split(":")[0] def _colocated_node_name(name): """Decodes colocated node name and returns it without loc:@ prepended.""" colocated_node_decoded = name.decode("utf-8") if colocated_node_decoded.startswith("loc:@"): return colocated_node_decoded[5:] return colocated_node_decoded class EchoInitializer(tf.keras.initializers.Initializer): def __init__(self, init_value): self._init_value = init_value def __call__(self, shape, dtype=None, **kwargs): if isinstance(self._init_value, (list, tuple)): assert len(self._init_value) == 1 init_value = self._init_value[0] else: init_value = self._init_value if isinstance(init_value, tf.Tensor): return init_value else: assert len(init_value.outputs) == 1 return init_value.outputs[0] class VariableDef(object): def __init__(self, node: tf.compat.v1.NodeDef = None, helper: 'GraphDefHelper' = None): self.node = node self._helper = helper self._name_to_node = helper.name_to_node self._read_ops: List[tf.compat.v1.NodeDef] = [] self._variable = None self._initializer = None @property def dtype(self): return tf.dtypes.DType(self.node.attr['dtype'].type) @property def shape(self): return tuple(dim.size for dim in self.node.attr['shape'].shape.dim) @property def device(self): return self.node.device @property def name(self): return _node_name(self.node.name) @property def initializer(self): if self._initializer is None: assign = _node_name(f'{self.node.name}/Assign') assert assign in self._name_to_node assign_node = self._name_to_node[assign] assert len(assign_node.input) == 2 initializer = None vname = _node_name(self.node.name) for name in assign_node.input: if _node_name(name) != vname: initializer = name break assert initializer is not None sub_graph, _ = self._helper.sub_graph(dest_nodes=[initializer], source_nodes=None, with_library=False) init_ops = tf.compat.v1.import_graph_def(sub_graph, return_elements=[initializer], name="") self._initializer = EchoInitializer(init_ops) return self._initializer @property def variable(self): if self._variable is None: vs = tf.compat.v1.get_variable_scope() partitioner = vs._partitioner vs._partitioner = None if self.device is not None and len(self.device) > 0: with tf.compat.v1.device(self.device): self._variable = tf.compat.v1.get_variable( dtype=self.dtype, shape=self.shape, initializer=self.initializer, name=self.node.name) else: self._variable = tf.compat.v1.get_variable(dtype=self.dtype, shape=self.shape, initializer=self.initializer, name=self.node.name) vs._partitioner = partitioner if isinstance(self._variable, PartitionedVariable): self._variable = self._variable._variable_list[0] return self._variable def add_read(self, node: tf.compat.v1.NodeDef): assert node.op == 'ReadVariableOp' self._read_ops.append(node) @property def read_ops(self): return self._read_ops class PartitionVariableDef(object): PVName = re.compile(r'^(.*)/part_\d+$') def __init__(self, base_name: str, helper: 'GraphDefHelper' = None): self.base_name = base_name self._helper = helper self._name_to_node = helper.name_to_node self._partitions: Dict[str, tf.compat.v1.NodeDef] = {} self._read_ops: Dict[str, List[tf.compat.v1.NodeDef]] = {} self._variable = None self._initializer = None self._partitioned_variable = None self._save_slice_info = self._helper._save_slice_info def add_partition(self, node: tf.compat.v1.NodeDef): assert node.op == 'VarHandleOp' self._partitions[_node_name(node.name)] = node def add_read(self, node: tf.compat.v1.NodeDef): assert node.op == 'ReadVariableOp' name = _node_name(node.input[0]) if name in self._read_ops: self._read_ops[name].append(node) else: self._read_ops[name] = [node] @classmethod def get_base_name(cls, node: tf.compat.v1.NodeDef) -> Optional[str]: name, op = node.name, node.op if op == "VarHandleOp": matched = cls.PVName.match(name) if matched: return matched.group(1) elif op == "ReadVariableOp": inputs = [name for name in node.input if not name.startswith('^')] assert len(inputs) == 1 matched = cls.PVName.match(inputs[0]) if matched: return matched.group(1) return None @property def dtype(self): return [ tf.dtypes.DType(node.attr['dtype'].type) for node in self._partitions.values() ] @property def shape(self): return [ tuple(dim.size for dim in node.attr['shape'].shape.dim) for node in self._partitions.values() ] @property def device(self): return [node.device for node in self._partitions.values()] @property def initializer(self): if self._initializer is None: if len(self._partitions) > 1: dest_nodes = [ _node_name(f'{pname}/PartitionedInitializer/Slice') for pname in self._partitions ] else: node_name = _node_name(f'{next(iter(self._partitions))}') slice_node_name = f'{node_name}/PartitionedInitializer/Slice' if slice_node_name in self._name_to_node: dest_nodes = [slice_node_name] else: assign_node_name = f'{node_name}/Assign' assert assign_node_name in self._name_to_node assign_node = self._name_to_node[assign_node_name] assert len(assign_node.input) == 2 initializer = None for name in assign_node.input: if _node_name(name) != node_name: initializer = name break assert initializer is not None dest_nodes = [initializer] sub_graph, _ = self._helper.sub_graph(dest_nodes=dest_nodes, source_nodes=None, with_library=False) init_ops = tf.compat.v1.import_graph_def(sub_graph, return_elements=dest_nodes, name="") self._initializer = [EchoInitializer(init_op) for init_op in init_ops] return self._initializer @property def variable(self): if self._variable is None: self._variable = {} dtypes, shapes, inits = self.dtype, self.shape, self.initializer group_device = None vs = tf.compat.v1.get_variable_scope() partitioner = vs._partitioner vs._partitioner = None for i, (name, device) in enumerate(zip(self._partitions, self.device)): group_device = device if i == 0 else group_device if group_device is not None and len(group_device) > 0: logging.warning("variable[{}] use group_device[{}]".format(name, group_device)) with tf.compat.v1.device(None): with tf.compat.v1.device(group_device): variable = tf.compat.v1.get_variable(dtype=dtypes[i], shape=shapes[i], initializer=inits[i], name=name) if isinstance(variable, PartitionedVariable): variable = variable._variable_list[0] else: save_slice_info = self._save_slice_info.get(name) variable._set_save_slice_info(save_slice_info) else: variable = tf.compat.v1.get_variable(dtype=dtypes[i], shape=shapes[i], initializer=inits[i], name=name) if isinstance(variable, PartitionedVariable): variable = variable._variable_list[0] else: save_slice_info = self._save_slice_info.get(name) variable._set_save_slice_info(save_slice_info) self._variable[name] = variable vs._partitioner = partitioner # make PartitionedVariable, to check save_slice_info if len(self._variable) > 1: names = sorted(self._variable, key=lambda x: int(_node_name(x).rsplit('_')[-1])) name = names[0].rsplit('/', maxsplit=1)[0] partitions = [ len(shapes) if i == 0 else 1 for i, d in enumerate(shapes[0]) ] first_dim = sum(s[0] for s in shapes) if len(shapes[0]) > 1: shape = [first_dim] + list(shapes[0][1:]) else: shape = [first_dim] self._partitioned_variable = PartitionedVariable( name=name, shape=shape, dtype=dtypes[0], variable_list=[self._variable[name] for name in names], partitions=partitions) return self._variable def read_ops(self, pname: str) -> List[tf.compat.v1.NodeDef]: return self._read_ops[pname] class GraphDefHelper(object): def __init__(self, graph_def: tf.compat.v1.GraphDef, save_slice_info: Dict[str, SaveSliceInfo]): if not isinstance(graph_def, tf.compat.v1.GraphDef): raise TypeError("graph_def must be a graph_pb2.GraphDef proto.") self.graph_def = graph_def self.name_to_vardef: Dict[str, Union[VariableDef, PartitionVariableDef]] = {} self.name_to_input_name: Dict[str, Set[str]] = {} # Keyed by the dest node name. self.name_to_node: Dict[str, tf.compat.v1.NodeDef] = {} self.seq_num_to_node: Dict[int, tf.compat.v1.NodeDef] = {} self.name_to_seq_num: Dict[str, int] = {} self._save_slice_info = save_slice_info self._file_name = None seq = 0 for node in graph_def.node: node.device = b'' name = _node_name(node.name) self.name_to_input_name[name] = set(_node_name(x) for x in node.input) if "_class" in node.attr: for colocated_node_name in node.attr["_class"].list.s: self.name_to_input_name[name].add( _colocated_node_name(colocated_node_name)) del node.attr["_class"] self.name_to_node[name] = node self.name_to_seq_num[name] = seq self.seq_num_to_node[seq] = node seq += 1 if node.name == "PBDataset/file_name" and node.op == "Const": self._file_name = node stop_names = {'global_step', 'WorkerCkptMetaInfo'} if node.op == "VarHandleOp" and name not in stop_names: base_name = PartitionVariableDef.get_base_name(node) if base_name is not None: if base_name in self.name_to_vardef: self.name_to_vardef[base_name].add_partition(node) else: pvd = PartitionVariableDef(base_name, self) pvd.add_partition(node) self.name_to_vardef[base_name] = pvd else: if name in self.name_to_vardef: if self.name_to_vardef[name].node is None: self.name_to_vardef[name].node = node else: logging.info("maybe error, because node is not None") else: self.name_to_vardef[name] = VariableDef(node, self) if node.op == "ReadVariableOp": inputs = [name for name in node.input if not name.startswith('^')] assert len(inputs) == 1 vname = inputs[0] if vname in stop_names: continue base_name = PartitionVariableDef.get_base_name(node) if base_name is not None: if base_name in self.name_to_vardef: self.name_to_vardef[base_name].add_read(node) else: pvd = PartitionVariableDef(base_name, self) pvd.add_read(node) self.name_to_vardef[base_name] = pvd else: base_name = _node_name(node.input[0]) if base_name in self.name_to_vardef: self.name_to_vardef[base_name].add_read(node) else: dummy = VariableDef(None, self) dummy.add_read(node) self.name_to_vardef[base_name] = dummy @property def library(self): return self.graph_def.library @property def versions(self): return self.graph_def.versions @classmethod def _check_invalidate_node(clz, graph_def: tf.compat.v1.GraphDef, input_map: Dict[str, tf.Tensor]): if input_map is None or len(input_map) == 0: return exists = set() for node in graph_def.node: for ts_name in node.input: if ts_name.startswith('^'): ts_name = ts_name[1:] exists.add(ts_name) if ":" not in ts_name: exists.add(f'{ts_name}:0') invalidate = set(input_map.keys()) - exists for name in invalidate: del input_map[name] logging.warning(f"{name} is not used in model") def _create_variables(self, variables: Set[str]) -> Dict[str, tf.Tensor]: vread_map = {} graph = tf.compat.v1.get_default_graph() for vardef in self.name_to_vardef.values(): if isinstance(vardef, PartitionVariableDef): # remove variable that outside the graph if len(set(vardef._partitions.keys()) - variables) != 0: continue part_var = vardef.variable for pname, v in part_var.items(): # v._handle -> Tensor("dense/kernel/part_0:0", shape=(), dtype=resource) # v.value() -> Tensor("ReadVariableOp:0", shape=(48, 512), dtype=float32) # v.read_value() -> Tensor("Identity:0", shape=(48, 512), dtype=float32) # graph.get_tensor_by_name(f'{pname}:0') -> Tensor("dense/kernel/part_0:0", # shape=(), dtype=resource) for reader in vardef.read_ops(pname): if reader.name == f'{reader.input[0]}/Read/ReadVariableOp': continue vread_map[reader.name] = read_variable_op(resource=v._handle, dtype=v.dtype, name=_node_name( reader.name)) else: # remove variable that outside the graph if vardef.name not in variables: continue v = vardef.variable for reader in vardef.read_ops: if reader.name == f'{reader.input[0]}/Read/ReadVariableOp': continue vread_map[reader.name] = read_variable_op(resource=v._handle, dtype=v.dtype, name=_node_name( reader.name)) return vread_map def sub_graph(self, dest_nodes: List[str], source_nodes: Optional[List[str]] = None, with_library: bool = True): if isinstance(dest_nodes, six.string_types): raise TypeError("dest_nodes must be a list.") if source_nodes is not None: source_nodes = list(set([_node_name(sn) for sn in source_nodes])) for sn in source_nodes: assert sn in self.name_to_node, f"{sn} is not in graph" dest_nodes = list(set([_node_name(dn) for dn in dest_nodes])) for dn in dest_nodes: assert dn in self.name_to_node, f"{dn} is not in graph" # Breadth first search to find all the nodes that we should keep. nodes_to_keep = set() stop_nodes = set() if source_nodes is None else set(source_nodes) next_to_visit = dest_nodes[:] while next_to_visit: node = next_to_visit[0] del next_to_visit[0] if node in nodes_to_keep or node in stop_nodes: # Already visited/stop this node. continue nodes_to_keep.add(node) if node in self.name_to_input_name: next_to_visit += list(self.name_to_input_name[node]) nodes_to_keep_list = sorted(list(nodes_to_keep), key=lambda name: self.name_to_seq_num[name]) # Now construct the output GraphDef sub_gd = tf.compat.v1.GraphDef() variables = set() for n in nodes_to_keep_list: node = self.name_to_node[n] if node.op not in {"VarHandleOp", "ReadVariableOp"}: sub_gd.node.extend([copy.deepcopy(node)]) elif node.op == "VarHandleOp": variables.add(n) elif node.op == "ReadVariableOp": name = f'{node.input[0]}/Read/ReadVariableOp' if node.name != name: sub_gd.node.extend([copy.deepcopy(node)]) if with_library: func_names = set() for node in sub_gd.node: for key, value in node.attr.items(): if value.func is not None: name = value.func.name if name: func_names.add(name) for func in self.graph_def.library.function: if "Dataset" in func.signature.name or func.signature.name in func_names: ofunc = sub_gd.library.function.add() ofunc.CopyFrom(func) for node_def in ofunc.node_def: node_def.device = b'' # out.versions.CopyFrom(self.graph_def.versions) return sub_gd, variables def import_input_fn(self, input_conf, file_name: str): graph = tf.compat.v1.get_default_graph() dry_run: bool = hasattr(graph, DRY_RUN) dest_nodes = [] for feat_name, ts_repr in input_conf.output_features.items(): ts_dict = eval(ts_repr) dest_nodes.append(ts_dict['name']) if input_conf.label is not None and len(input_conf.label) > 0: dest_nodes.append(input_conf.label) if not dry_run: if "IteratorToStringHandle" in self.name_to_node: dest_nodes.append("IteratorToStringHandle") if "MakeIterator" in self.name_to_node: dest_nodes.append("MakeIterator") del self._file_name.attr['value'].tensor.string_val[:] file_name_bytes = bytes(file_name, encoding='utf-8') self._file_name.attr['value'].tensor.string_val.append(file_name_bytes) sub_graph, _ = self.sub_graph(dest_nodes=dest_nodes, source_nodes=None, with_library=True) data_type = getattr(FLAGS, 'data_type') logging.info(f"[INFO] using the data type {data_type}") if data_type: # replace pb dataset input pb type for node in sub_graph.node: if node.name == "PBDataset/input_pb_type": val_attr = node.attr['value'] # ValAttr val_ts = val_attr.WhichOneof('value') # tensor if val_ts == 'tensor': target_type = data_type.lower().encode() val_attr.tensor.string_val[0] = target_type logging.info(f"[INFO] using input_pb_type {val_attr.tensor.string_val[0]}") logging.info(f"[INFO] the pbdataset/input {node}") return_elements = tf.import_graph_def(sub_graph, input_map=None, return_elements=dest_nodes, name="") if not dry_run: if "IteratorToStringHandle" in self.name_to_node: idx = dest_nodes.index("IteratorToStringHandle") tf.compat.v1.add_to_collection("iterators", return_elements[idx]) if "MakeIterator" in self.name_to_node: idx = dest_nodes.index("MakeIterator") tf.compat.v1.add_to_collection("mkiter", return_elements[idx]) result = {} for i, (feat_name, ts_repr) in enumerate(input_conf.output_features.items()): ts_dict = eval(ts_repr) if ts_dict['is_ragged']: row_splits, values = return_elements[i].outputs assert ts_dict['name'] == values.name.split(':')[0] row_partition = RowPartition.from_row_splits(row_splits=tf.reshape( row_splits, shape=(-1,)), validate=False, preferred_dtype=None) result[feat_name] = tf.RaggedTensor(tf.reshape(values, shape=(-1,)), row_partition, internal=True) else: assert ts_dict['name'] == return_elements[i].name result[feat_name] = return_elements[i] if input_conf.label is not None and len(input_conf.label) > 0: idx = dest_nodes.index(input_conf.label) result['label'] = return_elements[idx] return result def import_model_fn(self, input_map: Dict[str, tf.Tensor], proto_model): source_nodes = list(input_map.keys()) if input_map else None model_fn = proto_model.model_fn outputs = list(model_fn.predict) if len(proto_model.extra_output) > 0: for extra_output in proto_model.extra_output: for ts_name in extra_output.fetch_dict.values(): node_name = _node_name(ts_name) full_name = ts_name if ':' in ts_name else f'{node_name}:0' if node_name not in outputs and full_name not in outputs: outputs.append(ts_name) if model_fn.loss is not None and len(model_fn.loss) > 0: outputs.append(model_fn.loss) if model_fn.label is not None and len(model_fn.label) > 0: outputs.extend([l for l in model_fn.label if l]) for extra_loss in model_fn.extra_losses: if extra_loss not in outputs: outputs.append(extra_loss) signature_input_names = [] if len(proto_model.signature) > 0: for signature in proto_model.signature: for ts_name in signature.inputs.values(): if ts_name not in signature_input_names: signature_input_names.append(ts_name) for ts_name in signature.outputs.values(): node_name = _node_name(ts_name) full_name = ts_name if ':' in ts_name else f'{node_name}:0' if node_name not in outputs and full_name not in outputs: outputs.append(ts_name) if len(model_fn.summary) > 0: logging.info("load user summaries {}".format(model_fn.summary)) outputs.extend(list(model_fn.summary)) summaries = model_fn.summary sub_graph, variables = self.sub_graph(dest_nodes=outputs, source_nodes=source_nodes, with_library=True) self._check_invalidate_node(sub_graph, input_map) vread_map = self._create_variables(variables) if input_map is not None and len(input_map) > 0: vread_map.update(input_map) # check input_map for import_graph_def nodes = {_node_name(node.name) for node in sub_graph.node} graph = tf.compat.v1.get_default_graph() if vread_map: ts_names = set() for node in sub_graph.node: for ip_ts_name in node.input: if _node_name(ip_ts_name) not in nodes: ts_names.add(ip_ts_name) vread_map = { key if key in ts_names else _node_name(key): value for key, value in vread_map.items() if key in ts_names or _node_name(key) in ts_names } unknown_input = ts_names - set(vread_map) if unknown_input: logging.info(f"Debug. unknown_input {unknown_input}") for ts_name in unknown_input: vread_map[ts_name] = graph.get_tensor_by_name( ts_name if ':' in ts_name else f'{ts_name}:0') else: vread_map = None # in case some output tensor not include in graph direct_out, real_out = {}, [] for op_ts_name in outputs: if _node_name(op_ts_name) not in nodes: try: direct_out[op_ts_name] = graph.get_tensor_by_name( op_ts_name if ':' in op_ts_name else f'{op_ts_name}:0') except: logging.warning(f'Cannot find {op_ts_name} in both graph and inputs') direct_out[op_ts_name] = None else: real_out.append(op_ts_name) real_result = tf.import_graph_def(sub_graph, input_map=vread_map, return_elements=real_out, name="") # collection sparse_feature for node in tf.compat.v1.get_default_graph().as_graph_def().node: if node.op.startswith("ShardingSparseFids"): feature_cfgs = FeatureConfigs() feature_cfgs.ParseFromString(node.attr["feature_cfgs"].s) sparse_features = get_collection("sparse_features") if sparse_features is None: sparse_features = [] else: sparse_features = sparse_features[-1] for feat_name, _ in feature_cfgs.feature_configs.items(): sparse_features.append(feat_name) sparse_features = list(set(sparse_features)) add_to_collections('sparse_features', sparse_features) real_result = {name: value for name, value in zip(real_out, real_result)} result = [real_result.get(name, direct_out.get(name)) for name in outputs] if len(model_fn.summary) > 0: for summary in summaries: for sum_ts in real_result.get(summary).outputs: # logging.info("[INFO] add summary {} to collection".format(sum_ts)) ops.add_to_collection(ops.GraphKeys.SUMMARIES, sum_ts) # check sig_input in graph for op_ts_name in signature_input_names: graph.get_tensor_by_name(op_ts_name if ':' in op_ts_name else f'{op_ts_name}:0') if model_fn.label is not None and len(model_fn.label) > 0: labels = [None] * len(model_fn.label) for i, label in enumerate(model_fn.label): if label: idx = outputs.index(label) labels[i] = result[idx] label = labels if len(labels) > 1 else labels[0] else: label = None if model_fn.loss is not None and len(model_fn.loss) > 0: idx = outputs.index(model_fn.loss) loss = result[idx] else: loss = None predict = result[:len(list(model_fn.predict))] if len(predict) == 1: predict = predict[0] if model_fn.head_name is not None and len(model_fn.head_name) > 0: if len(model_fn.head_name) == 1: head_name = model_fn.head_name[0] else: head_name = list(model_fn.head_name) else: head_name = None if model_fn.classification is not None: logging.info("load is_classificaiton {}".format(model_fn.classification)) if len(model_fn.classification) == 1: is_classification = model_fn.classification[0] else: is_classification = list(model_fn.classification) extra_output_dict = {} if len(proto_model.extra_output) > 0: for extra_output in proto_model.extra_output: real_extra = {} for key, ts_name in extra_output.fetch_dict.items(): idx = outputs.index(ts_name) real_extra[key] = result[idx] if len(extra_output.fetch_dict) == 1 and key == result[idx].name: extra_output_dict[extra_output.signature_name] = next( iter(real_extra.values())) else: extra_output_dict[extra_output.signature_name] = real_extra return label, loss, predict, head_name, extra_output_dict, is_classification def import_receiver_fn(self, receiver_conf): dest_nodes, sparse_features, dense_features, extra_features = [], [], [], [] dense_feature_shapes, dense_feature_types, extra_feature_shapes = [], [], [] parser_type = receiver_conf.parser_type for feat_name, ts_repr in receiver_conf.features.items(): ts_dict = eval(ts_repr) if ts_dict['is_ragged']: dest_nodes.append(ts_dict['values']) dest_nodes.append(ts_dict['row_splits']) sparse_features.append(feat_name) else: dest_nodes.append(ts_dict['name']) if hasattr(LineId, feat_name): extra_features.append(feat_name) extra_feature_shapes.append(ts_dict['last_dim']) else: dense_features.append(feat_name) dense_feature_types.append(ts_dict['dtype']) dense_feature_shapes.append(ts_dict['last_dim']) add_to_collections('sparse_features', sparse_features) add_to_collections('dense_features', dense_features) add_to_collections('dense_feature_shapes', dense_feature_shapes) add_to_collections('dense_feature_types', dense_feature_types) add_to_collections('extra_features', extra_features) add_to_collections('extra_feature_shapes', extra_feature_shapes) add_to_collections('variant_type', parser_type) num_feature_tensors = len(dest_nodes) for name, ph_name in receiver_conf.receiver_name.items(): dest_nodes.append(ph_name) sub_graph, _ = self.sub_graph(dest_nodes=dest_nodes, source_nodes=None, with_library=True) return_elements = tf.import_graph_def(sub_graph, input_map=None, return_elements=dest_nodes, name="") idx, features = 0, {} for feat_name, ts_repr in receiver_conf.features.items(): ts_dict = eval(ts_repr) if ts_dict['is_ragged']: values, row_splits = return_elements[idx], return_elements[idx + 1] features[feat_name] = tf.RaggedTensor.from_row_splits(values, row_splits, validate=False) idx += 2 else: features[feat_name] = return_elements[idx] idx += 1 receiver_tensors = {} for name in receiver_conf.receiver_name: receiver_tensors[name] = return_elements[idx] idx += 1 return features, receiver_tensors @classmethod def get_optimizer(cls, proto_model): ser_opt = proto_model.optimizer if ser_opt is not None and len(ser_opt) > 0: f = BytesIO(ser_opt) return pickle.load(f) else: return None ================================================ FILE: monolith/native_training/model_dump/graph_utils_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import tensorflow as tf from absl import flags from monolith.native_training.data.datasets import PBDataset from monolith.native_training.model_dump.graph_utils import GraphDefHelper from monolith.native_training.model_export.export_context import get_current_export_ctx from monolith.native_training.model_dump.dump_utils import DumpUtils from monolith.native_training.runner_utils import RunnerConfig file_name = "monolith/native_training/data/training_instance/examplebatch.data" FLAGS = flags.FLAGS class GraphUtilsTest(tf.test.TestCase): @classmethod def setUpClass(cls): FLAGS.data_type = 'examplebatch' cls.dump_utils = DumpUtils() cls.dump_utils.load( "monolith/native_training/model_dump/test_data/model_dump") def test_load_input_fn(self): proto_model = self.dump_utils.get_proto_model(mode='train') graph_helper = self.dump_utils.get_graph_helper(mode='train') result = graph_helper.import_input_fn(input_conf=proto_model.input_fn, file_name=file_name) for fname, ts_repr in proto_model.input_fn.output_features.items(): ts_dict = eval(ts_repr) if ts_dict['is_ragged']: self.assertTrue(isinstance(result[fname], tf.RaggedTensor)) else: self.assertTrue(isinstance(result[fname], tf.Tensor)) def test_load_receiver(self): proto_model = self.dump_utils.get_proto_model(mode='infer') graph_helper = self.dump_utils.get_graph_helper(mode='infer') features, receiver_tensors = graph_helper.import_receiver_fn( receiver_conf=proto_model.serving_input_receiver_fn) for fname, ts_repr in proto_model.serving_input_receiver_fn.features.items( ): ts_dict = eval(ts_repr) if ts_dict['is_ragged']: self.assertTrue(isinstance(features[fname], tf.RaggedTensor)) else: self.assertTrue(isinstance(features[fname], tf.Tensor)) self.assertTrue(len(receiver_tensors) == 1) def test_load_mode(self): mode = tf.estimator.ModeKeys.TRAIN proto_model = self.dump_utils.get_proto_model(mode=mode) graph_helper = self.dump_utils.get_graph_helper(mode=mode) self.assertTrue(isinstance(graph_helper, GraphDefHelper)) graph = tf.compat.v1.get_default_graph() graph.dry_run = True proto_model = self.dump_utils.get_proto_model(mode=mode) graph_helper = self.dump_utils.get_graph_helper(mode=mode) self.assertTrue(isinstance(graph_helper, GraphDefHelper)) mode = tf.estimator.ModeKeys.PREDICT proto_model = self.dump_utils.get_proto_model(mode=mode) graph_helper = self.dump_utils.get_graph_helper(mode=mode) self.assertTrue(isinstance(graph_helper, GraphDefHelper)) if __name__ == "__main__": tf.compat.v1.disable_eager_execution() tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO) tf.test.main() ================================================ FILE: monolith/native_training/model_dump/monolith_model.proto ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // All fields here should be hidden from the public version. syntax = "proto3"; package monolith.native_training.model_dump; option cc_enable_arenas = true; import public "monolith/native_training/runtime/hash_table/compressor/float_compressor.proto"; import public "monolith/native_training/runtime/hash_table/initializer/initializer_config.proto"; import public "monolith/native_training/runtime/hash_table/optimizer/optimizer.proto"; message FeatureColumn { } message LookupEmbeddingSlice { } message InputFn { } message ModelFn { } message ServingInputReceiverFn { } message ExtraOutput { } message Signature { } message ProtoModel { } message HashTableConfig { } message FeatureSliceDim { } enum Combiner { // Enum must have at least 1 value. ReduceSum = 0; } message FeatureCombiner { } message ModelDump { } ================================================ FILE: monolith/native_training/model_export/BUILD ================================================ load("@rules_python//python:defs.bzl", "py_binary", "py_library", "py_test") load("@rules_proto//proto:defs.bzl", "proto_library") load("@com_google_protobuf//:protobuf.bzl", "py_proto_library") package(default_visibility = ["//visibility:public"]) proto_library( name = "export_proto", srcs = ["export.proto"], ) py_proto_library( name = "export_py_proto", srcs = ["export.proto"], ) py_library( name = "export_context", srcs = ["export_context.py"], deps = [ "//monolith:utils", "//monolith/native_training:monolith_export", "//monolith/native_training:utils", ], ) py_library( name = "saved_model_exporters", srcs = ["saved_model_exporters.py"], deps = [ ":data_gen_utils", ":export_context", "//monolith/native_training:device_utils", "//monolith/native_training:hash_table_ops", "//monolith/native_training:monolith_export", "//monolith/native_training:multi_hash_table_ops", "//monolith/native_training:multi_type_hash_table", "//monolith/native_training/model_dump:dump_utils", ], ) py_test( name = "saved_model_exporters_test", srcs = ["saved_model_exporters_test.py"], deps = [ ":saved_model_exporters", "//monolith/native_training:test_utils", ], ) py_library( name = "export_state_utils", srcs = ["export_state_utils.py"], deps = [ ":export_py_proto", ], ) py_test( name = "export_state_utils_test", srcs = ["export_state_utils_test.py"], deps = [ ":export_state_utils", ], ) py_library( name = "export_hooks", srcs = ["export_hooks.py"], deps = [ ":export_py_proto", ":export_state_utils", ":saved_model_exporters", "//monolith/native_training:save_utils", ], ) py_test( name = "export_hooks_test", srcs = ["export_hooks_test.py"], deps = [ ":export_hooks", ":saved_model_exporters", "//monolith/native_training:save_utils", ], ) py_binary( name = "saved_model_visulizer", srcs = ["saved_model_visulizer.py"], deps = [ "//monolith/native_training:distribution_ops", "//monolith/native_training:hash_table_ops", ], ) py_binary( name = "warmup_data_gen", srcs = ["warmup_data_gen.py"], deps = [ ":data_gen_utils", "//monolith/native_training:cpu_training", "@org_tensorflow_serving//tensorflow_serving/apis:predict_proto_py_pb2", "@org_tensorflow_serving//tensorflow_serving/apis:prediction_log_proto_py_pb2", ], ) py_binary( name = "warmup_data_decoder", srcs = ["warmup_data_decoder.py"], deps = [ ":data_gen_utils", "//monolith/native_training:cpu_training", "@org_tensorflow_serving//tensorflow_serving/apis:predict_proto_py_pb2", "@org_tensorflow_serving//tensorflow_serving/apis:prediction_log_proto_py_pb2", ], ) py_binary( name = "warmup_example_batch", srcs = ["warmup_example_batch.py"], deps = [ ":data_gen_utils", "//monolith/native_training:cpu_training", "@org_tensorflow_serving//tensorflow_serving/apis:predict_proto_py_pb2", "@org_tensorflow_serving//tensorflow_serving/apis:prediction_log_proto_py_pb2", ], ) py_binary( name = "demo_export", srcs = ["demo_export.py"], deps = [ ":saved_model_exporters", "//monolith/native_training:cpu_training", "//monolith/native_training:model", "//monolith/native_training/data/training_instance:parse_instance_ops_py", ], ) py_test( name = "demo_export_test", srcs = ["demo_export_test.py"], deps = [ ":demo_export", ], ) py_binary( name = "demo_predictor", srcs = ["demo_predictor.py"], deps = [ "//idl:proto_parser_py_proto", "//monolith/native_training:distribution_ops", "//monolith/native_training:hash_filter_ops", "//monolith/native_training:hash_table_ops", "//monolith/native_training:logging_ops", "//monolith/native_training:model", "//monolith/native_training/data/training_instance:parse_instance_ops_py", ], ) py_binary( name = "demo_predictor_client", srcs = ["demo_predictor_client.py"], deps = [ ":demo_predictor", "@org_tensorflow//tensorflow:tensorflow_py", "@org_tensorflow_serving//tensorflow_serving/apis:predict_proto_py_pb2", "@org_tensorflow_serving//tensorflow_serving/apis:prediction_service_proto_py_pb2", ], ) py_library( name = "model_export", srcs = ["__init__.py"], srcs_version = "PY3", deps = [ ":export_context", ":saved_model_exporters", "@org_tensorflow//tensorflow:tensorflow_py", ], ) py_library( name = "data_gen_utils", srcs = ["data_gen_utils.py"], srcs_version = "PY3", deps = [ ":export_context", "//idl:example_py_proto", "//idl:line_id_py_proto", "//monolith/native_training:env_utils", "//monolith/native_training:utils", "//monolith/native_training/data:feature_list", "@org_tensorflow//tensorflow:tensorflow_py", "@org_tensorflow_serving//tensorflow_serving/apis:predict_proto_py_pb2", "@org_tensorflow_serving//tensorflow_serving/apis:prediction_log_proto_py_pb2", ], ) py_test( name = "data_gen_utils_test", srcs = ["data_gen_utils_test.py"], data = ["//monolith/native_training/data/test_data:test_feature_lists"], srcs_version = "PY3", deps = [ ":data_gen_utils", "//monolith/native_training/data:datasets_py", "//monolith/native_training/data:feature_utils_py", "//monolith/native_training/data:parsers_py", ], ) py_library( name = "export_utils", srcs = ["export_utils.py"], deps = [ ":export_context", "//monolith/native_training:distributed_serving_ops", "//monolith/native_training:nested_tensors", ], ) py_test( name = "export_utils_test", srcs = ["export_utils_test.py"], deps = [ ":export_utils", ], ) ================================================ FILE: monolith/native_training/model_export/__init__.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import sys as _sys import monolith.native_training.model_export.export_context as export_context import monolith.native_training.model_export.saved_model_exporters as saved_model_exporters _sys.modules['monolith.model_export.export_context'] = export_context _sys.modules[ 'monolith.model_export.saved_model_exporters'] = saved_model_exporters del _sys ================================================ FILE: monolith/native_training/model_export/data_gen_utils.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import sys from absl import app from absl import flags from absl import logging from random import randint, uniform, choice from copy import deepcopy import numpy as np from struct import pack, unpack from functools import singledispatch from typing import List, Iterable, Tuple, Dict, Any, get_type_hints from datetime import datetime from dataclasses import dataclass import tensorflow as tf from tensorflow_serving.apis import predict_pb2 from tensorflow_serving.apis import prediction_log_pb2 from tensorflow.python.saved_model.signature_constants import DEFAULT_SERVING_SIGNATURE_DEF_KEY from tensorflow.python.saved_model.model_utils.export_output import PredictOutput from monolith.native_training import env_utils from monolith.native_training.data.feature_list import Feature, FeatureList, get_feature_name_and_slot from idl.matrix.proto.example_pb2 import Example, ExampleBatch from idl.matrix.proto.example_pb2 import Feature as EFeature from idl.matrix.proto.line_id_pb2 import LineId from idl.matrix.proto.proto_parser_pb2 import Instance from idl.matrix.proto.feature_pb2 import Feature as IFeature from google.protobuf.descriptor import FieldDescriptor from monolith.native_training.utils import get_collection from monolith.native_training.model_export import export_context MASK_V1 = (1 << 54) - 1 MAX_SLOT_V1 = (1 << (64 - 54)) - 1 MASK_V2 = (1 << 48) - 1 MAX_SLOT_V2 = (1 << (64 - 48)) - 1 class FeatureMeta(object): line_id_fields = {f.name: f for f in LineId.DESCRIPTOR.fields} dtypes = { FieldDescriptor.CPPTYPE_FLOAT: tf.float32, FieldDescriptor.CPPTYPE_DOUBLE: tf.float32, FieldDescriptor.CPPTYPE_UINT32: tf.int32, FieldDescriptor.CPPTYPE_INT32: tf.int32, FieldDescriptor.CPPTYPE_UINT64: tf.int64, FieldDescriptor.CPPTYPE_INT64: tf.int64, FieldDescriptor.CPPTYPE_BOOL: tf.bool, FieldDescriptor.CPPTYPE_STRING: tf.string } def __init__(self, name: str, slot: int = None, shape: int = None, dtype: tf.compat.v1.dtypes.DType = None, extra=None): self.name = name self.slot = slot if shape is None: self.shape = 1 if self.slot is None else -1 else: self.shape = shape # infer data type self.dtype = dtype if self.dtype is None: if name in self.line_id_fields: cpp_type = self.line_id_fields[name].cpp_type if cpp_type in self.dtypes: self.dtype = self.dtypes[cpp_type] if self.dtype is None: if slot is None: self.dtype = tf.float32 else: self.dtype = tf.int64 self.extra = extra @dataclass class ParserArgs(object): model_name: str = 'entry' fidv1_features: List[int] = None fidv2_features: List[str] = None sparse_features: List[str] = None dense_features: List[str] = None dense_feature_shapes: List[int] = None dense_feature_types: List[tf.compat.v1.dtypes.DType] = None extra_features: List[str] = None extra_feature_shapes: List[int] = None feature_list: FeatureList = None batch_size: int = 64 max_records: int = 1000 signature_name: List[str] = None variant_type: str = None warmup_file: str = None drop_rate: float = 0.5 def __post_init__(self): self.model_name = self.model_name or self.get('model_name') self.fidv1_features = self.fidv1_features or self.get('fidv1_features') self.fidv2_features = self.fidv2_features or self.get('fidv2_features') self.sparse_features = self.sparse_features or self.get('sparse_features') self.dense_features = self.dense_features or self.get('dense_features') self.dense_feature_shapes = self.dense_feature_shapes or self.get( 'dense_feature_shapes') self.dense_feature_types = self.dense_feature_types or self.get( 'dense_feature_types') self.extra_features = self.extra_features or self.get('extra_features') self.extra_feature_shapes = self.extra_feature_shapes or self.get( 'extra_feature_shapes') self.feature_list = self.feature_list or self.get('feature_list') if self.feature_list is None: try: self.feature_list = FeatureList.parse() except: logging.info('cannot get feature_list, pls check!') self.signature_name = self.signature_name or self.get('signature_name') if self.signature_name is None: self.signature_name = [DEFAULT_SERVING_SIGNATURE_DEF_KEY] else: if DEFAULT_SERVING_SIGNATURE_DEF_KEY not in self.signature_name: self.signature_name.append(DEFAULT_SERVING_SIGNATURE_DEF_KEY) self.signature_name = list(set(self.signature_name)) self.variant_type = self.variant_type or self.get('variant_type') @classmethod def get(cls, name): collection = get_collection(name) if collection is None: return None elif name == 'signature_name': return list(set(collection)) else: return collection[-1] def gen_fids_v1(slot: int, size: int = 1) -> List[int]: if 0 < slot < MAX_SLOT_V1: return [ (slot << 54) | (randint(1, sys.maxsize) & MASK_V1) for _ in range(size) ] else: logging.log_first_n(logging.INFO, f"enconter slot bigger the 1023 in fid v1 {slot}", 10) return [] def gen_fids_v2(slot: int, size: int = 1) -> List[int]: assert 0 < slot < MAX_SLOT_V2 return [ (slot << 48) | (randint(1, sys.maxsize) & MASK_V2) for _ in range(size) ] @singledispatch def fill_features(): raise NotImplementedError("Not implemented fill_features") @fill_features.register(EFeature) def _(feature: EFeature, meta: FeatureMeta, drop_rate: float = 0): (name, size, dtype, feat) = meta.name, meta.shape, meta.dtype, meta.extra if size == -1: # sparse if '_recent' in name or '_cp' in name or feat.method in { 'Combine', 'VectorTopString' }: if drop_rate > 0 and uniform(0, 1) > drop_rate: feature.fid_v2_list.value.extend(gen_fids_v2(feat.slot, randint(0, 2))) elif feat.slot not in {1, 200}: # user_id, item_id if uniform(0, 1) > drop_rate: feature.fid_v2_list.value.extend(gen_fids_v2(feat.slot, 1)) else: feature.fid_v2_list.value.extend(gen_fids_v2(feat.slot, 1)) elif dtype == tf.float64: data = [uniform(0, 1) for _ in range(size)] feature.double_list.value.extend(data) elif dtype == tf.float32: data = [uniform(0, 1) for _ in range(size)] feature.float_list.value.extend(data) elif dtype == tf.int64: data = [randint(sys.maxsize // 2, sys.maxsize) for _ in range(size)] feature.int64_list.value.extend(data) else: logging.warning(f'{name} is empty') @fill_features.register(IFeature) def _(feature: IFeature, meta: FeatureMeta, drop_rate: float = 0): (name, size, dtype) = meta.name, meta.shape, meta.dtype feature.name = name if size == -1: # sparse if '_recent' in name or '_cp' in name or '-' in name: if drop_rate > 0 and uniform(0, 1) > drop_rate: feature.fid.extend(gen_fids_v2(meta.slot, randint(0, 2))) else: feature.fid.extend(gen_fids_v2(meta.slot, 1)) elif dtype in {tf.float64, tf.float32}: data = [uniform(0, 1) for _ in range(size)] feature.float_value.extend(data) elif dtype == tf.int64: data = [randint(sys.maxsize // 2, sys.maxsize) for _ in range(size)] feature.int64_value.extend(data) else: logging.warning(f'{name} is empty') def fill_line_id(line_id, features: List[FeatureMeta] = None, hash_len: int = 48, actions: List[int] = None): MASK = MASK_V1 if hash_len == 54 else MASK_V2 if features: for meta in features: name, shape = meta.name, meta.shape if name == 'uid': line_id.uid = (1 << hash_len) | ( randint(sys.maxsize // 2, sys.maxsize) & MASK) elif name == 'item_id': line_id.item_id = (200 << hash_len) | ( randint(sys.maxsize // 2, sys.maxsize) & MASK) elif name == 'req_time': line_id.req_time = int(datetime.now().timestamp()) line_id.sample_rate = 1.0 elif name == 'actions': if actions: line_id.actions.extend([choice(actions) for _ in range(shape)]) else: line_id.actions.extend([randint(0, 10) for _ in range(shape)]) elif name == 'stay_time': line_id.stay_time = uniform(a=0, b=1) * 1000 elif hasattr(LineId, name): desc = getattr(LineId, name).DESCRIPTOR if desc.label == FieldDescriptor.LABEL_REPEATED: value_list = getattr(line_id, name) if desc.cpp_type in { FieldDescriptor.CPPTYPE_DOUBLE, FieldDescriptor.CPPTYPE_FLOAT }: value_list.extend([uniform(0, 1) for _ in range(shape)]) elif desc.cpp_type in { FieldDescriptor.CPPTYPE_INT32, FieldDescriptor.CPPTYPE_INT64, FieldDescriptor.CPPTYPE_UINT32, FieldDescriptor.CPPTYPE_UINT64 }: value_list.extend([randint(0, 10) for _ in range(shape)]) elif desc.cpp_type == FieldDescriptor.CPPTYPE_STRING: value_list.extend(['hello world' for _ in range(shape)]) elif desc.cpp_type == FieldDescriptor.CPPTYPE_BOOL: value_list.extend([False for _ in range(shape)]) else: if desc.cpp_type in { FieldDescriptor.CPPTYPE_DOUBLE, FieldDescriptor.CPPTYPE_FLOAT }: setattr(line_id, name, uniform(0, 1)) elif desc.cpp_type in { FieldDescriptor.CPPTYPE_INT32, FieldDescriptor.CPPTYPE_INT64, FieldDescriptor.CPPTYPE_UINT32, FieldDescriptor.CPPTYPE_UINT64 }: setattr(line_id, name, randint(0, 10)) elif desc.cpp_type == FieldDescriptor.CPPTYPE_STRING: setattr(line_id, name, 'hello world') elif desc.cpp_type == FieldDescriptor.CPPTYPE_BOOL: setattr(line_id, name, False) else: line_id.uid = (1 << hash_len) | (randint(sys.maxsize // 2, sys.maxsize) & MASK) line_id.item_id = (200 << hash_len) | ( randint(sys.maxsize // 2, sys.maxsize) & MASK) line_id.req_time = int(datetime.now().timestamp()) line_id.sample_rate = 1.0 line_id.actions.append(randint(0, 10)) def lg_header(source: str): # calc java hash code if source: seed, h = 31, 0 for c in source: h = np.int32(seed * h) + ord(c) dfhc = int(np.uint32(h)).to_bytes(4, 'little') return pack('4Bi', 0, dfhc[0], dfhc[1], dfhc[2], 0) else: return int.to_bytes(0, 8, byteorder='little') def sort_header(sort_id: bool, kafka_dump: bool, kafka_dump_prefix: bool): # kafka_dump_prefix: [size: 8 bytes][aggregate_page_sortid_size: 8 bytes] # sort_id: [size: 8 bytes][sort_id: size bytes] # kafka_dump: [kafka_dump: 8 bytes] if sort_id and not (kafka_dump or kafka_dump_prefix): return pack(' Example: assert len(sparse_features) > 0 and len(sparse_features) == len( set(sparse_features)) name_to_info = {} for name in sparse_features: try: feat = feature_list.get(name) if feat is not None: name_to_info[name] = FeatureMeta(name, slot=feat.slot, dtype=tf.int64, extra=feat) except: _, slot = get_feature_name_and_slot(name) feat = Feature(feature_name=name, slot=slot) name_to_info[name] = FeatureMeta(name, slot=feat.slot, dtype=tf.int64, extra=feat) # logging.warning(f'cannot find name {name} in feature_list') if dense_features: name_to_info.update({meta.name: meta for meta in dense_features}) assert len(name_to_info) > 0 example = Example() label_meta = name_to_info.pop('label', None) for name, meta in name_to_info.items(): named_feature = example.named_feature.add() if meta.slot: named_feature.id = meta.slot named_feature.name = name fill_features(named_feature.feature, meta, drop_rate) fill_line_id(example.line_id, extra_features, actions=actions) if label_meta: example.label.extend([choice([0, 1]) for _ in range(label_meta.shape)]) else: example.label.append(choice([0, 1])) return example def gen_instance(fidv1_features: List[int] = None, fidv2_features: List[str] = None, dense_features: List[FeatureMeta] = None, extra_features: List[FeatureMeta] = None, feature_list: FeatureList = None, drop_rate: float = 0, actions: List[int] = None) -> Instance: inst = Instance() if fidv1_features is not None: assert len(fidv1_features) > 0 and len(fidv1_features) == len( set(fidv1_features)) for slot in fidv1_features: size = 1 if slot in {1, 200} else randint(0, 3) fids_v1 = gen_fids_v1(slot, size) if fids_v1: inst.fid.extend(fids_v1) name_to_info = {} if fidv2_features: assert len(fidv2_features) > 0 and len(fidv2_features) == len( set(fidv2_features)) notfound_names = list() for name in fidv2_features: try: feat = feature_list.get(name) if feat is not None: name_to_info[name] = FeatureMeta(name, slot=feat.slot, dtype=tf.int64, extra=feat) except: notfound_names.append(name) logging.warning( f'Total {len(notfound_names)} features not found in feature_list: {notfound_names}' ) if dense_features: name_to_info.update({meta.name: meta for meta in dense_features}) label_meta = name_to_info.pop('label', None) for name, meta in name_to_info.items(): feature = inst.feature.add() fill_features(feature, meta, drop_rate) fill_line_id(inst.line_id, extra_features, actions=actions) if label_meta: inst.label.extend([choice([0, 1]) for _ in range(label_meta.shape)]) else: inst.label.append(choice([0, 1])) return inst def gen_example_batch(sparse_features: List[str], dense_features: List[FeatureMeta] = None, extra_features: List[FeatureMeta] = None, feature_list: FeatureList = None, batch_size: int = 64, drop_rate: float = 0, actions: List[int] = None) -> ExampleBatch: assert len(sparse_features) > 0 and len(sparse_features) == len( set(sparse_features)) and batch_size > 0 name_to_info = {} for name in sparse_features: try: feat = feature_list.get(name) if feat is not None: name_to_info[name] = FeatureMeta(name, slot=feat.slot, dtype=tf.int64, extra=feat) except: _, slot = get_feature_name_and_slot(name) feat = Feature(feature_name=name, slot=slot) name_to_info[name] = FeatureMeta(name, slot=feat.slot, dtype=tf.int64, extra=feat) # logging.warning(f'cannot find name {name} in feature_list') if dense_features: name_to_info.update({meta.name: meta for meta in dense_features}) assert len(name_to_info) > 0 example_batch = ExampleBatch(batch_size=batch_size) label_meta = name_to_info.pop('label', None) for name, meta in name_to_info.items(): named_feature_list = example_batch.named_feature_list.add() if meta.slot: named_feature_list.id = meta.slot named_feature_list.name = name for _ in range(batch_size): feature = named_feature_list.feature.add() fill_features(feature, meta, drop_rate) named_feature_list = example_batch.named_feature_list.add() named_feature_list.name = '__LINE_ID__' for _ in range(batch_size): feature = named_feature_list.feature.add() line_id = LineId() fill_line_id(line_id, extra_features, hash_len=48, actions=actions) feature.bytes_list.value.append(line_id.SerializeToString()) named_feature_list = example_batch.named_feature_list.add() named_feature_list.name = '__LABEL__' for i in range(batch_size): feature = named_feature_list.feature.add() if label_meta: feature.float_list.value.extend( [choice([0, 1]) for _ in range(label_meta.shape)]) else: feature.float_list.value.append(i % 2) return example_batch def gen_prediction_log( args: ParserArgs) -> Iterable[prediction_log_pb2.PredictionLog]: assert args.variant_type in { 'example', 'instance', 'example_batch', 'examplebatch' } and args.batch_size < args.max_records if args.variant_type == 'example': input_name = 'examples' elif args.variant_type == 'instance': input_name = 'instances' else: input_name = 'example_batch' dense_feature_meta = [] if args.dense_features: for name, shape, dtype in zip(args.dense_features, args.dense_feature_shapes, args.dense_feature_types): try: assert shape >= 1 try: feat = args.feature_list.get(name) if feat is not None: dense_feature_meta.append( FeatureMeta(name, shape=shape, dtype=dtype, extra=feat)) else: dense_feature_meta.append( FeatureMeta(name, shape=shape, dtype=dtype)) except: dense_feature_meta.append(FeatureMeta(name, shape=shape, dtype=dtype)) except: logging.warning(f'cannot find name {name} in feature_list') else: dense_feature_meta = None extra_meta = None if args.extra_features: extra_meta = [ FeatureMeta(name=name, shape=shape) for name, shape in zip(args.extra_features, args.extra_feature_shapes) ] if args.signature_name is None: args.signature_name = [DEFAULT_SERVING_SIGNATURE_DEF_KEY] num_log = args.max_records // args.batch_size assert num_log >= len(args.signature_name) export_ctx = export_context.get_current_export_ctx() graph = tf.compat.v1.get_default_graph() if export_ctx is None: signatures = None else: signatures = { signature.name: signature for signature in export_ctx.signatures(graph) } for name in signatures: if name not in args.signature_name: args.signature_name.append(name) for i in range(num_log): request = predict_pb2.PredictRequest() request.model_spec.name = args.model_name signature_name = args.signature_name[i % len(args.signature_name)] request.model_spec.signature_name = signature_name if signatures is None or signatures[signature_name].inputs: if signatures is not None: assert input_name in signatures[signature_name].inputs if args.variant_type == 'example': instances = [ gen_example(args.sparse_features, dense_feature_meta, extra_meta, args.feature_list, args.drop_rate).SerializeToString() for _ in range(args.batch_size) ] elif args.variant_type == 'instance': instances = [ gen_instance(args.fidv1_features, args.fidv2_features, dense_feature_meta, extra_meta, args.feature_list, args.drop_rate).SerializeToString() for _ in range(args.batch_size) ] else: instances = [ gen_example_batch(args.sparse_features, dense_feature_meta, extra_meta, args.feature_list, args.batch_size, args.drop_rate).SerializeToString() ] request.inputs[input_name].CopyFrom(tf.make_tensor_proto(instances)) log = prediction_log_pb2.PredictionLog( predict_log=prediction_log_pb2.PredictLog(request=request)) yield log if signatures: outputs = signatures[signature_name].outputs if signature_name == DEFAULT_SERVING_SIGNATURE_DEF_KEY and outputs is not None: if len(outputs) > 1 or (len(outputs) == 1 and PredictOutput._SINGLE_OUTPUT_DEFAULT_NAME not in outputs): for head_name in outputs: request.output_filter.append(head_name) log = prediction_log_pb2.PredictionLog( predict_log=prediction_log_pb2.PredictLog(request=request)) yield log del request.output_filter[:] def gen_warmup_file(warmup_file: str = None, drop_rate: float = None): warmup_args = ParserArgs(warmup_file=warmup_file) if drop_rate is not None: warmup_args.drop_rate = drop_rate if not warmup_args.warmup_file: logging.info(f'warmup_file is None, skip') return None elif tf.io.gfile.exists(warmup_args.warmup_file): logging.info(f'{warmup_args.warmup_file} exists, return directly') return warmup_args.warmup_file else: features = warmup_args.fidv1_features or warmup_args.fidv2_features or \ warmup_args.sparse_features or warmup_args.dense_features or warmup_args.extra_features if features is None: logging.warning('features is None, pls. check!') return None # if warmup_args.variant_type != 'instance' and warmup_args.feature_list is None: # logging.warning('feature_list is None, pls. check!') # return None # remove label if exists if warmup_args.dense_features is not None and 'label' in warmup_args.dense_features: dense_features = deepcopy(warmup_args.dense_features) dense_feature_shapes = deepcopy(warmup_args.dense_feature_shapes) dense_feature_types = deepcopy(warmup_args.dense_feature_types) idx = warmup_args.dense_features.index('label') if idx is not None and idx >= 0: try: del dense_features[idx] del dense_feature_shapes[idx] del dense_feature_types[idx] except: pass else: dense_features = None dense_feature_shapes = None dense_feature_types = None warmup_args.dense_features = dense_features warmup_args.dense_feature_shapes = dense_feature_shapes warmup_args.dense_feature_types = dense_feature_types try: logging.info( f'begin to write prediction log to {warmup_args.warmup_file}') dirname = os.path.dirname(warmup_args.warmup_file) if not tf.io.gfile.exists(dirname): tf.io.gfile.makedirs(dirname) with tf.io.TFRecordWriter(warmup_args.warmup_file) as writer: for log in gen_prediction_log(warmup_args): writer.write(log.SerializeToString()) logging.info( f'finish to write prediction log to {warmup_args.warmup_file}') return warmup_args.warmup_file except Exception as e: logging.warning(f'{type(e)}: {str(e)}') raise e def gen_random_data_file(data_file_name: str, args: ParserArgs, num_batch: int = 128, source: str = None, sort_id: bool = True, kafka_dump: bool = False, kafka_dump_prefix: bool = False, actions: List[int] = None): dense_feature_meta = [] if args.dense_features: for name, shape, dtype in zip(args.dense_features, args.dense_feature_shapes, args.dense_feature_types): try: assert shape >= 1 try: feat = args.feature_list.get(name) if feat is not None: dense_feature_meta.append( FeatureMeta(name, shape=shape, dtype=dtype, extra=feat)) else: dense_feature_meta.append( FeatureMeta(name, shape=shape, dtype=dtype)) except: dense_feature_meta.append(FeatureMeta(name, shape=shape, dtype=dtype)) except: logging.warning(f'cannot find name {name} in feature_list') else: dense_feature_meta = None extra_meta = None if args.extra_features: extra_meta = [ FeatureMeta(name=name, shape=shape) for name, shape in zip(args.extra_features, args.extra_feature_shapes) ] instances = [] for i in range(num_batch): if args.variant_type == 'example': instances.extend([ gen_example(args.sparse_features, dense_feature_meta, extra_meta, args.feature_list, args.drop_rate, actions=actions).SerializeToString() for _ in range(args.batch_size) ]) elif args.variant_type == 'instance': instances.extend([ gen_instance(args.fidv1_features, args.fidv2_features, dense_feature_meta, extra_meta, args.feature_list, args.drop_rate, actions=actions).SerializeToString() for _ in range(args.batch_size) ]) else: instances.extend([ gen_example_batch(args.sparse_features, dense_feature_meta, extra_meta, args.feature_list, args.batch_size, args.drop_rate, actions=actions).SerializeToString() ]) if sort_id: header = sort_header(sort_id, kafka_dump, kafka_dump_prefix) else: header = lg_header(source) with open(data_file_name, 'wb') as ostream: for inst in instances: ostream.write(header) ostream.write(int.to_bytes(len(inst), 8, byteorder='little')) ostream.write(inst) ================================================ FILE: monolith/native_training/model_export/data_gen_utils_test.py ================================================ ================================================ FILE: monolith/native_training/model_export/demo_export.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import getpass from absl import app from absl import flags import tensorflow as tf from monolith.native_training.data.training_instance.python.parse_instance_ops import parse_instances from monolith.native_training import model from monolith.native_training import cpu_training from monolith.native_training.model import TestFFMModel from monolith.native_training.model_export.export_context import ExportMode, enter_export_mode from monolith.native_training.model_export.saved_model_exporters import StandaloneExporter, DistributedExporter FLAGS = flags.FLAGS flags.DEFINE_integer( "num_ps", default=5, help=("Number of parameter servers. Must align with training.")), flags.DEFINE_string( "model_dir", default="/tmp/{}/monolith/native_training/demo/ckpt".format( getpass.getuser()), help=("Model dir containing training ckpts."), ) flags.DEFINE_string( "export_base", default="/tmp/{}/monolith/native_training/demo/saved_model".format( getpass.getuser()), help=("The path to saved exported saved model."), ) flags.DEFINE_enum_class("export_mode", default=ExportMode.STANDALONE, enum_class=ExportMode, help="standalone or distributed") def export_saved_model(model_dir, export_base, num_ps, export_mode): tf.compat.v1.disable_eager_execution() tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO) params = TestFFMModel.params() params.name = "demo_export" params.train.per_replica_batch_size = 64 task = params.instantiate() cpu_training_task = cpu_training.CpuTraining( cpu_training.CpuTrainingConfig(num_ps=num_ps), task) if export_mode == ExportMode.STANDALONE: exporter = StandaloneExporter(cpu_training_task.create_model_fn(), model_dir=model_dir, export_dir_base=export_base) elif export_mode == ExportMode.DISTRIBUTED: exporter = DistributedExporter(cpu_training_task.create_model_fn(), model_dir=model_dir, export_dir_base=export_base, shared_embedding=False) def serving_input_receiver_fn(): receiver_tensors = {} features = {} instances_placeholder = tf.compat.v1.placeholder(dtype=tf.string, shape=(None,)) receiver_tensors["instances"] = instances_placeholder parsed_results = parse_instances( instances_placeholder, fidv1_features=[i for i in range(model._NUM_SLOTS)], fidv2_features=None, misc_float_features=[], misc_int64_features=[]) for i in range(model._NUM_SLOTS): features["feature_{}".format(i)] = parsed_results["slot_{}".format(i)] return tf.estimator.export.ServingInputReceiver(features, receiver_tensors) exporter.export_saved_model(serving_input_receiver_fn) def main(_): export_saved_model(FLAGS.model_dir, FLAGS.export_base, FLAGS.num_ps, FLAGS.export_mode) if __name__ == "__main__": app.run(main) ================================================ FILE: monolith/native_training/model_export/demo_export_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import tensorflow as tf from monolith.native_training.model_export.export_context import ExportMode from monolith.native_training.model_export import demo_export from monolith.native_training import cpu_training from monolith.native_training.model import TestFFMModel tf.compat.v1.disable_eager_execution() class DemoExportTest(tf.test.TestCase): def test_demo_export(self): model_dir = os.path.join(os.environ["TEST_TMPDIR"], "test_ffm_model") params = TestFFMModel.params() params.name = "test_ffm_model" params.train.per_replica_batch_size = 64 cpu_training.local_train(params, num_ps=5, model_dir=model_dir) demo_export.export_saved_model( model_dir, os.path.join(os.environ["TEST_TMPDIR"], "standalone_saved_model"), 5, ExportMode.STANDALONE) demo_export.export_saved_model( model_dir, os.path.join(os.environ["TEST_TMPDIR"], "distributed_saved_model"), 5, ExportMode.DISTRIBUTED) if __name__ == "__main__": tf.test.main() ================================================ FILE: monolith/native_training/model_export/demo_predictor.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Some examples: # bazel run //monolith/native_training/model_export:demo_predictor -- --saved_model_path=/ps_0/1623816010 --signature=lookup # bazel run //monolith/native_training/model_export:demo_predictor -- --saved_model_path=/ps_0/1623816010 --signature=hashtable_assign # bazel run //monolith/native_training/model_export:demo_predictor -- --saved_model_path=/standalone/1623816010 --signature=serving_default from absl import app from absl import flags from absl import logging import numpy as np import tensorflow as tf from idl.matrix.proto import proto_parser_pb2 from monolith.native_training import model FLAGS = flags.FLAGS flags.DEFINE_string("saved_model_path", default="", help=("The path for the demo saved model")) flags.DEFINE_string("tag_set", "serve", "tag_set") flags.DEFINE_string("signature", "serving_default", "signature to predict") flags.DEFINE_integer("batch_size", 128, "batch size") def make_fid_v1(slot_id, fid): return (slot_id << 54) | fid def generate_demo_instance(): instance = proto_parser_pb2.Instance() v1_fids = [] max_vocab = max(model._VOCAB_SIZES) for i in range(model._NUM_SLOTS): v1_fids.extend( make_fid_v1(i, i * max_vocab + np.random.randint(max_vocab, size=5))) instance.fid.extend(v1_fids) return instance.SerializeToString() def random_generate_instances(bs): return [generate_demo_instance() for _ in range(bs)] def random_generate_examples(bs): return [model.generate_ffm_example(model._VOCAB_SIZES) for _ in range(bs)] def random_generate_int(shape): max_vocab = max(model._VOCAB_SIZES) * model._NUM_SLOTS return np.random.randint(max_vocab, size=shape) def random_generate_float(shape): return np.random.uniform(size=shape) def predict(): with tf.compat.v1.Session(graph=tf.compat.v1.Graph()) as sess: meta_graph = tf.compat.v1.saved_model.load(sess, {FLAGS.tag_set}, FLAGS.saved_model_path) input_infos = meta_graph.signature_def[FLAGS.signature].inputs output_infos = meta_graph.signature_def[FLAGS.signature].outputs feed_dict = {} for input_name, tensor_info in input_infos.items(): shape = [ FLAGS.batch_size if dim.size == -1 else dim.size for dim in tensor_info.tensor_shape.dim ] logging.info("Generate {} of shape {}".format(input_name, shape)) if tensor_info.dtype == tf.dtypes.string.as_datatype_enum: assert len(shape) == 1 feed_dict[tensor_info.name] = random_generate_instances(shape[0]) elif tensor_info.dtype == tf.dtypes.int64.as_datatype_enum: feed_dict[tensor_info.name] = random_generate_int(shape) elif tensor_info.dtype == tf.dtypes.float32.as_datatype_enum: feed_dict[tensor_info.name] = random_generate_float(shape) else: raise ValueError("{} has invalid setting {}.".format( input_name, tensor_info)) fetch = { output_name: tensor_info.name for output_name, tensor_info in output_infos.items() } logging.info(sess.run(fetch, feed_dict=feed_dict)) def main(_): predict() if __name__ == "__main__": logging.set_verbosity(logging.INFO) app.run(main) ================================================ FILE: monolith/native_training/model_export/demo_predictor_client.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from absl import app from absl import flags from absl import logging import grpc import numpy as np import tensorflow as tf from tensorflow_serving.apis import predict_pb2, get_model_metadata_pb2 from tensorflow_serving.apis import prediction_service_pb2_grpc from monolith.native_training import model from monolith.native_training.model_export import demo_predictor FLAGS = flags.FLAGS flags.DEFINE_string("server", "localhost:8500", "PredictionService host:port") flags.DEFINE_string("model_name", "default", "Model name") flags.DEFINE_string("signature_name", "serving_default", "Signature Name") flags.DEFINE_bool("use_example", False, "tf example or instance") def get_signature_def(stub): request = get_model_metadata_pb2.GetModelMetadataRequest() request.model_spec.name = FLAGS.model_name request.metadata_field.append("signature_def") result = stub.GetModelMetadata(request) any_proto = result.metadata["signature_def"] signature_def_map = get_model_metadata_pb2.SignatureDefMap() assert any_proto.Is(signature_def_map.DESCRIPTOR) any_proto.Unpack(signature_def_map) signature_def = signature_def_map.signature_def[FLAGS.signature_name] print([x for x in signature_def_map.signature_def]) return signature_def def main(_): channel = grpc.insecure_channel(FLAGS.server) stub = prediction_service_pb2_grpc.PredictionServiceStub(channel) signature_def = get_signature_def(stub) request = predict_pb2.PredictRequest() request.model_spec.name = FLAGS.model_name request.model_spec.signature_name = FLAGS.signature_name input_infos = signature_def.inputs for input_name, tensor_info in input_infos.items(): shape = [ FLAGS.batch_size if dim.size == -1 else dim.size for dim in tensor_info.tensor_shape.dim ] logging.info("Generate {} of shape {}".format(input_name, shape)) if tensor_info.dtype == tf.dtypes.string.as_datatype_enum: assert len(shape) == 1 if FLAGS.use_example: examples = demo_predictor.random_generate_examples(shape[0]) else: examples = demo_predictor.random_generate_instances(shape[0]) request.inputs[input_name].CopyFrom(tf.make_tensor_proto(examples)) elif tensor_info.dtype == tf.dtypes.int64.as_datatype_enum: request.inputs[input_name].CopyFrom( tf.make_tensor_proto(demo_predictor.random_generate_int(shape))) elif tensor_info.dtype == tf.dtypes.float32.as_datatype_enum: request.inputs[input_name].CopyFrom( tf.make_tensor_proto(demo_predictor.random_generate_float(shape), dtype=tf.float32)) else: raise ValueError("{} has invalid setting {}.".format( input_name, tensor_info)) result = stub.Predict(request, 30) logging.info(result) if __name__ == "__main__": logging.set_verbosity(logging.INFO) app.run(main) ================================================ FILE: monolith/native_training/model_export/export.proto ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. syntax="proto2"; package monolith.model_export; message ServingEntry { optional string export_dir = 1; // Records the global step for exported model. optional int64 global_step = 2; // TODO(leqi.zou): Add deps to support better recovery. } message ServingModelState { repeated ServingEntry entries = 1; } ================================================ FILE: monolith/native_training/model_export/export_context.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from collections import namedtuple from collections import defaultdict from enum import Enum from typing import List import tensorflow as tf from tensorflow.python.util import tf_contextlib from monolith.native_training.monolith_export import monolith_export from monolith.native_training.utils import add_to_collections class ExportMode(Enum): NONE = 0 STANDALONE = 1 DISTRIBUTED = 2 SavedModelSignature = namedtuple('SavedModelSignature', ['name', 'inputs', 'outputs']) @monolith_export class ExportContext: """保存模型导出的上下文""" def __init__(self, with_remote_gpu=False): self._sub_graphs = defaultdict(lambda: tf.Graph()) self._dense_sub_graphs = defaultdict(lambda: tf.Graph()) self._signatures = defaultdict(lambda: {}) self._with_remote_gpu = with_remote_gpu def sub_graph(self, name: str) -> tf.Graph: return self._sub_graphs[name] def dense_sub_graph(self, name: str) -> tf.Graph: return self._dense_sub_graphs[name] @property def dense_sub_graphs(self): return self._dense_sub_graphs @property def sub_graphs(self): return self._sub_graphs @property def with_remote_gpu(self): return self._with_remote_gpu def signatures(self, graph: tf.Graph) -> List[SavedModelSignature]: return self._signatures[id(graph)].values() def add_signature(self, graph: tf.Graph, name: str, inputs, outputs): add_to_collections('signature_name', name) self._signatures[id(graph)][name] = SavedModelSignature(name=name, inputs=inputs, outputs=outputs) def merge_signature(self, graph: tf.Graph, name: str, inputs, outputs): if name not in self._signatures[id(graph)]: self._signatures[id(graph)][name] = SavedModelSignature(name=name, inputs={}, outputs={}) self._signatures[id(graph)][name].inputs.update(inputs) self._signatures[id(graph)][name].outputs.update(outputs) @property def sub_graph_num(self): """得到当前export_context中sub graph的数量""" return len(self._sub_graphs) EXPORT_MODE = ExportMode.NONE EXPORT_CTX = None @monolith_export def is_exporting(): """是否在导出模式中""" return EXPORT_MODE != ExportMode.NONE @monolith_export def is_exporting_standalone(): """是否在导出单机模型""" return EXPORT_MODE == ExportMode.STANDALONE @monolith_export def is_exporting_distributed(): """是否正在导出分布式模型""" return EXPORT_MODE == ExportMode.DISTRIBUTED @monolith_export def get_current_export_ctx() -> ExportContext: """获取当前的上下文""" return EXPORT_CTX @monolith_export @tf_contextlib.contextmanager def enter_export_mode(mode: ExportMode, export_ctx=None): """进入模型导出模式,会根据mode构图 Args: mode (:obj:`ExportMode`): 导出模式,可选ExportMode.DISTRIBUTED, ExportMode.STANDALONE export_ctx (:obj:`ExportContext`, optional): 模型导出上下文 """ global EXPORT_MODE, EXPORT_CTX assert EXPORT_MODE is ExportMode.NONE and EXPORT_CTX is None, "export mode can't be nested" if export_ctx is None: export_ctx = ExportContext() EXPORT_MODE = mode EXPORT_CTX = export_ctx try: yield export_ctx finally: EXPORT_MODE = ExportMode.NONE EXPORT_CTX = None @monolith_export def is_dry_run_or_exporting(): graph = tf.compat.v1.get_default_graph() return is_exporting() or hasattr(graph, 'dry_run') ================================================ FILE: monolith/native_training/model_export/export_hooks.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """A hook that exports the model at the same time we save the checkpoint""" import os import re from pathlib import Path from typing import List from absl import logging import tensorflow as tf import traceback from monolith.native_training import utils from monolith.native_training import save_utils from monolith.native_training.metric import cli from monolith.native_training.model_export import saved_model_exporters from monolith.native_training.model_export import export_pb2 from monolith.native_training.model_export import export_state_utils def get_global_step(checkpoint_path: str): pattern = re.compile(r'^.*model.ckpt-(\d+)$') matched = pattern.match(checkpoint_path.strip()) assert matched is not None return int(matched.group(1)) class ExportSaverListener(tf.estimator.CheckpointSaverListener): """A hook that exports saved model whenever a new ckpt is generated.""" def __init__(self, save_path: str, serving_input_receiver_fn, exporter: saved_model_exporters.BaseExporter, exempt_checkpoint_paths: List[str] = None, dense_only: bool = False): super().__init__() self._serving_input_receiver_fn = serving_input_receiver_fn self._helper = save_utils.SaveHelper(save_path) self._exporter = exporter self._exempt_checkpoint_steps = set([ get_global_step(p) for p in exempt_checkpoint_paths ]) if exempt_checkpoint_paths else set() self._dense_only = dense_only self._mcli = cli.get_cli(utils.get_metric_prefix()) logging.info('Exempt global steps={}'.format(self._exempt_checkpoint_steps)) def after_save(self, session, global_step_value): checkpoint_file = self._helper.get_ckpt_prefix(global_step_value) export_dirs = self._exporter.export_saved_model( self._serving_input_receiver_fn, checkpoint_file, global_step_value) if isinstance(export_dirs, bytes): export_dirs = [export_dirs] elif isinstance(export_dirs, dict): export_dirs = export_dirs.values() for export_dir in export_dirs: self._add_entry_to_state(export_dir, global_step_value) # delete old saved models self._maybe_delete_old_entries(export_dir) def _add_entry_to_state(self, export_dir: bytes, global_step_value: int): export_dir = export_dir.decode() export_dir_base = os.path.dirname(export_dir) export_version = os.path.basename(export_dir) state = export_state_utils.get_export_saver_listener_state(export_dir_base) entry = export_pb2.ServingEntry() entry.export_dir = export_dir entry.global_step = global_step_value state.entries.append(entry) export_state_utils.overwrite_export_saver_listener_state( export_dir_base, state) self._update_metrics(export_dir_base, export_version) def _maybe_delete_old_entries(self, export_dir: bytes): export_dir = export_dir.decode() export_dir_base = os.path.dirname(export_dir) old_state = export_state_utils.get_export_saver_listener_state( export_dir_base) existing_steps = self._helper.get_existing_checkpoint_steps( ) | self._exempt_checkpoint_steps if self._dense_only: path = Path(export_dir_base) model_dir = str(path.parent.parent) full_stats = tf.train.get_checkpoint_state(model_dir) if full_stats: existing_steps |= set([ get_global_step(ckpt) for ckpt in full_stats.all_model_checkpoint_paths ]) new_state = export_pb2.ServingModelState() for entry in old_state.entries: if entry.global_step in existing_steps: new_state.entries.append(entry) else: try: logging.info("Deleted export dir: %s.", entry.export_dir) tf.io.gfile.rmtree(entry.export_dir) except tf.errors.NotFoundError: logging.warning( "Hit NotFoundError when deleting '%s', possibly because another " "process/thread is also deleting/moving the same file", entry.export_dir) export_state_utils.overwrite_export_saver_listener_state( export_dir_base, new_state) def _update_metrics(self, export_dir_base: str, version: str): try: model_name = os.path.basename(export_dir_base) tags = { "model_name": model_name, } version = version.split(".")[0] # In case version is float if version.isdigit(): # In case rough sort, version maybe xxx_dense self._mcli.emit_store("export_models.latest_version", int(version), tags) self._mcli.flush() except Exception as e: err_mesg = f"meet error when trying to emit metric: export_models.latest_version, stack trace: {traceback.format_exc()}" logging.log_every_n_seconds(logging.WARNING, err_mesg, 1200) ================================================ FILE: monolith/native_training/model_export/export_hooks_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from monolith.native_training import save_utils import os import time from unittest import mock import tensorflow as tf from monolith.native_training.model_export import export_hooks from monolith.native_training.model_export import export_state_utils from monolith.native_training import save_utils class ExportHookTest(tf.test.TestCase): def testBasic(self): model_dir = os.path.join(os.environ["TEST_TMPDIR"], "testBasic_model_dir") export_dir_base = os.path.join(os.environ["TEST_TMPDIR"], "testBasic_export_dir") exporter = mock.MagicMock() export_dir = os.path.join(export_dir_base, "12345678") os.makedirs(export_dir) def export_saved_model(serving_input_receiver_fn, checkpoint_file, global_step): self.assertEqual(checkpoint_file, model_dir + "/model.ckpt-10") return export_dir.encode() exporter.export_saved_model.side_effect = export_saved_model saver_hook = save_utils.NoFirstSaveCheckpointSaverHook( model_dir, save_steps=10000, listeners=[ export_hooks.ExportSaverListener(model_dir + "/model.ckpt", None, exporter) ]) global_step = tf.compat.v1.train.get_or_create_global_step() global_step = tf.compat.v1.assign(global_step, 10) with tf.compat.v1.train.SingularMonitoredSession( hooks=[saver_hook]) as sess: sess.run(global_step) state = export_state_utils.get_export_saver_listener_state(export_dir_base) # One is before_save, one is after_save. self.assertEqual(len(state.entries), 1) entry = state.entries[0] self.assertEqual(entry.export_dir, export_dir) self.assertEqual(entry.global_step, 10) def testExporterReturnsDict(self): model_dir = os.path.join(os.environ["TEST_TMPDIR"], "testExporterReturnsDict") export_dir_base = os.path.join(os.environ["TEST_TMPDIR"], "testBasic_export_dir") exporter = mock.MagicMock() export_dir1 = os.path.join(export_dir_base, "model1/12345678") export_dir2 = os.path.join(export_dir_base, "model2/12345678") os.makedirs(export_dir1) os.makedirs(export_dir2) def export_saved_model(serving_input_receiver_fn, checkpoint_file, global_step): return { "model1": export_dir1.encode(), "model2": export_dir2.encode(), } exporter.export_saved_model.side_effect = export_saved_model saver_hook = save_utils.NoFirstSaveCheckpointSaverHook( model_dir, save_steps=10000, listeners=[ export_hooks.ExportSaverListener(model_dir + "/model.ckpt", None, exporter) ]) global_step = tf.compat.v1.train.get_or_create_global_step() global_step = tf.compat.v1.assign(global_step, 10) with tf.compat.v1.train.SingularMonitoredSession( hooks=[saver_hook]) as sess: sess.run(global_step) def testDeleted(self): model_dir = os.path.join(os.environ["TEST_TMPDIR"], "testDeleted_model_dir") export_dir_base = os.path.join(os.environ["TEST_TMPDIR"], "testDeleted_export_dir") exporter = mock.MagicMock() def export_saved_model(serving_input_receiver_fn, checkpoint_file, global_step): export_dir = os.path.join(export_dir_base, str(time.time())) os.makedirs(export_dir) return export_dir.encode() exporter.export_saved_model.side_effect = export_saved_model global_step = tf.compat.v1.train.get_or_create_global_step() global_step = tf.compat.v1.assign_add(global_step, 1) saver = save_utils.PartialRecoverySaver(tf.compat.v1.global_variables(), sharded=True, max_to_keep=1, keep_checkpoint_every_n_hours=2) saver_hook = save_utils.NoFirstSaveCheckpointSaverHook( model_dir, save_steps=1, saver=saver, listeners=[ export_hooks.ExportSaverListener(model_dir + "/model.ckpt", None, exporter) ]) with tf.compat.v1.train.SingularMonitoredSession( hooks=[saver_hook]) as sess: sess.run(global_step) sess.run(global_step) state = export_state_utils.get_export_saver_listener_state(export_dir_base) # Saved model for step 1 is deleted. self.assertEqual(len(state.entries), 1) entry = state.entries[0] self.assertEqual(entry.global_step, 2) self.assertEqual(len(tf.io.gfile.glob(export_dir_base + "/*.*")), 1) if __name__ == "__main__": tf.compat.v1.disable_eager_execution() tf.test.main() ================================================ FILE: monolith/native_training/model_export/export_state_utils.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import tensorflow as tf from google.protobuf import text_format from monolith.native_training.model_export import export_pb2 _ExportSaverListenerStateFile = "ExportSaverListenerState" def get_export_saver_listener_state( export_dir_base: str) -> export_pb2.ServingModelState: filename = os.path.join(export_dir_base, _ExportSaverListenerStateFile) state = export_pb2.ServingModelState() try: with tf.io.gfile.GFile(filename) as f: text = f.read() text_format.Merge(text, state) except tf.errors.NotFoundError: pass return state def overwrite_export_saver_listener_state(export_dir_base: str, state: export_pb2.ServingModelState): filename = os.path.join(export_dir_base, _ExportSaverListenerStateFile) tmp_name = filename + "-tmp" tf.io.gfile.makedirs(export_dir_base) with tf.io.gfile.GFile(tmp_name, mode="w") as f: text = text_format.MessageToString(state) f.write(text) tf.io.gfile.rename(tmp_name, filename, overwrite=True) ================================================ FILE: monolith/native_training/model_export/export_state_utils_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import unittest from monolith.native_training.model_export import export_state_utils from monolith.native_training.model_export import export_pb2 class ExportStateUtilsTest(unittest.TestCase): def test_basic(self): state = export_pb2.ServingModelState() entry = state.entries.add() entry.export_dir = "a" entry.global_step = 1 dir = os.path.join(os.environ["TEST_TMPDIR"], "basic") export_state_utils.overwrite_export_saver_listener_state(dir, state) new_state = export_state_utils.get_export_saver_listener_state(dir) self.assertEquals(new_state, state) if __name__ == "__main__": unittest.main() ================================================ FILE: monolith/native_training/model_export/export_utils.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Callable import tensorflow as tf from monolith.native_training import nested_tensors from monolith.native_training.model_export import export_context from monolith.native_training import distributed_serving_ops remote_predict = distributed_serving_ops.remote_predict def _get_tensor_signature_name(t: tf.Tensor): return t.name.replace(":", "_") class RemotePredictHelper: def __init__(self, name: str, input_tensors: object, remote_func: Callable[[object], object]): self._name = name self._input_tensors = nested_tensors.NestedTensors(input_tensors) self._remote_func = remote_func self._define_remote_func() def _define_remote_func(self): """Defines the remote func""" self._func_defined = True flat_input_tensors = self._input_tensors.get_tensors() phs = [] for tensor in flat_input_tensors: phs.append( tf.compat.v1.placeholder(dtype=tensor.dtype, shape=tensor.shape, name=_get_tensor_signature_name(tensor) + "_remote_input_ph")) func_input = self._input_tensors.get_nested_result(phs) func_output = self._remote_func(func_input) self._output_tensors = nested_tensors.NestedTensors(func_output) flat_output_tensors = self._output_tensors.get_tensors() self._sig_input = { _get_tensor_signature_name(t): ph for t, ph in zip(flat_input_tensors, phs) } assert len(self._sig_input) == len( flat_input_tensors), f"Name conflicts: {flat_input_tensors}" self._sig_output = { _get_tensor_signature_name(t): t for t in flat_output_tensors } assert len(self._sig_output) == len( flat_output_tensors), f"Name conflicts: {flat_input_tensors}" export_context.get_current_export_ctx().add_signature( tf.compat.v1.get_default_graph(), self._name, self._sig_input, self._sig_output) def call_remote_predict(self, model_name: str, input_tensors: object = None, old_model_name: str = None, task: int = 0): """ Calls the remote function. Args: model_name - the remote model_name that will be used. input_tensors - if None, will use tensors in the __init__ old_model_name & task - A deprecated args to support old remote predict """ flat_input_tensors = None if input_tensors: flat_input_tensors = nested_tensors.NestedTensors( input_tensors).get_tensors() else: flat_input_tensors = self._input_tensors.get_tensors() results = remote_predict( list(self._sig_input.keys()), flat_input_tensors, list(self._sig_output.keys()), model_name, task, old_model_name, output_types=[t.dtype for t in self._sig_output.values()], signature_name=self._name) return self._output_tensors.get_nested_result(results) ================================================ FILE: monolith/native_training/model_export/export_utils_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import tensorflow as tf from monolith.native_training.model_export import export_utils from monolith.native_training.model_export import export_context class ExportUtilsTest(tf.test.TestCase): def testBasic(self): # Currently we only test gramar until we can figure out a way # to compile tensorflow serving here. with export_context.enter_export_mode( export_context.EXPORT_MODE.STANDALONE): def remote_func(d): return d["a"] * 3 + d["b"] * 4 helper = export_utils.RemotePredictHelper("test_func", { "a": tf.constant(1), "b": tf.constant(2) }, remote_func) result = helper.call_remote_predict("model_name") self.assertIsInstance(result, tf.Tensor) if __name__ == "__main__": tf.compat.v1.disable_eager_execution() tf.test.main() ================================================ FILE: monolith/native_training/model_export/saved_model_exporters.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import abc import os import time import contextlib from typing import Callable, Dict, List, Union from absl import logging from google.protobuf.any_pb2 import Any import tensorflow as tf from tensorflow.core.protobuf import meta_graph_pb2 from tensorflow.python.framework import ops from tensorflow.python.saved_model import constants from tensorflow.python.framework import graph_util from tensorflow.python.framework import importer from tensorflow.python.lib.io import file_io from tensorflow.python.saved_model.loader_impl import parse_saved_model from tensorflow_estimator.python.estimator.export import export_lib from monolith.native_training import device_utils from monolith.native_training import hash_table_ops from monolith.native_training import multi_hash_table_ops from monolith.native_training import save_utils from monolith.native_training.distribution_utils import update_session_config_for_gpu from monolith.native_training.model_export import export_context from monolith.native_training.monolith_checkpoint_state_pb2 import MonolithCheckpointState from monolith.native_training.monolith_export import monolith_export from monolith.native_training.model_export.data_gen_utils import gen_warmup_file from monolith.native_training.model_dump.dump_utils import DumpUtils class BaseExporter(abc.ABC): _ASSET_BASE = "ASSET_BASE" def __init__(self, model_fn: Callable, model_dir: str, export_dir_base: str, shared_embedding=False, warmup_file: str = None, export_context_list: List = None): """ Args: model_fn - the model fn which should have (features, mode, config) as args and return a EstimatorSpec shared_embedding - instead of exporting a hermetic SavedModel, we will use the embedding in the checkpoint instead of copying it. warmup_file - the warmup file name. """ self._raw_model_fn = model_fn self._model_dir = model_dir self._export_dir_base = export_dir_base self._shared_embedding = shared_embedding self._warmup_file = warmup_file self._export_context_list = export_context_list if export_context_list is not None else [] @staticmethod def create_asset_base(): """ This method returns a tensor which represents relative path of the assets folder. For example: We have a saved model here: /tmp/${USER}/saved_models/${model_name}/1622840665 If we want to ref an asset with path: /tmp/${USER}/saved_models/${model_name}/1622840665/assets/MonolithHashTable_1 We should use `tf.strings.join([create_asset_base(), "MonolithHashTable_1"])` as the asset path """ try: return tf.compat.v1.get_default_graph().get_tensor_by_name( BaseExporter._ASSET_BASE + ":0") except KeyError: pass asset_dir = "./" asset_base = tf.convert_to_tensor(asset_dir, dtype=tf.string, name=BaseExporter._ASSET_BASE) asset_proto = meta_graph_pb2.AssetFileDef() asset_proto.filename = asset_dir asset_proto.tensor_info.name = asset_base.name asset_any_proto = Any() asset_any_proto.Pack(asset_proto) ops.add_to_collection(constants.ASSETS_KEY, asset_any_proto) return asset_base @staticmethod def add_ckpt_to_assets(ckpt_to_export, pattern="*"): hash_table_ckpts = tf.io.gfile.glob( save_utils.SaveHelper.get_ckpt_asset_dir(ckpt_to_export) + pattern) for hash_table_ckpt in hash_table_ckpts: logging.info(hash_table_ckpt) ops.add_to_collection(tf.compat.v1.GraphKeys.ASSET_FILEPATHS, tf.convert_to_tensor(hash_table_ckpt)) @staticmethod def build_signature(input_tensor_dict, output_tensor_dict): def ensure_tensor_info(maybe_tensor_info): if isinstance(maybe_tensor_info, meta_graph_pb2.TensorInfo): return maybe_tensor_info else: return tf.compat.v1.saved_model.utils.build_tensor_info( maybe_tensor_info) return tf.compat.v1.saved_model.build_signature_def( inputs={k: ensure_tensor_info(v) for k, v in input_tensor_dict.items()}, outputs={ k: ensure_tensor_info(v) for k, v in output_tensor_dict.items() }, method_name=tf.compat.v1.saved_model.signature_constants. PREDICT_METHOD_NAME) def _get_tensor_name(self, name): if name.startswith("^"): return name[1:] else: return name.split(":")[0] def _freeze_dense_graph(self, graph_def, signature_def_map, session): dest_nodes = [] for signature in signature_def_map.values(): for input_item in signature.inputs.values(): dest_nodes.append(self._get_tensor_name(input_item.name)) for output_item in signature.outputs.values(): dest_nodes.append(self._get_tensor_name(output_item.name)) logging.info("freeze output_node_names: {}".format(dest_nodes)) # variable_names_whitelist = [] # for node in graph_def.node: # if node.op == 'VarHandleOp' and node.name.endswith('kernel'): # variable_names_whitelist.append(node.name) # logging.info("freeze list: {}".format(variable_names_whitelist)) variable_names_whitelist = None node_device = {} for node in graph_def.node: node_device[node.name] = node.device forzen_graph_def = graph_util.convert_variables_to_constants( session, graph_def, dest_nodes, variable_names_whitelist=variable_names_whitelist) exists_node_names = set() for node in forzen_graph_def.node: exists_node_names.add(node.name) if node.name in node_device: node.device = node_device[node.name] # for node in graph_def.node: # if len(node.input) == 0 and node.name not in exists_node_names: # forzen_graph_def.node.append(node) return forzen_graph_def def _export_saved_model_from_graph( self, graph: tf.Graph, checkpoint_path: str, export_dir_base: str = None, export_dir: str = None, restore_vars=True, restore_hashtable=True, assign_hashtable=True, export_ctx: export_context.ExportContext = None, export_tags=['serve'], assets_extra=None, clear_devices=False, strip_default_attrs=True) -> bytes: """ Export saved_model from a user constructed graph, a graph can have multiple signatures Signautres are stored in export_ctx.signatures """ assert export_dir or export_dir_base, "must provide export_dir or export_dir_base" if export_ctx is None: export_ctx = export_context.get_current_export_ctx() assert export_ctx is not None if not export_dir: export_dir = export_lib.get_timestamped_export_dir(export_dir_base) temp_export_dir = export_lib.get_temp_export_dir(export_dir) builder = tf.compat.v1.saved_model.Builder(temp_export_dir) with graph.as_default(): tf.compat.v1.train.get_or_create_global_step(graph) signature_def_map = {} # Add signatures collected in the export_context for signature in export_ctx.signatures(graph): signature_def_map[signature.name] = BaseExporter.build_signature( signature.inputs, signature.outputs) if assign_hashtable: # assign signature assign_inputs, assign_outputs = self.build_hashtable_assign_inputs_outputs( ) signature_def_map["hashtable_assign"] = BaseExporter.build_signature( assign_inputs, assign_outputs) self.add_multi_hashtable_assign_signatures(signature_def_map) ''' To export CPU-trained saved_model for GPU serving, it requires explicit GPU device placement at the exporting time. But it raised error here on CPU-only machines while exporting graph with ops explicitly placed on GPU (due to the necessity of loading the whole graph at runtime for calling save_op). So we use the soft placement here, to avoid raising runtime exceptions, but still successfully record the correct GPU placements to the saved_model. ''' session_config = tf.compat.v1.ConfigProto(allow_soft_placement=True) update_session_config_for_gpu(session_config) with tf.compat.v1.Session(config=session_config) as session: graph_saver = tf.compat.v1.train.Saver(sharded=True) if restore_vars: try: graph_saver.restore(session, checkpoint_path) except tf.errors.NotFoundError as e: msg = ('Could not load all requested variables from checkpoint. ' 'Please make sure your model_fn does not expect variables ' 'that were not saved in the checkpoint.\n\n' 'Encountered error with mode `{}` while restoring ' 'checkpoint from: `{}`. Full Traceback:\n\n{}').format( tf.estimator.ModeKeys.PREDICT, checkpoint_path, e) raise ValueError(msg) restore_op = None if restore_hashtable: restore_op = tf.group( self.create_multi_hashtable_restore_ops(checkpoint_path) + self.create_hashtable_restore_ops(checkpoint_path)) restore_op = restore_op or tf.no_op() meta_graph_kwargs = dict(tags=export_tags, signature_def_map=signature_def_map, assets_collection=tf.compat.v1.get_collection( tf.compat.v1.GraphKeys.ASSET_FILEPATHS), clear_devices=clear_devices, main_op=restore_op, saver=graph_saver, strip_default_attrs=strip_default_attrs) builder.add_meta_graph_and_variables(session, **meta_graph_kwargs) builder.save() # Add the extra assets if assets_extra: assets_extra_path = os.path.join(tf.compat.as_bytes(temp_export_dir), tf.compat.as_bytes('assets.extra')) for dest_relative, source in assets_extra.items(): dest_absolute = os.path.join(tf.compat.as_bytes(assets_extra_path), tf.compat.as_bytes(dest_relative)) dest_path = os.path.dirname(dest_absolute) tf.compat.v1.gfile.MakeDirs(dest_path) tf.compat.v1.gfile.Copy(source, dest_absolute) tf.io.gfile.rename(temp_export_dir, export_dir) return export_dir if isinstance(export_dir, bytes) else export_dir.encode() def _export_frozen_saved_model_from_graph( self, graph: tf.Graph, checkpoint_path: str, export_dir_base: str = None, export_dir: str = None, restore_vars=True, export_ctx: export_context.ExportContext = None, export_tags=['serve'], assets_extra=None, clear_devices=False, strip_default_attrs=True) -> bytes: """ Export saved_model from a user constructed graph, a graph can have multiple signatures Signautres are stored in export_ctx.signatures """ assert export_dir or export_dir_base, "must provide export_dir or export_dir_base" if export_ctx is None: export_ctx = export_context.get_current_export_ctx() assert export_ctx is not None if not export_dir: export_dir = export_lib.get_timestamped_export_dir(export_dir_base) temp_export_dir = export_lib.get_temp_export_dir(export_dir) frozen_graph_def = None with graph.as_default(): tf.compat.v1.train.get_or_create_global_step(graph) signature_def_map = {} # Add signatures collected in the export_context for signature in export_ctx.signatures(graph): signature_def_map[signature.name] = BaseExporter.build_signature( signature.inputs, signature.outputs) session_config = tf.compat.v1.ConfigProto(allow_soft_placement=True) update_session_config_for_gpu(session_config) with tf.compat.v1.Session(config=session_config) as session: graph_saver = tf.compat.v1.train.Saver(sharded=True) if restore_vars: try: graph_saver.restore(session, checkpoint_path) except tf.errors.NotFoundError as e: msg = ('Could not load all requested variables from checkpoint. ' 'Please make sure your model_fn does not expect variables ' 'that were not saved in the checkpoint.\n\n' 'Encountered error with mode `{}` while restoring ' 'checkpoint from: `{}`. Full Traceback:\n\n{}').format( tf.estimator.ModeKeys.PREDICT, checkpoint_path, e) raise ValueError(msg) frozen_graph_def = self._freeze_dense_graph(graph.as_graph_def(), signature_def_map, session) builder = tf.compat.v1.saved_model.Builder(temp_export_dir) with tf.Graph().as_default() as final_graph: tf.graph_util.import_graph_def(frozen_graph_def, name='') tf.compat.v1.train.get_or_create_global_step(final_graph) session_config = tf.compat.v1.ConfigProto(allow_soft_placement=True) update_session_config_for_gpu(session_config) with tf.compat.v1.Session(config=session_config) as session: graph_saver = tf.compat.v1.train.Saver(sharded=True) if restore_vars: try: graph_saver.restore(session, checkpoint_path) except tf.errors.NotFoundError as e: msg = ('Could not load all requested variables from checkpoint. ' 'Please make sure your model_fn does not expect variables ' 'that were not saved in the checkpoint.\n\n' 'Encountered error with mode `{}` while restoring ' 'checkpoint from: `{}`. Full Traceback:\n\n{}').format( tf.estimator.ModeKeys.PREDICT, checkpoint_path, e) raise ValueError(msg) restore_op = tf.no_op() meta_graph_kwargs = dict(tags=export_tags, signature_def_map=signature_def_map, assets_collection=tf.compat.v1.get_collection( tf.compat.v1.GraphKeys.ASSET_FILEPATHS), clear_devices=clear_devices, main_op=restore_op, saver=graph_saver, strip_default_attrs=strip_default_attrs) builder.add_meta_graph_and_variables(session, **meta_graph_kwargs) builder.save() # Add the extra assets if assets_extra: assets_extra_path = os.path.join(tf.compat.as_bytes(temp_export_dir), tf.compat.as_bytes('assets.extra')) for dest_relative, source in assets_extra.items(): dest_absolute = os.path.join(tf.compat.as_bytes(assets_extra_path), tf.compat.as_bytes(dest_relative)) dest_path = os.path.dirname(dest_absolute) tf.compat.v1.gfile.MakeDirs(dest_path) tf.compat.v1.gfile.Copy(source, dest_absolute) tf.io.gfile.rename(temp_export_dir, export_dir) return export_dir if isinstance(export_dir, bytes) else export_dir.encode() def create_hashtable_restore_ops(self, checkpoint_path): """ Find all the hashtables in the current graph and create restore_op for them. When shared_embedding is False, it adds the hashtable ckpt files into the assets folder """ ckpt_asset_base = save_utils.SaveHelper.get_ckpt_asset_dir(checkpoint_path) restore_ops = [] for table in ops.get_collection(hash_table_ops._HASH_TABLE_GRAPH_KEY): tensor_prefix = hash_table_ops._table_tensor_prefix(table) share_embedding = self._shared_embedding if table.export_share_embedding is None else table.export_share_embedding if not share_embedding: BaseExporter.add_ckpt_to_assets(checkpoint_path, pattern=tensor_prefix + "*") asset_base = BaseExporter.create_asset_base() else: asset_base = ckpt_asset_base table_prefix = tf.strings.join([asset_base, tensor_prefix]) table_prefix = tf.strings.join( [asset_base, hash_table_ops._table_tensor_prefix(table)]) restore_ops.append(table.restore(table_prefix).as_op()) return restore_ops def create_multi_hashtable_restore_ops(self, checkpoint_path): """ Find all the multi-hashtables in the current graph and create restore_op for them. When shared_embedding is False, it adds the hashtable ckpt files into the assets folder """ ckpt_asset_base = save_utils.SaveHelper.get_ckpt_asset_dir(checkpoint_path) restore_ops = [] for table in ops.get_collection( multi_hash_table_ops._MULTI_HASH_TABLE_GRAPH_KEY): if not self._shared_embedding: BaseExporter.add_ckpt_to_assets(checkpoint_path, pattern=table.shared_name + "*") asset_base = BaseExporter.create_asset_base() else: asset_base = ckpt_asset_base table_basename = tf.strings.join([asset_base, table.shared_name]) with tf.control_dependencies([table.initializer]): restore_op = table.restore(basename=table_basename).as_op() restore_ops.append(restore_op) return restore_ops def build_hashtable_assign_inputs_outputs(self): """ For all hashtables in the current graph, create assign tensors for them """ assign_input_tensors, assign_output_tensors = {}, {} for table in ops.get_collection(hash_table_ops._HASH_TABLE_GRAPH_KEY): assign_id = tf.compat.v1.placeholder(dtype=tf.int64, shape=(None,)) assign_value = tf.compat.v1.placeholder(dtype=tf.float32, shape=(None, table.dim_size)) assign_input_tensors[table.name + "_id"] = assign_id assign_input_tensors[table.name + "_value"] = assign_value updated_table = table.assign(assign_id, assign_value) with tf.control_dependencies(control_inputs=[updated_table.as_op()]): # The size of id tensor is returned here as a dummy value assign_output_tensors[table.name + "_result"] = tf.size(assign_id) return assign_input_tensors, assign_output_tensors def add_multi_hashtable_assign_signatures(self, signature_def_map: Dict): """ For all hashtables in the current graph, create assign tensors for them """ for table in ops.get_collection( multi_hash_table_ops._MULTI_HASH_TABLE_GRAPH_KEY): name = table.shared_name + "/raw_assign" assert name not in signature_def_map, f"{name} has already been defined in signature" input_tensors, output_tensors = dict(), dict() id = tf.compat.v1.placeholder(dtype=tf.int64, shape=(None,)) input_tensors["id"] = id id_split = tf.compat.v1.placeholder(dtype=tf.int64, shape=(None,)) input_tensors["id_split"] = id_split flat_value = tf.compat.v1.placeholder(dtype=tf.float32, shape=(None,)) input_tensors["flat_value"] = flat_value assign_op = table.raw_assign( tf.RaggedTensor.from_row_splits(id, id_split), flat_value).as_op() with tf.control_dependencies([assign_op]): dummy_tensor = tf.constant(0) output_tensors["result"] = dummy_tensor signature_def_map[name] = self.build_signature(input_tensors, output_tensors) def _model_fn_with_input_reveiver(self, serving_input_receiver_fn): input_receiver = serving_input_receiver_fn() estimator_spec = self._raw_model_fn(input_receiver.features, mode=tf.estimator.ModeKeys.PREDICT, config=tf.estimator.RunConfig( self._model_dir)) export_outputs = export_lib.get_export_outputs( estimator_spec.export_outputs, estimator_spec.predictions) signature_def_map = export_lib.build_all_signature_defs( input_receiver.receiver_tensors, export_outputs, getattr(input_receiver, 'receiver_tensors_alternatives', None), serving_only=True) for signature_name, signature in signature_def_map.items(): export_context.get_current_export_ctx().add_signature( tf.compat.v1.get_default_graph(), signature_name, signature.inputs, signature.outputs) @abc.abstractmethod def export_saved_model(self, serving_input_receiver_fn, checkpoint_path=None, global_step=None) -> Union[bytes, Dict[str, bytes]]: """ Export the saved model and returns the path of exported model. Args: checkpoint_path - If None, the latest one will be used. """ pass def gen_warmup_assets(self) -> Dict[str, str]: if not self._warmup_file: return None if not tf.io.gfile.exists(self._warmup_file): try: flag = gen_warmup_file(self._warmup_file) if flag is None: return None else: return {'tf_serving_warmup_requests': self._warmup_file} except Exception as e: logging.error(str(e)) return None else: return {'tf_serving_warmup_requests': self._warmup_file} @monolith_export class StandaloneExporter(BaseExporter): """单机模式的saved model导出器 Args: model_fn: 和tf.estimator兼容的model_fn, 以(features, mode, config)作为参数并且返回EstimatorSpec model_dir: 保存checkpoint的目录 export_dir_base: 导出saved_model的目标路径 shared_embedding: 是否复用checkpoint中的 embedding 文件, False的话会将embedding文件拷贝至saved_model, 可能会降低导出速度 warmup_file: warmup文件, 参考 https://www.tensorflow.org/tfx/serving/saved_model_warmup """ def __init__(self, model_fn: Callable, model_dir: str, export_dir_base: str, shared_embedding=False, warmup_file: str = None, export_context_list: List = None): super(StandaloneExporter, self).__init__(model_fn, model_dir, export_dir_base, shared_embedding, warmup_file, export_context_list) def export_saved_model(self, serving_input_receiver_fn, checkpoint_path=None, global_step=None): """ 导出saved_model Args: serving_input_receiver_fn: 返回 tf.estimator.export.ServingInputReceiver 的函数, 用来将serving 请求映射到模型输入 checkpoint_path: 可选的checkpoint路径, 为空则使用tf.train.latest_checkpoint(self._model_dir) """ if not checkpoint_path: checkpoint_path = tf.train.latest_checkpoint(self._model_dir) with export_context.enter_export_mode( export_context.ExportMode.STANDALONE ) as export_ctx, contextlib.ExitStack() as stack: other_contexts = [ stack.enter_context(ctx()) for ctx in self._export_context_list ] saved_tf_config = os.environ.pop("TF_CONFIG", None) try: with tf.Graph().as_default() as g: self._model_fn_with_input_reveiver(serving_input_receiver_fn) return self._export_saved_model_from_graph( g, checkpoint_path=checkpoint_path, export_dir_base=self._export_dir_base, export_ctx=export_ctx, assets_extra=self.gen_warmup_assets()) finally: if saved_tf_config: os.environ["TF_CONFIG"] = saved_tf_config @monolith_export class DistributedExporter(BaseExporter): """分布式模型导出器 Args: model_fn: 和tf.estimator兼容的model_fn, 以(features, mode, config)作为参数并且返回EstimatorSpec model_dir: 保存checkpoint的目录 export_dir_base: 导出saved_model的目标路径 shared_embedding: 是否复用checkpoint中的 embedding 文件, False的话会将embedding文件拷贝至saved_model, 可能会降低导出速度 warmup_file: warmup文件, 参考 https://www.tensorflow.org/tfx/serving/saved_model_warmup include_graphs: Only export saved_models from include_graphs if the param not None, otherwise export all graphs in export context global_step_as_timestamp: whether to use use global_step export folder name, useful when we do parallel export in sync_training """ def __init__(self, model_fn: Callable, model_dir: str, export_dir_base: str, shared_embedding=False, warmup_file: str = None, export_context_list: List = None, dense_only=False, allow_gpu=False, with_remote_gpu=False, clear_entry_devices=False, include_graphs: List[str] = None, global_step_as_timestamp: bool = False, freeze_variable: bool = True): super(DistributedExporter, self).__init__(model_fn, model_dir, export_dir_base, shared_embedding, warmup_file, export_context_list) self._dense_only = dense_only self._allow_gpu = allow_gpu self._with_remote_gpu = with_remote_gpu self._clear_entry_devices = clear_entry_devices self._include_graphs = include_graphs self._global_step_as_timestamp = global_step_as_timestamp self._freeze_variable = freeze_variable def _should_export(self, graph_name, export_dir): if tf.io.gfile.exists(export_dir): logging.info("skipping duplicated model exportings") return False return self._include_graphs is None or graph_name in self._include_graphs def export_saved_model(self, serving_input_receiver_fn, checkpoint_path=None, global_step=None): """ 导出saved_model Args: serving_input_receiver_fn: 返回 tf.estimator.export.ServingInputReceiver 的函数, 用来将serving 请求映射到模型输入 checkpoint_path: 可选的checkpoint路径, 为空则使用tf.train.latest_checkpoint(self._model_dir) """ if not checkpoint_path: checkpoint_path = tf.train.latest_checkpoint(self._model_dir) export_ctx = export_context.ExportContext( with_remote_gpu=self._with_remote_gpu) with export_context.enter_export_mode( export_context.ExportMode.DISTRIBUTED, export_ctx), contextlib.ExitStack() as stack: other_contexts = [ stack.enter_context(ctx()) for ctx in self._export_context_list ] saved_tf_config = os.environ.pop("TF_CONFIG", None) result = {} try: # Run model fn and export entry part if self._allow_gpu: device_utils.enable_gpu_training() with tf.Graph().as_default() as g, g.device( device_utils.default_device_fn): self._model_fn_with_input_reveiver(serving_input_receiver_fn) if global_step and self._global_step_as_timestamp: timestamp = str(global_step) entry_export_dir = os.path.join(self._export_dir_base, "entry", timestamp) else: entry_export_dir = export_lib.get_timestamped_export_dir( os.path.join(self._export_dir_base, "entry")).decode() timestamp = os.path.basename(entry_export_dir) if self._should_export("entry", entry_export_dir): result["entry"] = self._export_saved_model_from_graph( g, checkpoint_path=checkpoint_path, export_dir=entry_export_dir, export_ctx=export_ctx, restore_hashtable=False, assign_hashtable=False, assets_extra=self.gen_warmup_assets(), clear_devices=self._clear_entry_devices, ) # Export additional dense graph stored in export_ctx DumpUtils().restore_sub_model('dense') for name, graph in export_ctx.dense_sub_graphs.items(): DumpUtils().add_sub_model('dense', name, graph) ps_export_dir = os.path.join( self._export_dir_base, name, str(timestamp) + getattr(graph, "export_suffix", "")) if not self._should_export(name, ps_export_dir): continue result[name] = self._export_saved_model_from_graph( graph, checkpoint_path=checkpoint_path, export_dir=ps_export_dir, export_ctx=export_ctx, ) if self._dense_only: return result # Export PS from graph stored in export_ctx DumpUtils().restore_sub_model('ps') for name, graph in export_ctx.sub_graphs.items(): if name.startswith("dense"): continue DumpUtils().add_sub_model('ps', name, graph) ps_export_dir = os.path.join( self._export_dir_base, name, str(timestamp) + getattr(graph, "export_suffix", "")) if not self._should_export(name, ps_export_dir): continue result[name] = self._export_saved_model_from_graph( graph, checkpoint_path=checkpoint_path, export_dir=ps_export_dir, export_ctx=export_ctx, ) # Export GPU Dense from graph stored in export_ctx for name, graph in export_ctx.sub_graphs.items(): if not name.startswith("dense"): continue DumpUtils().add_sub_model('ps', name, graph) ps_export_dir = os.path.join( self._export_dir_base, name, str(timestamp) + getattr(graph, "export_suffix", "")) if not self._should_export(name, ps_export_dir): continue if not self._freeze_variable: result[name] = self._export_saved_model_from_graph( graph, checkpoint_path=checkpoint_path, export_dir=ps_export_dir, export_ctx=export_ctx, restore_hashtable=False, assign_hashtable=False) else: result[name] = self._export_frozen_saved_model_from_graph( graph, checkpoint_path=checkpoint_path, export_dir=ps_export_dir, export_ctx=export_ctx) finally: if saved_tf_config: os.environ["TF_CONFIG"] = saved_tf_config return result ================================================ FILE: monolith/native_training/model_export/saved_model_exporters_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import tensorflow as tf from monolith.native_training import hash_table_ops from monolith.native_training import test_utils from monolith.native_training import multi_hash_table_ops from monolith.native_training.model_export import export_context from monolith.native_training.model_export import saved_model_exporters def input_fn(): return tf.data.Dataset.from_tensor_slices([1]).repeat() class ModelFnCreator: def __init__(self): self._called_in_exported_mode = False @property def called_in_exported_mode(self): return self._called_in_exported_mode def create_model_fn(self): def model_fn(features, mode, config): if export_context.EXPORT_MODE != None: self._called_in_exported_mode = True table = hash_table_ops.test_hash_table(2) mtable = multi_hash_table_ops.MultiHashTable.from_configs( configs={"test": test_utils.generate_test_hash_table_config(1)}) global_step = tf.compat.v1.train.get_or_create_global_step() if mode == tf.estimator.ModeKeys.PREDICT: output_tensor = table.lookup([0]) output = tf.estimator.export.PredictOutput(output_tensor) moutput = tf.estimator.export.PredictOutput( mtable.lookup({"test": [0]})["test"]) return tf.estimator.EstimatorSpec( mode=mode, predictions=output_tensor, export_outputs={ tf.compat.v1.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: output, "table/lookup": output, "mtable/lookup": moutput, }) add_op = table.assign_add([0], [[1, 2]]).as_op() add_op2 = mtable.assign_add({"test": ([0], [[1]])}).as_op() global_step = tf.compat.v1.assign_add(global_step, 1) with tf.control_dependencies([global_step]): print_op = tf.print("tensor value:", table.lookup([0])) print_op2 = tf.print("mtable tensor value: ", mtable.lookup({"test": [0]})) ckpt_prefix = config.model_dir + "/model.ckpt" return tf.estimator.EstimatorSpec( mode=mode, train_op=tf.group([global_step, add_op, add_op2, print_op, print_op2]), training_hooks=[ tf.estimator.CheckpointSaverHook( config.model_dir, save_steps=1000, listeners=[ hash_table_ops.HashTableCheckpointSaverListener( ckpt_prefix), multi_hash_table_ops. MultiHashTableCheckpointSaverListener(ckpt_prefix), ]) ], loss=tf.constant(0.0)) return model_fn def dummy_input_receiver_fn(): return tf.estimator.export.ServingInputReceiver({}, tf.compat.v1.placeholder( tf.string)) class SavedModelExportersTest(tf.test.TestCase): def setUp(self): self._model_dir = os.path.join(os.environ["TEST_TMPDIR"], self._testMethodName + "_model_dir") self._export_dir_base = os.path.join(os.environ["TEST_TMPDIR"], self._testMethodName + "_export_dir") def run_pred(self, export_path, key=tf.compat.v1.saved_model.signature_constants. DEFAULT_SERVING_SIGNATURE_DEF_KEY): g = tf.Graph() with g.as_default(), self.session() as sess: imported = tf.compat.v1.saved_model.load( sess, {tf.compat.v1.saved_model.tag_constants.SERVING}, export_path) pred_name = imported.signature_def[key].outputs["output"].name pred = g.get_tensor_by_name(pred_name) return sess.run(pred) def testBasic(self): creator = ModelFnCreator() est = tf.estimator.Estimator(creator.create_model_fn(), model_dir=self._model_dir) # Train twice so we guarantee there are 2 ckpts. est.train(input_fn, steps=1) exporter = saved_model_exporters.StandaloneExporter( creator.create_model_fn(), self._model_dir, self._export_dir_base) export_path = exporter.export_saved_model(dummy_input_receiver_fn) self.assertAllEqual(self.run_pred(export_path), [[1, 2]]) self.assertAllEqual(self.run_pred(export_path, "mtable/lookup"), [[1]]) self.assertTrue(creator.called_in_exported_mode) # TODO(leqi.zou) : Add test case for checkpoint_path is not None def testSharedEmebdding(self): creator = ModelFnCreator() est = tf.estimator.Estimator(creator.create_model_fn(), model_dir=self._model_dir) est.train(input_fn, steps=1) exporter = saved_model_exporters.StandaloneExporter( creator.create_model_fn(), self._model_dir, self._export_dir_base, shared_embedding=True) export_path = exporter.export_saved_model(dummy_input_receiver_fn) self.assertAllEqual(self.run_pred(export_path), [[1, 2]]) self.assertAllEqual(self.run_pred(export_path, "mtable/lookup"), [[1]]) # TODO(leqi.zou): Add more tests for the distributed hash tables. if __name__ == "__main__": tf.compat.v1.disable_eager_execution() tf.test.main() ================================================ FILE: monolith/native_training/model_export/saved_model_visulizer.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Imports a protobuf model as a graph in Tensorboard.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import argparse import sys import tensorflow as tf from tensorflow.core.protobuf import saved_model_pb2 from tensorflow.python.client import session from tensorflow.python.framework import importer from tensorflow.python.framework import ops from tensorflow.python.platform import app from tensorflow.python.platform import gfile from tensorflow.python.summary import summary from tensorflow.python.util import compat from monolith.native_training.runtime.ops import gen_monolith_ops hash_table_ops = gen_monolith_ops gen_distribution_ops = gen_monolith_ops def import_to_tensorboard(model_dir, log_dir): """View an imported protobuf model (`.pb` file) as a graph in Tensorboard. Args: model_dir: The location of the protobuf (`pb`) model to visualize log_dir: The location for the Tensorboard log to begin visualization from. Usage: Call this function with your model location and desired log directory. Launch Tensorboard by pointing it to the log directory. View your imported `.pb` model as a graph. """ with session.Session(graph=ops.Graph()) as sess: with gfile.FastGFile(model_dir, "rb") as f: data = compat.as_bytes(f.read()) sm = saved_model_pb2.SavedModel() sm.ParseFromString(data) if 1 != len(sm.meta_graphs): print('More than one graph found. Not sure which to write') sys.exit(1) importer.import_graph_def(sm.meta_graphs[0].graph_def) pb_visual_writer = summary.FileWriter(log_dir) pb_visual_writer.add_graph(sess.graph) print("Model Imported. Visualize by running: " "tensorboard --logdir={} --bind_all".format(log_dir)) def main(unused_args): import_to_tensorboard(FLAGS.model_dir, FLAGS.log_dir) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.register("type", "bool", lambda v: v.lower() == "true") parser.add_argument( "--model_dir", type=str, default="", required=True, help="The location of the protobuf (\'pb\') model to visualize.") parser.add_argument( "--log_dir", type=str, default="", required=True, help="The location for the Tensorboard log to begin visualization from.") FLAGS, unparsed = parser.parse_known_args() app.run(main=main, argv=[sys.argv[0]] + unparsed) ================================================ FILE: monolith/native_training/model_export/testdata/BUILD ================================================ exports_files(["saved_model"]) ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_0/1622716114/assets/MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_0-00002-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_0/1622716114/assets/MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_0-00003-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_0/1622716114/assets/MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_1-00003-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_0/1622716114/assets/MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_2-00000-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_0/1622716114/assets/MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_2-00003-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_0/1622716114/assets/MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_3-00000-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_0/1622716114/assets/MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_3-00001-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_0/1622716114/assets/MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_4-00001-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_0/1622716114/assets/MonolithHashTable_3fc25c64637605aa3983374cc61db982_0-00000-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_0/1622716114/assets/MonolithHashTable_3fc25c64637605aa3983374cc61db982_0-00001-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_0/1622716114/assets/MonolithHashTable_3fc25c64637605aa3983374cc61db982_0-00002-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_0/1622716114/assets/MonolithHashTable_3fc25c64637605aa3983374cc61db982_1-00001-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_0/1622716114/assets/MonolithHashTable_3fc25c64637605aa3983374cc61db982_1-00002-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_0/1622716114/assets/MonolithHashTable_3fc25c64637605aa3983374cc61db982_1-00003-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_0/1622716114/assets/MonolithHashTable_3fc25c64637605aa3983374cc61db982_2-00000-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_0/1622716114/assets/MonolithHashTable_3fc25c64637605aa3983374cc61db982_2-00002-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_0/1622716114/assets/MonolithHashTable_3fc25c64637605aa3983374cc61db982_2-00003-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_0/1622716114/assets/MonolithHashTable_3fc25c64637605aa3983374cc61db982_3-00000-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_0/1622716114/assets/MonolithHashTable_3fc25c64637605aa3983374cc61db982_3-00002-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_0/1622716114/assets/MonolithHashTable_3fc25c64637605aa3983374cc61db982_3-00003-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_0/1622716114/assets/MonolithHashTable_3fc25c64637605aa3983374cc61db982_4-00000-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_0/1622716114/assets/MonolithHashTable_3fc25c64637605aa3983374cc61db982_4-00001-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_0/1622716114/assets/MonolithHashTable_3fc25c64637605aa3983374cc61db982_4-00003-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_0/1622716114/assets/MonolithHashTable_e3997af6324e55640d4611001fa3a15b_0-00000-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_0/1622716114/assets/MonolithHashTable_e3997af6324e55640d4611001fa3a15b_0-00002-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_0/1622716114/assets/MonolithHashTable_e3997af6324e55640d4611001fa3a15b_0-00003-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_0/1622716114/assets/MonolithHashTable_e3997af6324e55640d4611001fa3a15b_1-00000-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_0/1622716114/assets/MonolithHashTable_e3997af6324e55640d4611001fa3a15b_1-00002-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_0/1622716114/assets/MonolithHashTable_e3997af6324e55640d4611001fa3a15b_1-00003-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_0/1622716114/assets/MonolithHashTable_e3997af6324e55640d4611001fa3a15b_2-00000-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_0/1622716114/assets/MonolithHashTable_e3997af6324e55640d4611001fa3a15b_2-00001-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_0/1622716114/assets/MonolithHashTable_e3997af6324e55640d4611001fa3a15b_2-00003-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_0/1622716114/assets/MonolithHashTable_e3997af6324e55640d4611001fa3a15b_3-00000-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_0/1622716114/assets/MonolithHashTable_e3997af6324e55640d4611001fa3a15b_3-00001-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_0/1622716114/assets/MonolithHashTable_e3997af6324e55640d4611001fa3a15b_3-00002-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_0/1622716114/assets/MonolithHashTable_e3997af6324e55640d4611001fa3a15b_4-00000-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_0/1622716114/assets/MonolithHashTable_e3997af6324e55640d4611001fa3a15b_4-00001-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_0/1622716114/assets/MonolithHashTable_e3997af6324e55640d4611001fa3a15b_4-00002-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_0/1622716114/assets/MonolithHashTable_f6962510869b682fd764009be0e4e9c3_0-00000-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_0/1622716114/assets/MonolithHashTable_f6962510869b682fd764009be0e4e9c3_0-00001-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_0/1622716114/assets/MonolithHashTable_f6962510869b682fd764009be0e4e9c3_0-00002-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_0/1622716114/assets/MonolithHashTable_f6962510869b682fd764009be0e4e9c3_1-00000-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_0/1622716114/assets/MonolithHashTable_f6962510869b682fd764009be0e4e9c3_1-00001-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_0/1622716114/assets/MonolithHashTable_f6962510869b682fd764009be0e4e9c3_1-00002-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_0/1622716114/assets/MonolithHashTable_f6962510869b682fd764009be0e4e9c3_2-00001-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_0/1622716114/assets/MonolithHashTable_f6962510869b682fd764009be0e4e9c3_2-00002-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_0/1622716114/assets/MonolithHashTable_f6962510869b682fd764009be0e4e9c3_2-00003-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_0/1622716114/assets/MonolithHashTable_f6962510869b682fd764009be0e4e9c3_3-00000-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_0/1622716114/assets/MonolithHashTable_f6962510869b682fd764009be0e4e9c3_3-00002-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_0/1622716114/assets/MonolithHashTable_f6962510869b682fd764009be0e4e9c3_3-00003-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_0/1622716114/assets/MonolithHashTable_f6962510869b682fd764009be0e4e9c3_4-00000-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_0/1622716114/assets/MonolithHashTable_f6962510869b682fd764009be0e4e9c3_4-00001-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_0/1622716114/assets/MonolithHashTable_f6962510869b682fd764009be0e4e9c3_4-00003-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_1/1622716114/assets/MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_0-00002-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_1/1622716114/assets/MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_0-00003-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_1/1622716114/assets/MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_1-00003-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_1/1622716114/assets/MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_2-00000-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_1/1622716114/assets/MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_2-00003-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_1/1622716114/assets/MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_3-00000-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_1/1622716114/assets/MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_3-00001-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_1/1622716114/assets/MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_4-00001-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_1/1622716114/assets/MonolithHashTable_3fc25c64637605aa3983374cc61db982_0-00000-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_1/1622716114/assets/MonolithHashTable_3fc25c64637605aa3983374cc61db982_0-00001-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_1/1622716114/assets/MonolithHashTable_3fc25c64637605aa3983374cc61db982_0-00002-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_1/1622716114/assets/MonolithHashTable_3fc25c64637605aa3983374cc61db982_1-00001-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_1/1622716114/assets/MonolithHashTable_3fc25c64637605aa3983374cc61db982_1-00002-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_1/1622716114/assets/MonolithHashTable_3fc25c64637605aa3983374cc61db982_1-00003-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_1/1622716114/assets/MonolithHashTable_3fc25c64637605aa3983374cc61db982_2-00000-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_1/1622716114/assets/MonolithHashTable_3fc25c64637605aa3983374cc61db982_2-00002-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_1/1622716114/assets/MonolithHashTable_3fc25c64637605aa3983374cc61db982_2-00003-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_1/1622716114/assets/MonolithHashTable_3fc25c64637605aa3983374cc61db982_3-00000-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_1/1622716114/assets/MonolithHashTable_3fc25c64637605aa3983374cc61db982_3-00002-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_1/1622716114/assets/MonolithHashTable_3fc25c64637605aa3983374cc61db982_3-00003-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_1/1622716114/assets/MonolithHashTable_3fc25c64637605aa3983374cc61db982_4-00000-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_1/1622716114/assets/MonolithHashTable_3fc25c64637605aa3983374cc61db982_4-00001-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_1/1622716114/assets/MonolithHashTable_3fc25c64637605aa3983374cc61db982_4-00003-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_1/1622716114/assets/MonolithHashTable_e3997af6324e55640d4611001fa3a15b_0-00000-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_1/1622716114/assets/MonolithHashTable_e3997af6324e55640d4611001fa3a15b_0-00002-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_1/1622716114/assets/MonolithHashTable_e3997af6324e55640d4611001fa3a15b_0-00003-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_1/1622716114/assets/MonolithHashTable_e3997af6324e55640d4611001fa3a15b_1-00000-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_1/1622716114/assets/MonolithHashTable_e3997af6324e55640d4611001fa3a15b_1-00002-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_1/1622716114/assets/MonolithHashTable_e3997af6324e55640d4611001fa3a15b_1-00003-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_1/1622716114/assets/MonolithHashTable_e3997af6324e55640d4611001fa3a15b_2-00000-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_1/1622716114/assets/MonolithHashTable_e3997af6324e55640d4611001fa3a15b_2-00001-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_1/1622716114/assets/MonolithHashTable_e3997af6324e55640d4611001fa3a15b_2-00003-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_1/1622716114/assets/MonolithHashTable_e3997af6324e55640d4611001fa3a15b_3-00000-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_1/1622716114/assets/MonolithHashTable_e3997af6324e55640d4611001fa3a15b_3-00001-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_1/1622716114/assets/MonolithHashTable_e3997af6324e55640d4611001fa3a15b_3-00002-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_1/1622716114/assets/MonolithHashTable_e3997af6324e55640d4611001fa3a15b_4-00000-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_1/1622716114/assets/MonolithHashTable_e3997af6324e55640d4611001fa3a15b_4-00001-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_1/1622716114/assets/MonolithHashTable_e3997af6324e55640d4611001fa3a15b_4-00002-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_1/1622716114/assets/MonolithHashTable_f6962510869b682fd764009be0e4e9c3_0-00000-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_1/1622716114/assets/MonolithHashTable_f6962510869b682fd764009be0e4e9c3_0-00001-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_1/1622716114/assets/MonolithHashTable_f6962510869b682fd764009be0e4e9c3_0-00002-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_1/1622716114/assets/MonolithHashTable_f6962510869b682fd764009be0e4e9c3_1-00000-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_1/1622716114/assets/MonolithHashTable_f6962510869b682fd764009be0e4e9c3_1-00001-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_1/1622716114/assets/MonolithHashTable_f6962510869b682fd764009be0e4e9c3_1-00002-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_1/1622716114/assets/MonolithHashTable_f6962510869b682fd764009be0e4e9c3_2-00001-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_1/1622716114/assets/MonolithHashTable_f6962510869b682fd764009be0e4e9c3_2-00002-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_1/1622716114/assets/MonolithHashTable_f6962510869b682fd764009be0e4e9c3_2-00003-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_1/1622716114/assets/MonolithHashTable_f6962510869b682fd764009be0e4e9c3_3-00000-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_1/1622716114/assets/MonolithHashTable_f6962510869b682fd764009be0e4e9c3_3-00002-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_1/1622716114/assets/MonolithHashTable_f6962510869b682fd764009be0e4e9c3_3-00003-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_1/1622716114/assets/MonolithHashTable_f6962510869b682fd764009be0e4e9c3_4-00000-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_1/1622716114/assets/MonolithHashTable_f6962510869b682fd764009be0e4e9c3_4-00001-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_1/1622716114/assets/MonolithHashTable_f6962510869b682fd764009be0e4e9c3_4-00003-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_2/1622716114/assets/MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_0-00002-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_2/1622716114/assets/MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_0-00003-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_2/1622716114/assets/MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_1-00003-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_2/1622716114/assets/MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_2-00000-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_2/1622716114/assets/MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_2-00003-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_2/1622716114/assets/MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_3-00000-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_2/1622716114/assets/MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_3-00001-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_2/1622716114/assets/MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_4-00001-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_2/1622716114/assets/MonolithHashTable_3fc25c64637605aa3983374cc61db982_0-00000-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_2/1622716114/assets/MonolithHashTable_3fc25c64637605aa3983374cc61db982_0-00001-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_2/1622716114/assets/MonolithHashTable_3fc25c64637605aa3983374cc61db982_0-00002-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_2/1622716114/assets/MonolithHashTable_3fc25c64637605aa3983374cc61db982_1-00001-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_2/1622716114/assets/MonolithHashTable_3fc25c64637605aa3983374cc61db982_1-00002-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_2/1622716114/assets/MonolithHashTable_3fc25c64637605aa3983374cc61db982_1-00003-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_2/1622716114/assets/MonolithHashTable_3fc25c64637605aa3983374cc61db982_2-00000-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_2/1622716114/assets/MonolithHashTable_3fc25c64637605aa3983374cc61db982_2-00002-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_2/1622716114/assets/MonolithHashTable_3fc25c64637605aa3983374cc61db982_2-00003-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_2/1622716114/assets/MonolithHashTable_3fc25c64637605aa3983374cc61db982_3-00000-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_2/1622716114/assets/MonolithHashTable_3fc25c64637605aa3983374cc61db982_3-00002-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_2/1622716114/assets/MonolithHashTable_3fc25c64637605aa3983374cc61db982_3-00003-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_2/1622716114/assets/MonolithHashTable_3fc25c64637605aa3983374cc61db982_4-00000-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_2/1622716114/assets/MonolithHashTable_3fc25c64637605aa3983374cc61db982_4-00001-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_2/1622716114/assets/MonolithHashTable_3fc25c64637605aa3983374cc61db982_4-00003-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_2/1622716114/assets/MonolithHashTable_e3997af6324e55640d4611001fa3a15b_0-00000-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_2/1622716114/assets/MonolithHashTable_e3997af6324e55640d4611001fa3a15b_0-00002-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_2/1622716114/assets/MonolithHashTable_e3997af6324e55640d4611001fa3a15b_0-00003-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_2/1622716114/assets/MonolithHashTable_e3997af6324e55640d4611001fa3a15b_1-00000-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_2/1622716114/assets/MonolithHashTable_e3997af6324e55640d4611001fa3a15b_1-00002-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_2/1622716114/assets/MonolithHashTable_e3997af6324e55640d4611001fa3a15b_1-00003-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_2/1622716114/assets/MonolithHashTable_e3997af6324e55640d4611001fa3a15b_2-00000-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_2/1622716114/assets/MonolithHashTable_e3997af6324e55640d4611001fa3a15b_2-00001-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_2/1622716114/assets/MonolithHashTable_e3997af6324e55640d4611001fa3a15b_2-00003-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_2/1622716114/assets/MonolithHashTable_e3997af6324e55640d4611001fa3a15b_3-00000-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_2/1622716114/assets/MonolithHashTable_e3997af6324e55640d4611001fa3a15b_3-00001-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_2/1622716114/assets/MonolithHashTable_e3997af6324e55640d4611001fa3a15b_3-00002-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_2/1622716114/assets/MonolithHashTable_e3997af6324e55640d4611001fa3a15b_4-00000-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_2/1622716114/assets/MonolithHashTable_e3997af6324e55640d4611001fa3a15b_4-00001-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_2/1622716114/assets/MonolithHashTable_e3997af6324e55640d4611001fa3a15b_4-00002-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_2/1622716114/assets/MonolithHashTable_f6962510869b682fd764009be0e4e9c3_0-00000-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_2/1622716114/assets/MonolithHashTable_f6962510869b682fd764009be0e4e9c3_0-00001-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_2/1622716114/assets/MonolithHashTable_f6962510869b682fd764009be0e4e9c3_0-00002-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_2/1622716114/assets/MonolithHashTable_f6962510869b682fd764009be0e4e9c3_1-00000-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_2/1622716114/assets/MonolithHashTable_f6962510869b682fd764009be0e4e9c3_1-00001-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_2/1622716114/assets/MonolithHashTable_f6962510869b682fd764009be0e4e9c3_1-00002-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_2/1622716114/assets/MonolithHashTable_f6962510869b682fd764009be0e4e9c3_2-00001-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_2/1622716114/assets/MonolithHashTable_f6962510869b682fd764009be0e4e9c3_2-00002-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_2/1622716114/assets/MonolithHashTable_f6962510869b682fd764009be0e4e9c3_2-00003-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_2/1622716114/assets/MonolithHashTable_f6962510869b682fd764009be0e4e9c3_3-00000-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_2/1622716114/assets/MonolithHashTable_f6962510869b682fd764009be0e4e9c3_3-00002-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_2/1622716114/assets/MonolithHashTable_f6962510869b682fd764009be0e4e9c3_3-00003-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_2/1622716114/assets/MonolithHashTable_f6962510869b682fd764009be0e4e9c3_4-00000-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_2/1622716114/assets/MonolithHashTable_f6962510869b682fd764009be0e4e9c3_4-00001-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_2/1622716114/assets/MonolithHashTable_f6962510869b682fd764009be0e4e9c3_4-00003-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_3/1622716114/assets/MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_0-00002-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_3/1622716114/assets/MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_0-00003-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_3/1622716114/assets/MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_1-00003-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_3/1622716114/assets/MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_2-00000-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_3/1622716114/assets/MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_2-00003-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_3/1622716114/assets/MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_3-00000-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_3/1622716114/assets/MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_3-00001-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_3/1622716114/assets/MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_4-00001-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_3/1622716114/assets/MonolithHashTable_3fc25c64637605aa3983374cc61db982_0-00000-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_3/1622716114/assets/MonolithHashTable_3fc25c64637605aa3983374cc61db982_0-00001-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_3/1622716114/assets/MonolithHashTable_3fc25c64637605aa3983374cc61db982_0-00002-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_3/1622716114/assets/MonolithHashTable_3fc25c64637605aa3983374cc61db982_1-00001-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_3/1622716114/assets/MonolithHashTable_3fc25c64637605aa3983374cc61db982_1-00002-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_3/1622716114/assets/MonolithHashTable_3fc25c64637605aa3983374cc61db982_1-00003-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_3/1622716114/assets/MonolithHashTable_3fc25c64637605aa3983374cc61db982_2-00000-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_3/1622716114/assets/MonolithHashTable_3fc25c64637605aa3983374cc61db982_2-00002-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_3/1622716114/assets/MonolithHashTable_3fc25c64637605aa3983374cc61db982_2-00003-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_3/1622716114/assets/MonolithHashTable_3fc25c64637605aa3983374cc61db982_3-00000-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_3/1622716114/assets/MonolithHashTable_3fc25c64637605aa3983374cc61db982_3-00002-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_3/1622716114/assets/MonolithHashTable_3fc25c64637605aa3983374cc61db982_3-00003-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_3/1622716114/assets/MonolithHashTable_3fc25c64637605aa3983374cc61db982_4-00000-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_3/1622716114/assets/MonolithHashTable_3fc25c64637605aa3983374cc61db982_4-00001-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_3/1622716114/assets/MonolithHashTable_3fc25c64637605aa3983374cc61db982_4-00003-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_3/1622716114/assets/MonolithHashTable_e3997af6324e55640d4611001fa3a15b_0-00000-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_3/1622716114/assets/MonolithHashTable_e3997af6324e55640d4611001fa3a15b_0-00002-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_3/1622716114/assets/MonolithHashTable_e3997af6324e55640d4611001fa3a15b_0-00003-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_3/1622716114/assets/MonolithHashTable_e3997af6324e55640d4611001fa3a15b_1-00000-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_3/1622716114/assets/MonolithHashTable_e3997af6324e55640d4611001fa3a15b_1-00002-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_3/1622716114/assets/MonolithHashTable_e3997af6324e55640d4611001fa3a15b_1-00003-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_3/1622716114/assets/MonolithHashTable_e3997af6324e55640d4611001fa3a15b_2-00000-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_3/1622716114/assets/MonolithHashTable_e3997af6324e55640d4611001fa3a15b_2-00001-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_3/1622716114/assets/MonolithHashTable_e3997af6324e55640d4611001fa3a15b_2-00003-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_3/1622716114/assets/MonolithHashTable_e3997af6324e55640d4611001fa3a15b_3-00000-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_3/1622716114/assets/MonolithHashTable_e3997af6324e55640d4611001fa3a15b_3-00001-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_3/1622716114/assets/MonolithHashTable_e3997af6324e55640d4611001fa3a15b_3-00002-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_3/1622716114/assets/MonolithHashTable_e3997af6324e55640d4611001fa3a15b_4-00000-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_3/1622716114/assets/MonolithHashTable_e3997af6324e55640d4611001fa3a15b_4-00001-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_3/1622716114/assets/MonolithHashTable_e3997af6324e55640d4611001fa3a15b_4-00002-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_3/1622716114/assets/MonolithHashTable_f6962510869b682fd764009be0e4e9c3_0-00000-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_3/1622716114/assets/MonolithHashTable_f6962510869b682fd764009be0e4e9c3_0-00001-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_3/1622716114/assets/MonolithHashTable_f6962510869b682fd764009be0e4e9c3_0-00002-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_3/1622716114/assets/MonolithHashTable_f6962510869b682fd764009be0e4e9c3_1-00000-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_3/1622716114/assets/MonolithHashTable_f6962510869b682fd764009be0e4e9c3_1-00001-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_3/1622716114/assets/MonolithHashTable_f6962510869b682fd764009be0e4e9c3_1-00002-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_3/1622716114/assets/MonolithHashTable_f6962510869b682fd764009be0e4e9c3_2-00001-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_3/1622716114/assets/MonolithHashTable_f6962510869b682fd764009be0e4e9c3_2-00002-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_3/1622716114/assets/MonolithHashTable_f6962510869b682fd764009be0e4e9c3_2-00003-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_3/1622716114/assets/MonolithHashTable_f6962510869b682fd764009be0e4e9c3_3-00000-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_3/1622716114/assets/MonolithHashTable_f6962510869b682fd764009be0e4e9c3_3-00002-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_3/1622716114/assets/MonolithHashTable_f6962510869b682fd764009be0e4e9c3_3-00003-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_3/1622716114/assets/MonolithHashTable_f6962510869b682fd764009be0e4e9c3_4-00000-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_3/1622716114/assets/MonolithHashTable_f6962510869b682fd764009be0e4e9c3_4-00001-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_3/1622716114/assets/MonolithHashTable_f6962510869b682fd764009be0e4e9c3_4-00003-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_4/1622716114/assets/MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_0-00002-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_4/1622716114/assets/MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_0-00003-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_4/1622716114/assets/MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_1-00003-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_4/1622716114/assets/MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_2-00000-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_4/1622716114/assets/MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_2-00003-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_4/1622716114/assets/MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_3-00000-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_4/1622716114/assets/MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_3-00001-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_4/1622716114/assets/MonolithHashTable_0b9721ec6fc5396c38499b5be394b722_4-00001-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_4/1622716114/assets/MonolithHashTable_3fc25c64637605aa3983374cc61db982_0-00000-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_4/1622716114/assets/MonolithHashTable_3fc25c64637605aa3983374cc61db982_0-00001-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_4/1622716114/assets/MonolithHashTable_3fc25c64637605aa3983374cc61db982_0-00002-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_4/1622716114/assets/MonolithHashTable_3fc25c64637605aa3983374cc61db982_1-00001-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_4/1622716114/assets/MonolithHashTable_3fc25c64637605aa3983374cc61db982_1-00002-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_4/1622716114/assets/MonolithHashTable_3fc25c64637605aa3983374cc61db982_1-00003-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_4/1622716114/assets/MonolithHashTable_3fc25c64637605aa3983374cc61db982_2-00000-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_4/1622716114/assets/MonolithHashTable_3fc25c64637605aa3983374cc61db982_2-00002-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_4/1622716114/assets/MonolithHashTable_3fc25c64637605aa3983374cc61db982_2-00003-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_4/1622716114/assets/MonolithHashTable_3fc25c64637605aa3983374cc61db982_3-00000-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_4/1622716114/assets/MonolithHashTable_3fc25c64637605aa3983374cc61db982_3-00002-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_4/1622716114/assets/MonolithHashTable_3fc25c64637605aa3983374cc61db982_3-00003-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_4/1622716114/assets/MonolithHashTable_3fc25c64637605aa3983374cc61db982_4-00000-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_4/1622716114/assets/MonolithHashTable_3fc25c64637605aa3983374cc61db982_4-00001-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_4/1622716114/assets/MonolithHashTable_3fc25c64637605aa3983374cc61db982_4-00003-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_4/1622716114/assets/MonolithHashTable_e3997af6324e55640d4611001fa3a15b_0-00000-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_4/1622716114/assets/MonolithHashTable_e3997af6324e55640d4611001fa3a15b_0-00002-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_4/1622716114/assets/MonolithHashTable_e3997af6324e55640d4611001fa3a15b_0-00003-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_4/1622716114/assets/MonolithHashTable_e3997af6324e55640d4611001fa3a15b_1-00000-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_4/1622716114/assets/MonolithHashTable_e3997af6324e55640d4611001fa3a15b_1-00002-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_4/1622716114/assets/MonolithHashTable_e3997af6324e55640d4611001fa3a15b_1-00003-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_4/1622716114/assets/MonolithHashTable_e3997af6324e55640d4611001fa3a15b_2-00000-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_4/1622716114/assets/MonolithHashTable_e3997af6324e55640d4611001fa3a15b_2-00001-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_4/1622716114/assets/MonolithHashTable_e3997af6324e55640d4611001fa3a15b_2-00003-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_4/1622716114/assets/MonolithHashTable_e3997af6324e55640d4611001fa3a15b_3-00000-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_4/1622716114/assets/MonolithHashTable_e3997af6324e55640d4611001fa3a15b_3-00001-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_4/1622716114/assets/MonolithHashTable_e3997af6324e55640d4611001fa3a15b_3-00002-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_4/1622716114/assets/MonolithHashTable_e3997af6324e55640d4611001fa3a15b_4-00000-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_4/1622716114/assets/MonolithHashTable_e3997af6324e55640d4611001fa3a15b_4-00001-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_4/1622716114/assets/MonolithHashTable_e3997af6324e55640d4611001fa3a15b_4-00002-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_4/1622716114/assets/MonolithHashTable_f6962510869b682fd764009be0e4e9c3_0-00000-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_4/1622716114/assets/MonolithHashTable_f6962510869b682fd764009be0e4e9c3_0-00001-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_4/1622716114/assets/MonolithHashTable_f6962510869b682fd764009be0e4e9c3_0-00002-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_4/1622716114/assets/MonolithHashTable_f6962510869b682fd764009be0e4e9c3_1-00000-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_4/1622716114/assets/MonolithHashTable_f6962510869b682fd764009be0e4e9c3_1-00001-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_4/1622716114/assets/MonolithHashTable_f6962510869b682fd764009be0e4e9c3_1-00002-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_4/1622716114/assets/MonolithHashTable_f6962510869b682fd764009be0e4e9c3_2-00001-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_4/1622716114/assets/MonolithHashTable_f6962510869b682fd764009be0e4e9c3_2-00002-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_4/1622716114/assets/MonolithHashTable_f6962510869b682fd764009be0e4e9c3_2-00003-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_4/1622716114/assets/MonolithHashTable_f6962510869b682fd764009be0e4e9c3_3-00000-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_4/1622716114/assets/MonolithHashTable_f6962510869b682fd764009be0e4e9c3_3-00002-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_4/1622716114/assets/MonolithHashTable_f6962510869b682fd764009be0e4e9c3_3-00003-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_4/1622716114/assets/MonolithHashTable_f6962510869b682fd764009be0e4e9c3_4-00000-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_4/1622716114/assets/MonolithHashTable_f6962510869b682fd764009be0e4e9c3_4-00001-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/testdata/saved_model/ps_4/1622716114/assets/MonolithHashTable_f6962510869b682fd764009be0e4e9c3_4-00003-of-00004 ================================================ ================================================ FILE: monolith/native_training/model_export/warmup_data_decoder.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from absl import app from absl import flags from absl import logging import re import tensorflow as tf from tensorflow_serving.apis import predict_pb2 from tensorflow_serving.apis import prediction_log_pb2 from monolith.native_training import env_utils FLAGS = flags.FLAGS flags.DEFINE_string("file_name", None, "input file name") def main(_): try: env_utils.setup_hdfs_env() except: pass tf.compat.v1.enable_eager_execution() tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO) logging.set_verbosity(logging.INFO) def decode_fn(record_bytes): log = prediction_log_pb2.PredictionLog() log.ParseFromString(record_bytes) return log for i, batch in enumerate(tf.data.TFRecordDataset([FLAGS.file_name])): prediction_log = decode_fn(batch.numpy()) predict_log = prediction_log.predict_log request = predict_log.request simple_request_string = re.sub('string_val:.*', 'string_val: ...', str(request)) logging.info('%dth model_spec:\n%s', i, simple_request_string) if __name__ == "__main__": app.run(main) ================================================ FILE: monolith/native_training/model_export/warmup_data_gen.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import sys from absl import app from absl import flags from absl import logging import tensorflow as tf from tensorflow_serving.apis import predict_pb2 from tensorflow_serving.apis import prediction_log_pb2 from monolith.native_training import env_utils from monolith.native_training.model_export.data_gen_utils import gen_prediction_log from monolith.native_training.data.feature_list import FeatureList FLAGS = flags.FLAGS flags.DEFINE_string("file_name", None, "input file name") flags.DEFINE_integer("batch_size", 256, "Batch size of prediction request.") flags.DEFINE_bool("lagrangex_header", False, "kafka_dump_prefix") flags.DEFINE_bool("kafka_dump_prefix", False, "kafka_dump_prefix") flags.DEFINE_bool("has_sort_id", True, "has_sort_id") flags.DEFINE_bool("kafka_dump", False, "kafka_dump") flags.DEFINE_integer("max_records", 1000, "Maximum number of warmup records.") flags.DEFINE_string("model_name", "default", "mode name") flags.DEFINE_string("signature_names", "serving_default", "signature names") flags.DEFINE_string("output_path", "/tmp/tf_warmup_data", "output path of warmup data.") flags.DEFINE_enum("variant_type", "instance", ['instance', 'example', 'example_batch'], "variant_type") flags.DEFINE_string("sparse_features", None, "sparse_features") flags.DEFINE_string("dense_features", None, "dense_features") flags.DEFINE_integer("dense_feature_shapes", None, "dense_feature_shapes") flags.DEFINE_integer("dense_feature_types", None, "dense_feature_types") flags.DEFINE_string("extra_features", None, "extra_features") flags.DEFINE_integer("extra_feature_shapes", None, "extra_feature_shapes") flags.DEFINE_string("feature_list", None, "feature_list") flags.DEFINE_enum("gen_type", "file", ['file', 'random'], "gen_type") flags.DEFINE_integer("drop_rate", 0, "drop_rate") class PBReader(object): def __init__(self, file_name: str, batch_size: int, lagrangex_header: bool = False, has_sort_id: bool = False, kafka_dump_prefix: bool = False, kafka_dump: bool = False, variant_type: str = 'instance'): self.file_name = file_name assert batch_size > 0 self.batch_size = batch_size if self.file_name is None or len(self.file_name) == 0: self._stream = sys.stdin.buffer else: self._stream = tf.io.gfile.GFile(self.file_name) self.lagrangex_header = lagrangex_header self.has_sort_id = has_sort_id self.kafka_dump_prefix = kafka_dump_prefix self.kafka_dump = kafka_dump self.variant_type = variant_type self._curr = 0 self._max_iter = None def __iter__(self): return self def __next__(self): try: self._curr += 1 if self._max_iter is not None and self._curr > self._max_iter: raise StopIteration pb_items = [] if self.variant_type == 'example_batch': # example_batch self._read_header() bin_string = self._stream.read(self._read_size()) pb_items.append(bin_string) else: # example/instance for _ in range(self.batch_size): self._read_header() bin_string = self._stream.read(self._read_size()) pb_items.append(bin_string) return tf.make_tensor_proto(pb_items) except: if self.file_name: self._stream.close() raise StopIteration def _read_size(self) -> int: size_t = 8 try: size_binary = self._stream.read(size_t) if len(size_binary) != size_t: raise EOFError except Exception as e: raise e return int.from_bytes(size_binary, byteorder="little") def _read_header(self): size, aggregate_page_sortid_size = 0, 0 if self.lagrangex_header: size = self._read_size() else: if self.kafka_dump_prefix: size = self._read_size() if size == 0: size = self._read_size() else: aggregate_page_sortid_size = size if self.has_sort_id: if aggregate_page_sortid_size == 0: size = self._read_size() else: size = aggregate_page_sortid_size sort_id = self._stream.read(size) if self.kafka_dump: size = self._read_size() def set_max_iter(self, max_records): if self.variant_type == 'example_batch': assert self.batch_size < max_records self._max_iter = (max_records // self.batch_size) else: self._max_iter = max_records def gen_prediction_log_from_file(file_name: str = None, batch_size: int = 64, lagrangex_header: bool = False, kafka_dump_prefix=False, has_sort_id=True, kafka_dump=False, max_records=1000, variant_type: str = 'instance'): assert variant_type in {'instance', 'example', 'example_batch'} if variant_type == 'instance': input_name = 'instances' elif variant_type == 'example': input_name = 'examples' else: assert lagrangex_header == True input_name = 'example_batch' reader = PBReader(file_name, batch_size, lagrangex_header, has_sort_id, kafka_dump_prefix, kafka_dump, variant_type) reader.set_max_iter(max_records) signature_names = [name.strip() for name in FLAGS.signature_names.split(',')] if 'serving_default' not in signature_names: signature_names.append('serving_default') for i, batch in enumerate(reader): request = predict_pb2.PredictRequest() request.model_spec.name = FLAGS.model_name request.model_spec.signature_name = signature_names[i % len(signature_names)] request.inputs[input_name].CopyFrom(batch) log = prediction_log_pb2.PredictionLog( predict_log=prediction_log_pb2.PredictLog(request=request)) yield log def tf_dtype(dtype: str) -> tf.compat.v1.dtypes.DType: if dtype in {'int', 'int32', 'short', 'uint', 'uint32', '3', '22'}: return tf.int32 elif dtype in {'int64', 'long', 'uint64', '9', '23'}: return tf.int46 elif dtype in {'float', 'float32', '1'}: return tf.float32 elif dtype in {'float64', 'double', '2'}: return tf.float64 elif dtype in {'bool', 'boolean', '10'}: return tf.bool elif dtype in {'str', 'string', 'char', '7'}: return tf.string else: raise Exception(f'{dtype} error') def main(_): env_utils.setup_hdfs_env() with tf.io.TFRecordWriter(FLAGS.output_path) as writer: if FLAGS.gen_type == 'file': for log in gen_prediction_log_from_file( FLAGS.file_name, FLAGS.batch_size, FLAGS.lagrangex_header, FLAGS.kafka_dump_prefix, FLAGS.has_sort_id, FLAGS.kafka_dump, FLAGS.max_records, FLAGS.variant_type): writer.write(log.SerializeToString()) else: assert FLAGS.sparse_features is not None sparse_features = FLAGS.sparse_features.split(',') if FLAGS.dense_features is not None: dense_features = FLAGS.dense_features.split(',') dense_feature_shapes = [ int(shape) for shape in FLAGS.dense_feature_shapes.split(',') ] dense_feature_types = [ tf_dtype(dtype) for dtype in FLAGS.dense_feature_types.split(',') ] else: dense_features = FLAGS.dense_features dense_feature_shapes = FLAGS.dense_feature_shapes dense_feature_types = FLAGS.dense_feature_types if FLAGS.extra_features is not None: extra_features = FLAGS.extra_features.split(',') extra_feature_shapes = [ int(shape) for shape in FLAGS.extra_feature_shapes.split(',') ] else: extra_features = FLAGS.extra_features extra_feature_shapes = FLAGS.extra_feature_shapes feature_list = FeatureList.parse(FLAGS.feature_list) for log in gen_prediction_log(FLAGS.model_name, sparse_features, dense_features, dense_feature_shapes, dense_feature_types, extra_features, extra_feature_shapes, feature_list, FLAGS.batch_size, FLAGS.max_records, None, FLAGS.variant_type, FLAGS.drop_rate): writer.write(log.SerializeToString()) if __name__ == "__main__": logging.set_verbosity(logging.INFO) app.run(main) ================================================ FILE: monolith/native_training/model_export/warmup_example_batch.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import sys from absl import app from absl import flags from absl import logging import tensorflow as tf from tensorflow.python.util import compat from tensorflow_serving.apis import predict_pb2 from tensorflow_serving.apis import prediction_log_pb2 from monolith.native_training import env_utils flags.DEFINE_string('input_folder', '', '') flags.DEFINE_string('output_path', '', '') FLAGS = flags.FLAGS def gen_prediction_log(input_folder): filenames = tf.io.gfile.listdir(input_folder) for filename in filenames: with tf.io.gfile.GFile(os.path.join(input_folder, filename), 'rb') as f: request = predict_pb2.PredictRequest() print(request.ParseFromString(compat.as_bytes(f.read()))) request.model_spec.name = "default" request.model_spec.signature_name = "serving_default" log = prediction_log_pb2.PredictionLog( predict_log=prediction_log_pb2.PredictLog(request=request)) yield log def main(_): with tf.io.TFRecordWriter(FLAGS.output_path) as writer: for log in gen_prediction_log(FLAGS.input_folder): writer.write(log.SerializeToString()) if __name__ == "__main__": env_utils.setup_hdfs_env() app.run(main) ================================================ FILE: monolith/native_training/monolith_checkpoint_state.proto ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. syntax = "proto2"; package monolith.native_training; // Protocol buffer representing the monolith checkpoint state. message MonolithCheckpointState { // Paths to all exempt(never-to-be-deleted) model checkpoints repeated string exempt_model_checkpoint_paths = 1; optional int64 last_checkpoint_save_timestamp = 2; enum HashTableType { UNKNOWN = 0; CUCKOO_HASH_MAP = 1; MULTI_CUCKOO_HASH_MAP = 2; } optional HashTableType builtin_hash_table_type = 3; } ================================================ FILE: monolith/native_training/monolith_export.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. def monolith_export(obj): """A dummy decorator to hint this class/function should be exported.""" obj.__monolith_doc = None return obj ================================================ FILE: monolith/native_training/multi_hash_table_ops.proto ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. syntax="proto2"; package monolith; message MultiHashTableProto { repeated int32 dims = 1; optional bytes slot_expire_time_config = 2; repeated string table_names = 3; optional string learning_rate_tensor = 4; optional string shared_name = 5; optional int32 saver_parallel = 6; optional string initializer_op = 7; optional string handle_tensor = 8; } ================================================ FILE: monolith/native_training/multi_hash_table_ops.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import abc import copy import concurrent.futures import dataclasses import hashlib import os import threading from typing import Tuple, Union, Dict, List, Iterable, NamedTuple import collections from absl import logging from google.protobuf import text_format import tensorflow as tf from tensorflow.python.framework import ops from tensorflow.python.ops import resources from monolith.native_training import basic_restore_hook from monolith.native_training import entry from monolith.native_training import hash_filter_ops from monolith.native_training.hash_table_utils import infer_dim_size from monolith.native_training.multi_type_hash_table import BaseMultiTypeHashTable from monolith.native_training import multi_hash_table_ops_pb2 from monolith.native_training import distributed_serving_ops from monolith.native_training import graph_meta from monolith.native_training.runtime.ops import gen_monolith_ops from monolith.native_training import save_utils from monolith.native_training.model_export.export_context import \ is_exporting from monolith.native_training.proto import ckpt_info_pb2 from monolith.native_training.runtime.hash_table import \ embedding_hash_table_pb2 hash_table_ops = gen_monolith_ops _TIMEOUT_IN_MS = 60 * 60 * 1000 _MULTI_HASH_TABLE_GRAPH_KEY = "monolith_multi_hash_tables" class CachedConfig(NamedTuple): """Cache the config object to reduce the graph size.""" # Original configs configs: Dict[str, entry.HashTableConfigInstance] # Generated data # The table names table_names: Tuple[str] # The multi_hash_table serialized config. mconfig: bytes # table creation may request the data from other devices. mconfig_tensor: tf.Tensor # The dim size in each config to make tf function reused. dims: Tuple[int] # This is generated from mconfig slot_expire_time_config: bytes def infer_dims(configs: Dict[str, entry.HashTableConfigInstance]): table_names = tuple(sorted(configs.keys())) dims = [] for table_name in table_names: config = configs[table_name] table_config = config.table_config dims.append(infer_dim_size(table_config)) return dims def convert_to_cached_config(configs: Dict[str, entry.HashTableConfigInstance]): mconfig = embedding_hash_table_pb2.MultiEmbeddingHashTableConfig() table_names = tuple(sorted(configs.keys())) slot_expire_time_config = None dims = [] for table_name in table_names: config = configs[table_name] table_config = embedding_hash_table_pb2.EmbeddingHashTableConfig() table_config.CopyFrom(config.table_config) if is_exporting(): table_config.entry_config.entry_type = embedding_hash_table_pb2.EntryConfig.EntryType.SERVING mconfig.names.append(table_name) mconfig.configs.append(table_config) dims.append(infer_dim_size(table_config)) if not slot_expire_time_config: slot_expire_time_config = table_config.slot_expire_time_config.SerializeToString( ) serialized_mconfig = mconfig.SerializeToString() return CachedConfig(configs=configs, table_names=table_names, mconfig=serialized_mconfig, mconfig_tensor=tf.convert_to_tensor(serialized_mconfig), dims=tuple(dims), slot_expire_time_config=slot_expire_time_config) @dataclasses.dataclass class MultiHashTableMetadata: name_set: set = dataclasses.field(default_factory=set) # TODO(leqi.zou): Add tf.function when export is fixed def concat_1d_tensors(*args) -> tf.RaggedTensor: """Concat 1D tensors into a Raaged Tensor """ values = tf.concat(args, axis=0) row_lengths = [tf.size(t) for t in args] return tf.RaggedTensor.from_row_lengths(values, row_lengths, validate=False) # TODO(leqi.zou): Add tf.function when export is fixed def get_list_from_flat_value(key: tf.RaggedTensor, dims: Tuple[int], flat_value: tf.Tensor) -> List[tf.Tensor]: row_lengths = key.row_lengths() value_lengths = row_lengths * dims values = tf.split(flat_value, value_lengths) for i in range(len(dims)): values[i] = tf.reshape(values[i], [-1, dims[i]]) return values # TODO(leqi.zou): Add tf.function when export is fixed def flatten_n_tensors(*args) -> tf.Tensor: flattened_tensors = [] for tensor in args: flattened_tensors.append(tf.reshape(tensor, shape=[-1])) return tf.concat(flattened_tensors, axis=0) class RawMultiTypeHashTable(abc.ABC): """Raw lookup API to minimize the overhead transferration between differene devices.""" @abc.abstractmethod def get_ragged_id(self, slot_to_id: Dict[str, tf.Tensor]) -> tf.RaggedTensor: """Converts ids to a single ragged id. Graph independent.""" pass @abc.abstractmethod def get_flat_value(self, slot_to_value: Dict[str, tf.Tensor]) -> tf.Tensor: """Converts values to a single float tensor. Graph independent.""" pass @abc.abstractmethod def get_embeddings(self, ragged_id: tf.RaggedTensor, value: tf.Tensor) -> Dict[str, tf.Tensor]: """Converts returned flat value into the dict of embeddings. Graph independent.""" pass @abc.abstractmethod def raw_lookup(self, ragged_id: tf.RaggedTensor) -> tf.Tensor: pass @abc.abstractmethod def raw_apply_gradients(self, ragged_id: tf.RaggedTensor, flat_grad: tf.Tensor, global_step: tf.Tensor, *args, **kwargs) -> "RawMultiTypeHashTable": pass @abc.abstractclassmethod def raw_assign(self, ragged_id: tf.RaggedTensor, flat_value: tf.Tensor, *args, **kwargs): pass def _convert_to_int64(t): if isinstance(t, tf.Tensor): return t return tf.convert_to_tensor(t, tf.int64) def _convert_to_float32(t): if isinstance(t, tf.Tensor): return t return tf.convert_to_tensor(t, tf.float32) class MultiHashTable(BaseMultiTypeHashTable, RawMultiTypeHashTable): """ It maps a int64 to a float32 embedding. """ NAME_PREFIX = "MonolithMultiHashTable" def __init__(self, cc: CachedConfig = None, hash_filter: tf.Tensor = None, sync_client: tf.Tensor = None, learning_rate_list: List[tf.Tensor] = None, name_suffix: str = "", saver_parallel: int = -1, table_proto: multi_hash_table_ops_pb2.MultiHashTableProto = None, import_scope: str = None, device='/device:CPU:0'): if table_proto is not None: self._init_from_proto(table_proto, import_scope) return self._dims = cc.dims self._slot_expire_time_config = cc.slot_expire_time_config self._table_names = cc.table_names self._learning_rate = tf.cast(tf.stack(learning_rate_list), tf.float32) self._saver_parallel = saver_parallel self._shared_name = "_".join([MultiHashTable.NAME_PREFIX, name_suffix]) self._check_and_insert_name(self._shared_name) self._device = device with tf.device(device): # We separate the table creation and use by using a dummy var. init_handle = hash_table_ops.create_monolith_multi_hash_table( filter_handle=hash_filter, sync_client_handle=sync_client, config=cc.mconfig_tensor, shared_name=self._shared_name) self._initializer = init_handle.op self._handle = hash_table_ops.read_monolith_multi_hash_table( shared_name=self._shared_name) self._is_initialized = hash_table_ops.is_hash_table_initialized( init_handle) resources.register_resource(init_handle, self._initializer, self._is_initialized) tf.compat.v1.get_collection_ref(_MULTI_HASH_TABLE_GRAPH_KEY).append(self) def _init_from_proto( self, proto: multi_hash_table_ops_pb2.MultiHashTableProto = None, import_scope: str = None): assert isinstance(proto, multi_hash_table_ops_pb2.MultiHashTableProto) g = tf.compat.v1.get_default_graph() self._dims = tuple(proto.dims) self._slot_expire_time_config = proto.slot_expire_time_config self._table_names = tuple(proto.table_names) self._learning_rate = g.as_graph_element( ops.prepend_name_scope(proto.learning_rate_tensor, import_scope)) self._saver_parallel = proto.saver_parallel self._shared_name = proto.shared_name self._initializer = g.as_graph_element( ops.prepend_name_scope(proto.initializer_op, import_scope)) self._handle = g.as_graph_element( ops.prepend_name_scope(proto.handle_tensor, import_scope)) init_handle = self._initializer.outputs[0] self._is_initialized = hash_table_ops.is_hash_table_initialized(init_handle) resources.register_resource(init_handle, self._initializer, self._is_initialized) @classmethod def from_cached_config(cls, cc: CachedConfig, hash_filter: tf.Tensor = None, sync_client: tf.Tensor = None, name_suffix: str = "", saver_parallel: int = -1): table_config = next(iter(cc.configs.values())).table_config assert table_config.HasField("type") table_type = table_config.WhichOneof("type") logging.info("Hash table type: {}".format(table_type)) use_gpu = table_type == "gpucuco" d = "/device:GPU:0" if use_gpu else "/device:CPU:0" with tf.device(d): hash_filter = hash_filter if hash_filter is not None else hash_filter_ops.create_dummy_hash_filter( name_suffix=name_suffix) sync_client = sync_client if sync_client is not None else distributed_serving_ops.create_dummy_sync_client( ) learning_rate_list = [] table_names = list(cc.configs.keys()) for table_name in table_names: config = cc.configs[table_name] if len(config.learning_rate_fns) != len( config.table_config.entry_config.segments): raise ValueError( "Size of learning_rate_fns and size of segments must be equal.") learning_rate_list.extend(config.call_learning_rate_fns_fewer_ops()) if tf.compat.v1.get_default_graph() != cc.mconfig_tensor.graph: # In this case, we can't reuse mconfig_tensor cc = cc._replace(mconfig_tensor=tf.convert_to_tensor(cc.mconfig)) return cls(cc=cc, hash_filter=hash_filter, sync_client=sync_client, learning_rate_list=learning_rate_list, name_suffix=name_suffix, saver_parallel=saver_parallel, device=d) @classmethod def from_configs(cls, configs: Dict[str, entry.HashTableConfigInstance], *args, **kwargs): cc = convert_to_cached_config(configs) return cls.from_cached_config(cc, *args, **kwargs) @staticmethod def from_proto(table_proto, import_scope=None): return MultiHashTable(table_proto=table_proto, import_scope=import_scope) def to_proto(self, export_scope=None): if (export_scope is not None and not self._handle.name.startswith(export_scope)): return None proto = multi_hash_table_ops_pb2.MultiHashTableProto() proto.dims.extend(self._dims) proto.slot_expire_time_config = self._slot_expire_time_config proto.table_names.extend(self._table_names) proto.learning_rate_tensor = ops.strip_name_scope(self._learning_rate.name, export_scope) proto.saver_parallel = self._saver_parallel proto.shared_name = self._shared_name proto.initializer_op = ops.strip_name_scope(self._initializer.name, export_scope) proto.handle_tensor = ops.strip_name_scope(self._handle.name, export_scope) return proto @classmethod def _check_and_insert_name(cls, name): meta = graph_meta.get_meta("multi_hash_table_metadata", MultiHashTableMetadata) if name in meta.name_set: raise ValueError("shared_name {} has already been used.".format(name)) meta.name_set.add(name) @property def table_names(self): """Return table names.""" return self._table_names @property def handle(self): return self._handle @property def shared_name(self): return self._shared_name @property def initializer(self): return self._initializer """Implements BaseMultiHashTable""" def assign(self, slot_to_id_and_value: Dict[str, Tuple[tf.Tensor, tf.Tensor]], req_time: tf.Tensor = None, enable_inter_table_parallelism: bool = False) -> "MultiHashTable": ragged_id = self.get_ragged_id( {k: _convert_to_int64(v[0]) for k, v in slot_to_id_and_value.items()}) flat_value = self.get_flat_value( {k: _convert_to_float32(v[1]) for k, v in slot_to_id_and_value.items()}) return self.raw_assign(ragged_id, flat_value) def assign_add(self, slot_to_id_and_value: Dict[str, Tuple[tf.Tensor, tf.Tensor]], req_time: tf.Tensor = None) -> "MultiHashTable": if req_time is None: req_time = tf.constant(0, dtype=tf.int64) ragged_id = self.get_ragged_id( {k: _convert_to_int64(v[0]) for k, v in slot_to_id_and_value.items()}) flat_value = self.get_flat_value( {k: _convert_to_float32(v[1]) for k, v in slot_to_id_and_value.items()}) new_handle = hash_table_ops.monolith_multi_hash_table_assign_add( mtable=self._handle, id=ragged_id.values, id_split=ragged_id.row_splits, value=flat_value, update_time=req_time) return self._copy_with_new_table(new_handle) def reinitialize(self, slot: str, ids: tf.Tensor) -> Tuple["MultiHashTable", tf.Tensor]: new_handle, status = hash_table_ops.monolith_multi_hash_table_reinitialize( mtable=self._handle, table_name=slot, id=ids) return self._copy_with_new_table(new_handle), status def lookup(self, slot_to_id: Dict[str, tf.Tensor]) -> Dict[str, tf.Tensor]: ragged_id = self.get_ragged_id( {k: _convert_to_int64(v) for k, v in slot_to_id.items()}) flat_embedding = self.raw_lookup(ragged_id) slot_to_embeddings = self.get_embeddings(ragged_id, flat_embedding) slot_to_embeddings = { k: v for k, v in slot_to_embeddings.items() if k in slot_to_id } return slot_to_embeddings def lookup_entry( self, slot_to_id: Dict[str, tf.Tensor], enable_inter_table_parallelism: bool = False) -> Dict[str, tf.Tensor]: raise NotImplementedError("") def apply_gradients(self, slot_to_id_and_grad: Dict[str, Tuple[tf.Tensor, tf.Tensor]], global_step: tf.Tensor, req_time: tf.Tensor = None) -> "MultiHashTable": ragged_id = self.get_ragged_id( {k: _convert_to_int64(v[0]) for k, v in slot_to_id_and_grad.items()}) flat_grad = self.get_flat_value( {k: _convert_to_float32(v[1]) for k, v in slot_to_id_and_grad.items()}) return self.raw_apply_gradients(ragged_id, flat_grad, global_step, req_time) def save(self, basename: tf.Tensor) -> "MultiHashTable": new_handle = hash_table_ops.monolith_multi_hash_table_save( mtable=self._handle, basename=basename, nshards=self._saver_parallel, slot_expire_time_config=self._slot_expire_time_config) return self._copy_with_new_table(new_handle) def restore(self, basename: tf.Tensor) -> "MultiHashTable": new_handle = hash_table_ops.monolith_multi_hash_table_restore( mtable=self._handle, basename=basename) return self._copy_with_new_table(new_handle) def as_op(self, *args, **kwargs): # pylint: disable=unused-argument return self._handle.op def _copy_with_new_table(self, handle: tf.Tensor): copied = copy.copy(self) copied._handle = handle return copied def feature_stat(self, basename: tf.Tensor): """Only to be called after hash tables are saved.""" features, counts = hash_table_ops.monolith_multi_hash_table_feature_stat( basename) return features, counts """ Fused ops for sync training. """ # This is a very concise API that supports fused lookup, without mapping the # IDs to its slots. def fused_lookup(self, ids: tf.Tensor, fused_slot_size: tf.Tensor, num_of_shards: int, req_time=None) -> Tuple[tf.Tensor]: if req_time is None: req_time = tf.constant(0, dtype=tf.int64) return hash_table_ops.monolith_multi_hash_table_fused_lookup( mtable=self._handle, ids=ids, fused_slot_size=fused_slot_size, num_of_shards=num_of_shards, req_time=req_time) # This is a very concise API that supports fused optimize, without mapping the # IDs to its slots. def fused_apply_gradient( self, ids: tf.Tensor, indices: tf.Tensor, fused_slot_size: tf.Tensor, id_grads: tf.Tensor, id_offsets: tf.Tensor, grad_offsets: tf.Tensor, global_step: tf.Tensor, req_time: tf.Tensor, num_of_shards: int, enable_grad_accumulation: bool = False) -> "MultiHashTable": handle = hash_table_ops.monolith_multi_hash_table_fused_optimize( mtable=self._handle, ids=ids, indices=indices, fused_slot_size=fused_slot_size, id_grads=id_grads, id_offsets=id_offsets, grad_offsets=grad_offsets, learning_rate_tensors=self._learning_rate, req_time=req_time, global_step=global_step, num_of_shards=num_of_shards, enable_grad_accumulation=enable_grad_accumulation) return self._copy_with_new_table(handle) def get_table_dim_sizes(self): return self._dims """ RawMultiTypeHashTable APIs """ def raw_lookup(self, ragged_id: tf.RaggedTensor) -> tf.Tensor: return hash_table_ops.monolith_multi_hash_table_lookup( mtable=self._handle, id=ragged_id.values, id_split=ragged_id.row_splits) def raw_apply_gradients(self, ragged_id: tf.RaggedTensor, flat_grad: tf.Tensor, global_step: tf.Tensor, req_time: tf.Tensor = None) -> RawMultiTypeHashTable: if req_time is None: req_time = tf.constant(0, dtype=tf.int64) return self._copy_with_new_table( hash_table_ops.monolith_multi_hash_table_optimize( mtable=self._handle, id=ragged_id.values, id_split=ragged_id.row_splits, value=flat_grad, learning_rate=self._learning_rate, update_time=req_time, global_step=global_step)) def raw_assign(self, ragged_id: tf.RaggedTensor, flat_value: tf.Tensor, req_time: tf.Tensor = None): logging.info(f"raw_assign {self._handle}") if req_time is None: req_time = tf.constant(0, dtype=tf.int64) return self._copy_with_new_table( hash_table_ops.monolith_multi_hash_table_assign( mtable=self._handle, id=ragged_id.values, id_split=ragged_id.row_splits, value=flat_value, update_time=req_time)) def get_embeddings(self, ragged_id: tf.RaggedTensor, value: tf.Tensor) -> Dict[str, tf.Tensor]: d = {} values = get_list_from_flat_value(ragged_id, self._dims, value) for name, value in zip(self._table_names, values): d[name] = value return d def get_ragged_id(self, slot_to_id: Dict[str, tf.Tensor]): tensors = [] empty_id = tf.constant([], dtype=tf.int64) for name in self._table_names: tensors.append(slot_to_id.get(name, empty_id)) return concat_1d_tensors(*tensors) def get_flat_value(self, slot_to_value: Dict[str, tf.Tensor]): tensors = [] empty_value = tf.zeros([0, 1]) for name in self._table_names: tensors.append(slot_to_value.get(name, empty_value)) return flatten_n_tensors(*tensors) class MultiHashTableCheckpointSaverListener(tf.estimator.CheckpointSaverListener ): """Saves the hash tables when saver is run.""" def __init__(self, basename: str, write_ckpt_info: bool = True): """|basename| should be a file name which is same as what is passed to saver.""" super().__init__() self._write_ckpt_info = write_ckpt_info self._helper = save_utils.SaveHelper(basename) self._table_id_to_placeholder = {} self._features_counts_tuples = [] self._save_op = self._build_save_graph() def before_save(self, sess, global_step_value): """ We use before save so the checkpoint file is updated after we successfully save the hash table. """ logging.info("Starting saving MultiHashTables.") feed_dict = {} base_dir = self._helper.get_ckpt_asset_dir( self._helper.get_ckpt_prefix(global_step_value)) tf.io.gfile.makedirs(base_dir) for table in tf.compat.v1.get_collection(_MULTI_HASH_TABLE_GRAPH_KEY): table_basename = base_dir + table.shared_name feed_dict.update( {self._table_id_to_placeholder[id(table)]: table_basename}) sess.run(self._save_op, feed_dict=feed_dict, options=tf.compat.v1.RunOptions(timeout_in_ms=_TIMEOUT_IN_MS)) logging.info("Finished saving MultiHashTables.") if self._write_ckpt_info: logging.info("Start collecting slot fid count.") features_counts_list = sess.run(fetches=self._features_counts_tuples, feed_dict=feed_dict) logging.info("Start writing CkptInfo.") feature_to_fid_count = collections.defaultdict(int) for features_counts in features_counts_list: features = features_counts[0].tolist() counts = features_counts[1].tolist() if not len(features) == len(counts): raise ValueError( "Number of features [{}] does not match number of fid counts [{}]" .format(len(features), len(counts))) for feature, count in zip(features, counts): feature_to_fid_count[feature] += count info = ckpt_info_pb2.CkptInfo() for feature, count in feature_to_fid_count.items(): info.feature_counts[feature] = count ckpt_dir = os.path.dirname(self._helper._basename) with tf.io.gfile.GFile( os.path.join(ckpt_dir, f"ckpt.info-{global_step_value}"), "w") as f: f.write(str(info)) logging.info("Finished writing CkptInfo.") def _build_save_graph(self) -> tf.Operation: save_ops = [] for table in ops.get_collection(_MULTI_HASH_TABLE_GRAPH_KEY): table_basename = tf.compat.v1.placeholder(tf.string, shape=[]) self._table_id_to_placeholder.update({id(table): table_basename}) save_op = table.save(basename=table_basename).as_op() save_ops.append(save_op) if self._write_ckpt_info: with tf.control_dependencies([save_op]): self._features_counts_tuples.append( table.feature_stat(table_basename)) with tf.control_dependencies(save_ops): return tf.no_op(name="multi_hashtable_save_all") class MultiHashTableCheckpointRestorerListener( basic_restore_hook.CheckpointRestorerListener): """Restores the hash tables from basename""" def __init__(self, basename: str, ps_monitor=None): super().__init__() self._basename = basename self._ps_monitor = ps_monitor self._helper = save_utils.SaveHelper(basename) self._table_id_to_placeholder = {} self._restore_ops_per_device = self._build_restore_graph() def before_restore(self, session): """ We use before restore so as to strictly control the order of restorer listeners. """ ckpt_prefix = tf.train.latest_checkpoint(os.path.dirname(self._basename)) if not ckpt_prefix: logging.info("No checkpoint found in %s. Skip the hash tables restore.", self._basename) return logging.info("Restore hash tables from %s.", ckpt_prefix) asset_dir = self._helper.get_ckpt_asset_dir(ckpt_prefix) init_ops = [] feed_dict = {} for mtable in tf.compat.v1.get_collection(_MULTI_HASH_TABLE_GRAPH_KEY): init_ops.append(mtable.initializer) table_basename = asset_dir + mtable.shared_name feed_dict.update( {self._table_id_to_placeholder[id(mtable)]: table_basename}) session.run(init_ops) restore_ops_all = [] for device, restore_ops in self._restore_ops_per_device.items(): if not self._ps_monitor or self._ps_monitor.is_ps_uninitialized( session, device): restore_ops_all.extend(restore_ops) session.run(restore_ops_all, feed_dict=feed_dict, options=tf.compat.v1.RunOptions(timeout_in_ms=_TIMEOUT_IN_MS)) logging.info("Finished restore.") def _build_restore_graph(self): restore_ops_per_device = collections.defaultdict(list) for table in ops.get_collection(_MULTI_HASH_TABLE_GRAPH_KEY): table_basename = tf.compat.v1.placeholder(tf.string, shape=[]) self._table_id_to_placeholder.update({id(table): table_basename}) restore_op = table.restore(basename=table_basename).as_op() restore_ops_per_device[table.handle.device].append(restore_op) return restore_ops_per_device class MultiHashTableRestorerSaverListener(tf.estimator.CheckpointSaverListener): """Since we use restore to remove stale entries, we create a saver listener here.""" def __init__(self, ckpt_prefix: str): self._l = MultiHashTableCheckpointRestorerListener(ckpt_prefix) def after_save(self, session, global_step_value): self._l.before_restore(session) ops.register_proto_function( _MULTI_HASH_TABLE_GRAPH_KEY, proto_type=multi_hash_table_ops_pb2.MultiHashTableProto, to_proto=MultiHashTable.to_proto, from_proto=MultiHashTable.from_proto) ops.NotDifferentiable("IsHashTableInitialized") ================================================ FILE: monolith/native_training/multi_hash_table_ops_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from absl.testing import parameterized import hashlib import os from typing import Dict, List from unittest import mock import tensorflow as tf import tensorflow.python.ops.resources as resources from monolith.native_training import basic_restore_hook from monolith.native_training import entry from monolith.native_training import learning_rate_functions from monolith.native_training import multi_hash_table_ops from monolith.native_training import test_utils def _id(x): return tf.constant(x, dtype=tf.int64) def _value(x): return tf.constant(x, dtype=tf.float32) def from_configs(configs, *args, **kwargs): """We do a serialization/deserialization to make sure it worked in all cases.""" with tf.name_scope("scope") as scope: table = multi_hash_table_ops.MultiHashTable.from_configs( configs, *args, **kwargs) proto = table.to_proto(export_scope=scope) table = multi_hash_table_ops.MultiHashTable.from_proto(proto, import_scope=scope) return table class MultiTypeHashTableTest(tf.test.TestCase, parameterized.TestCase): def test_lookup_assign_add_reinitialize(self): multi_table = multi_hash_table_ops.MultiHashTable.from_configs( configs={ "slot0": test_utils.generate_test_hash_table_config(1), "not_used": test_utils.generate_test_hash_table_config(2), "slot1": test_utils.generate_test_hash_table_config(2), "slot2": test_utils.generate_test_hash_table_config(2), }) multi_table = multi_table.assign_add( slot_to_id_and_value={ "slot0": (_id([0]), _value([[1]])), "slot1": (_id([1]), _value([[2, 2]])), "slot2": (_id([2, 3]), _value([[4, 4], [8, 8]])) }) values_dict = multi_table.lookup(slot_to_id={ "slot0": _id([0]), "slot1": _id([1]), "slot2": _id([2, 3]), }) with tf.compat.v1.train.SingularMonitoredSession() as sess: values_dict = sess.run(values_dict) expected_values_dict = { "slot0": [[1]], "slot1": [[2, 2]], "slot2": [[4, 4], [8, 8]] } for slot, values in values_dict.items(): self.assertAllEqual(values, expected_values_dict[slot]) multi_table, status1 = multi_table.reinitialize("slot2", _id([1, 2, 3])) multi_table, status2 = multi_table.reinitialize("slot3", _id([1, 2, 3])) values_dict = multi_table.lookup(slot_to_id={ "slot0": _id([0]), "slot1": _id([1]), "slot2": _id([1, 2, 3]), }) with tf.compat.v1.train.SingularMonitoredSession() as sess: values_dict, status1, status2 = sess.run([values_dict, status1, status2]) expected_values_dict = { "slot0": [[1]], "slot1": [[2, 2]], "slot2": [[0, 0], [0, 0], [0, 0]] } for slot, values in values_dict.items(): self.assertAllEqual(values, expected_values_dict[slot]) self.assertAllEqual(status1, [0, 1, 1]) self.assertAllEqual(status2, [-1, -1, -1]) def test_apply_gradients(self): with self.session() as sess: multi_table = from_configs( configs={ "slot0": test_utils.generate_test_hash_table_config(1), "slot1": test_utils.generate_test_hash_table_config(2), }) sess.run(multi_table.initializer) values_dict = multi_table.lookup(slot_to_id={ "slot0": _id([0]), "slot1": _id([1, 2]), }) grads = [tf.constant(2.0), tf.constant([[1.0, 3.0], [2.0, 4.0]])] global_step = tf.constant(0, dtype=tf.int64) multi_table = multi_table.apply_gradients(slot_to_id_and_grad={ "slot0": (_id([0]), grads[0]), "slot1": (_id([1, 2]), grads[1]), }, global_step=global_step) values_dict = multi_table.lookup(slot_to_id={ "slot0": _id([0]), "slot1": _id([1, 2]), }) values_dict = sess.run(values_dict) expected_dict = {"slot0": [[-2]], "slot1": [[-1, -3], [-2, -4]]} for key in expected_dict: self.assertAllEqual(values_dict[key], expected_dict[key]) def test_save_restore(self): with tf.Graph().as_default(), self.session() as sess: table_0 = from_configs( configs={ "slot0": test_utils.generate_test_hash_table_config(1), "slot1": test_utils.generate_test_hash_table_config(2), "slot2": test_utils.generate_test_hash_table_config(2), }) table_0 = table_0.assign_add( slot_to_id_and_value={ "slot0": (_id([0, 1]), _value([[1], [2]])), "slot1": (_id([2, 3, 4, 5]), _value([[1, 2], [2, 3], [3, 4], [4, 5]])), "slot2": (_id([6, 7, 8, 9, 10]), _value([[1, 1], [2, 2], [3, 3], [4, 4], [5, 5]])) }) basename = os.path.join(os.environ["TEST_TMPDIR"], "test_save_restore", table_0.shared_name) table_0 = table_0.save(basename) sess.run(table_0.initializer) sess.run(table_0.as_op()) with tf.Graph().as_default(), self.session() as sess: table_1 = from_configs( configs={ "slot0": test_utils.generate_test_hash_table_config(1), "slot2": test_utils.generate_test_hash_table_config(2), "slot3": test_utils.generate_test_hash_table_config(3), }) table_1 = table_1.restore(basename) values_dict = table_1.lookup(slot_to_id={ "slot0": _id([0, 1]), "slot2": _id([6, 7, 8, 9, 10]), }) sess.run(table_1.initializer) values_dict = sess.run(values_dict) expected_values_dict = { "slot0": [[1], [2]], "slot2": [[1, 1], [2, 2], [3, 3], [4, 4], [5, 5]] } for slot, values in values_dict.items(): self.assertAllEqual(values, expected_values_dict[slot]) def test_save_restore_hook(self): basename = os.path.join(os.environ["TEST_TMPDIR"], "test_save_restore_hook", "model.ckpt") table = from_configs( configs={ "slot0": test_utils.generate_test_hash_table_config(1), "slot1": test_utils.generate_test_hash_table_config(2), "slot2": test_utils.generate_test_hash_table_config(2), }) add_op = table.assign_add( slot_to_id_and_value={ "slot0": (_id([0]), _value([[1]])), "slot1": (_id([1]), _value([[2, 2]])), "slot2": (_id([2, 3]), _value([[4, 4], [8, 8]])) }).as_op() sub_op = table.assign_add( slot_to_id_and_value={ "slot0": (_id([0]), _value([[-1]])), "slot1": (_id([1]), _value([[-2, -3]])), "slot2": (_id([2, 3]), _value([[-4, -5], [-6, -7]])) }).as_op() values_dict = table.lookup(slot_to_id={ "slot0": _id([0]), "slot1": _id([1]), "slot2": _id([2, 3]), }) saver_listener = multi_hash_table_ops.MultiHashTableCheckpointSaverListener( basename) # We need to create some variables to make saver happy. tf.compat.v1.train.create_global_step() saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables(), sharded=True, max_to_keep=10, keep_checkpoint_every_n_hours=2) saver_hook = tf.estimator.CheckpointSaverHook(os.path.dirname(basename), save_steps=1000, saver=saver, listeners=[saver_listener]) restorer_listener = multi_hash_table_ops.MultiHashTableCheckpointRestorerListener( basename) restore_hook = basic_restore_hook.CheckpointRestorerHook( listeners=[restorer_listener]) with self.session() as sess: saver_hook.begin() sess.run(tf.compat.v1.global_variables_initializer()) resources.initialize_resources(resources.shared_resources()).run() # In the estimator API, graph will be finalized before calling hook g = tf.compat.v1.get_default_graph() g.finalize() sess.run(add_op) saver_hook.after_create_session(sess, None) sess.run(sub_op) # restore will override sub_op restore_hook.after_create_session(sess, None) values_dict = sess.run(values_dict) expected_values_dict = { "slot0": [[1]], "slot1": [[2, 2]], "slot2": [[4, 4], [8, 8]] } for slot, values in values_dict.items(): self.assertAllEqual(values, expected_values_dict[slot]) def test_meta_graph_export(self): multi_table = from_configs( configs={ "slot0": test_utils.generate_test_hash_table_config(1), "slot1": test_utils.generate_test_hash_table_config(2), }) meta = tf.compat.v1.train.export_meta_graph() self.assertIn(multi_hash_table_ops._MULTI_HASH_TABLE_GRAPH_KEY, meta.collection_def) if __name__ == "__main__": tf.compat.v1.disable_eager_execution() tf.test.main() ================================================ FILE: monolith/native_training/multi_type_hash_table.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations import abc from collections import defaultdict import copy import dataclasses import hashlib import itertools import re from typing import Callable, Dict, Iterable, List, Tuple from absl import logging import tensorflow as tf from monolith.native_training import device_utils from monolith.native_training import distribution_ops from monolith.native_training import entry from monolith.native_training import hash_table_ops from monolith.native_training import utils from monolith.native_training import prefetch_queue from monolith.native_training.hash_table_utils import infer_dim_size class BaseMultiTypeHashTable(abc.ABC): # https://www.python.org/dev/peps/pep-0526/#class-and-instance-variable-annotations # Allow nested instances. _table: BaseMultiTypeHashTable _tables: List[BaseMultiTypeHashTable] # Allow pipelined graph execution. _local_queue_hooks: List[prefetch_queue.EnqueueHook | prefetch_queue.AsyncPushHook] @abc.abstractmethod def lookup(self, slot_to_id: Dict[str, tf.Tensor]) -> Dict[str, tf.Tensor]: pass @abc.abstractmethod def assign( self, slot_to_id_and_value: Dict[str, Tuple[tf.Tensor, tf.Tensor]] ) -> BaseMultiTypeHashTable: pass @abc.abstractmethod def assign_add( self, slot_to_id_and_value: Dict[str, Tuple[tf.Tensor, tf.Tensor]] ) -> BaseMultiTypeHashTable: pass @abc.abstractmethod def reinitialize(self, slot: str, ids: tf.Tensor) -> Tuple[BaseMultiTypeHashTable, tf.Tensor]: pass @abc.abstractmethod def apply_gradients(self, slot_to_id_and_grad: Dict[str, Tuple[tf.Tensor, tf.Tensor]], *args) -> BaseMultiTypeHashTable: pass @abc.abstractmethod def as_op(self, name=None) -> tf.Operation: pass def add_queue_hook(self, hook): # Allow pipelined graph execution. if not getattr(self, "_local_queue_hooks", None): self._local_queue_hooks = [] self._local_queue_hooks.append(hook) def get_queue_hooks(self): hooks = copy.copy(getattr(self, "_local_queue_hooks", [])) if getattr(self, "_table", None): hooks.extend(self._table.get_queue_hooks()) if getattr(self, "_tables", None): hooks.extend( itertools.chain.from_iterable( [t.get_queue_hooks() for t in self._tables])) return hooks @abc.abstractmethod def get_table_dim_sizes(self) -> List[int]: pass # TODO(leqi.zou): Makes this have a better name. class MultiTypeHashTable(BaseMultiTypeHashTable): """ A hash tables that support different types of embeddings (they may have different dimensions/ optimizers). Different types are distinguished by "Slot". Slot is the type of ids, the embeddings in the same slot has the same dimension. The functionality are same as BaseHashTable. The only difference is that now the input is a map, which maps slot to ids/values. hash_table_factory has two params: name_suffix & hash_table_config """ def __init__( self, slot_to_config: Dict[str, entry.HashTableConfigInstance], hash_table_factory: Callable[[str, entry.HashTableConfigInstance], hash_table_ops.BaseHashTable]): self._slot_to_config = slot_to_config self._hash_tables = {} self._hash_table_resources = [] learning_rate_tensors = [] for slot in sorted(self._slot_to_config.keys()): # We need to keep the order here. config = self._slot_to_config[slot] self._hash_tables[slot] = hash_table_factory(slot, config) # Here we setup the hashtable resources based on # self._slot_to_config.keys() self._hash_table_resources.append(self._hash_tables[slot].as_op()) learning_rate_tensors.append(config.call_learning_rate_fns()) # Build flattened learning rate tensor for fused apply gradient. self._learning_rate_tensors = tf.concat(learning_rate_tensors, 0) def lookup(self, slot_to_id: Dict[str, tf.Tensor]) -> Dict[str, tf.Tensor]: slot_to_embedding = {} for slot, id in slot_to_id.items(): embedding = self._hash_tables[slot].lookup(id) slot_to_embedding[slot] = embedding return slot_to_embedding def assign( self, slot_to_id_and_value: Dict[str, Tuple[tf.Tensor, tf.Tensor]] ) -> MultiTypeHashTable: return self._update("assign", slot_to_id_and_value) def assign_add( self, slot_to_id_and_value: Dict[str, Tuple[tf.Tensor, tf.Tensor]] ) -> MultiTypeHashTable: return self._update("assign_add", slot_to_id_and_value) def reinitialize(self, slot: str, ids: tf.Tensor) -> Tuple[MultiTypeHashTable, tf.Tensor]: raise NotImplementedError( "MultiTypeHashTable dost not support reinitialize!") def apply_gradients( self, slot_to_id_and_grad: Dict[str, Tuple[tf.Tensor, tf.Tensor]], *args, **kwargs, ) -> MultiTypeHashTable: return self._update("apply_gradients", slot_to_id_and_grad, *args, **kwargs) def _update( self, method_name: str, slot_to_id_and_tensor: Dict[str, Tuple[tf.Tensor, tf.Tensor]], *args, **kwargs, ) -> MultiTypeHashTable: updated_tables = dict(self._hash_tables) for slot, (id, tensor) in slot_to_id_and_tensor.items(): updated_tables[slot] = getattr(self._hash_tables[slot], method_name)(id, tensor, *args, **kwargs) return self._copy_with_new_tables(updated_tables) def as_op(self, name=None) -> tf.Operation: name = name or "mtht_ao" with tf.control_dependencies( [table.as_op() for table in self._hash_tables.values()]): c = tf.no_op(name=("{}/done".format(name))) return c def _copy_with_new_tables( self, tables: Dict[int, tf.Tensor]) -> "MultiTypeHashTable": copied = copy.copy(self) # Update the hash_table_resources everytime when there is a table update. hash_table_resources = [] for slot in sorted(self._slot_to_config.keys()): hash_table_resources.append(tables[slot].as_op()) copied.__dict__["_hash_tables"] = tables copied.__dict__["_hash_table_resources"] = hash_table_resources return copied # This is a very concise API that supports fused lookup, without mapping the # IDs to its slots. def fused_lookup(self, ids: tf.Tensor, fused_slot_size: tf.Tensor, num_of_shards: int, req_time=None) -> Tuple[tf.Tensor]: if req_time is None: req_time = tf.constant(0, dtype=tf.int64) return hash_table_ops.fused_lookup(self._hash_table_resources, ids, fused_slot_size, num_of_shards, req_time) # This is a very concise API that supports fused optimize, without mapping the # IDs to its slots. def fused_apply_gradient( self, ids: tf.Tensor, indices: tf.Tensor, fused_slot_size: tf.Tensor, id_grads: tf.Tensor, id_offsets: tf.Tensor, grad_offsets: tf.Tensor, global_step: tf.Tensor, req_time: tf.Tensor, num_of_shards: int, enable_grad_accumulation: bool = False) -> MultiTypeHashTable: table_handles_output = hash_table_ops.fused_apply_gradient( self._hash_table_resources, ids, indices, fused_slot_size, id_grads, id_offsets, grad_offsets, self._learning_rate_tensors, req_time, global_step, num_of_shards, enable_grad_accumulation) copied = copy.copy(self) updated_tables = dict(self._hash_tables) for slot, handle in zip(sorted(self._slot_to_config.keys()), table_handles_output): updated_tables[slot] = self._hash_tables[slot].table_update(handle) copied.__dict__["_hash_tables"] = updated_tables copied.__dict__["_hash_table_resources"] = table_handles_output return copied def get_table_dim_sizes(self) -> List[int]: return [ self._hash_tables[slot].dim_size for slot in sorted(self._slot_to_config.keys()) ] @dataclasses.dataclass class _IndexedValues: """ _IndexedValues represents tensors merged from multiple slots. slots are a list of string represents where values are coming from indices are a tensor of 1-D int64 ranged in [0, len(slots)) represents the slot of that value is slots[index] values are a tensor represents a list tensors which merged from multiple slots. """ slots: List[tf.Tensor] index: tf.Tensor value: tf.Tensor class MergedMultiTypeHashTable(BaseMultiTypeHashTable): """A decorator that merge slots which have the same embedding config. This helps reduce the size of graph. However, the caller need to make sure ids in different slots are different.""" def __init__(self, slot_to_config: Dict[str, entry.HashTableConfigInstance], factory: Callable[[Dict[str, entry.HashTableConfigInstance]], BaseMultiTypeHashTable]): self._slot_to_config = slot_to_config logging.info( "Create MergedMultiTypeHashTable: 1) reverse feature_name -> config into config -> feature_name_list" ) self._slot_mapping: Dict[str, str] = {} # feature/slot -> merged_slot deduped_config_to_slots = defaultdict(list) for slot in sorted(slot_to_config): config = slot_to_config[slot] # Use str of config as the key for merging slots. deduped_config_to_slots[str(config)].append(slot) logging.info( "Create MergedMultiTypeHashTable: 2) gen merged_slot and map merged_slot -> config" ) merged_slot_to_config = {} for key, slots in deduped_config_to_slots.items(): def get_merged_str(strs: List[str]): concat = ",".join(strs) return concat, hashlib.md5(concat.encode()).hexdigest() slots_str, merged_slot = get_merged_str(slots) logging.info("Merged '{}' into '{}'".format(slots_str, merged_slot)) merged_config = copy.copy(slot_to_config[slots[0]]) # replace "fc_slot_*" to "slot_*" old_slots = [ slot[3:] if re.match("^fc_slot_[0-9]*$", slot) else slot for slot in slots ] _, old_merged_slot = get_merged_str(old_slots) if old_merged_slot != merged_slot: merged_config.extra_restore_names.append(old_merged_slot) merged_slot_to_config[merged_slot] = merged_config for slot in slots: self._slot_mapping[slot] = merged_slot self._merged_slot_to_config = merged_slot_to_config logging.info( f"Create MergedMultiTypeHashTable: 3) sub hash tables {factory}") self._table = factory(merged_slot_to_config) @property def slot_mapping(self): """Returns slot mapping.""" return self._slot_mapping def lookup(self, slot_to_id: Dict[str, tf.Tensor], auxiliary_bundle=None, early_reorder_indicies_res_pack=None): if auxiliary_bundle is None: auxiliary_bundle = {} if early_reorder_indicies_res_pack: merged_slot_to_sizes, res_pack = early_reorder_indicies_res_pack merged_slot_to_id = { k: None for k in merged_slot_to_sizes.keys() # None to keep interface, we will only use the keys } auxiliary_bundle['merged_slot_to_sizes'] = merged_slot_to_sizes merged_slot_to_embedding, auxiliary_bundle = self._table.lookup( merged_slot_to_id, auxiliary_bundle, res_pack) else: logging.info( "Lookup MergedMultiTypeHashTable: 1) merged the features ids belong the same hash table" ) merged_slot_to_id, merged_slot_to_sizes = self._get_merged_to_indexed_tensor( slot_to_id) auxiliary_bundle['merged_slot_to_sizes'] = merged_slot_to_sizes logging.info( f"Lookup MergedMultiTypeHashTable: 2) lookup sub hash table {self._table} for embeddings" ) merged_slot_to_embedding = self._table.lookup(merged_slot_to_id) merged_slot_to_slots = defaultdict(list) for k in sorted(self._slot_mapping.keys()): if k in slot_to_id: merged_slot_to_slots[self._slot_mapping[k]].append(k) merged_slot_to_sizes = auxiliary_bundle.pop('merged_slot_to_sizes') logging.info( "Lookup MergedMultiTypeHashTable: 3) split the lookuped embeddings into feature_name -> embedding" ) slot_to_embedding = {} for merged_slot, emb in merged_slot_to_embedding.items(): sizes = merged_slot_to_sizes[merged_slot] slots = merged_slot_to_slots[merged_slot] with device_utils.maybe_device_if_allowed("/device:GPU:0"): embs = tf.split(emb, sizes, axis=0, num=len(slots)) for slot, emb in zip(slots, embs): slot_to_embedding[slot] = emb if auxiliary_bundle: return slot_to_embedding, auxiliary_bundle return slot_to_embedding def assign( self, slot_to_id_and_value: Dict[str, Tuple[tf.Tensor, tf.Tensor]] ) -> MergedMultiTypeHashTable: return self._update(self._table.assign, slot_to_id_and_value) def assign_add( self, slot_to_id_and_value: Dict[str, Tuple[tf.Tensor, tf.Tensor]] ) -> MergedMultiTypeHashTable: return self._update(self._table.assign_add, slot_to_id_and_value) def reinitialize( self, slot: str, ids: tf.Tensor) -> Tuple[MergedMultiTypeHashTable, tf.Tensor]: raise NotImplementedError( "MergedMultiTypeHashTable dost not support reinitialize!") def apply_gradients(self, slot_to_id_and_grad: Dict[str, Tuple[tf.Tensor, tf.Tensor]], *args, **kwargs) -> MergedMultiTypeHashTable: return self._update(self._table.apply_gradients, slot_to_id_and_grad, *args, **kwargs) def _update(self, method, slot_to_id_and_tensor: Dict[str, Tuple[tf.Tensor, tf.Tensor]], *args, **kwargs): if kwargs.pop("skip_merge_id", False): # To avoid redundant cpu usage, in MergedMultiTypeHashTable # sync training only passes the slot_to_grad for apply_gradient slot_to_grad = slot_to_id_and_tensor with device_utils.maybe_device_if_allowed("/device:GPU:0"): merged_slot_to_grad, _ = self._get_merged_to_indexed_tensor( slot_to_grad) return self._copy_with_new_table( method(merged_slot_to_grad, *args, **kwargs)) slot_to_id = {k: v[0] for k, v in slot_to_id_and_tensor.items()} merged_slot_to_id, _ = self._get_merged_to_indexed_tensor(slot_to_id) slot_to_tensor = {k: v[1] for k, v in slot_to_id_and_tensor.items()} with device_utils.maybe_device_if_allowed("/device:GPU:0"): merged_slot_to_tensor, _ = self._get_merged_to_indexed_tensor( slot_to_tensor) merged_slot_to_id_and_tensor = {} for slot in merged_slot_to_id: merged_slot_to_id_and_tensor[slot] = (merged_slot_to_id[slot], merged_slot_to_tensor[slot]) return self._copy_with_new_table( method(merged_slot_to_id_and_tensor, *args, **kwargs)) def as_op(self, name=None) -> tf.Operation: return self._table.as_op(name) def _get_merged_to_indexed_tensor( self, slot_to_tensor: Dict[str, tf.Tensor] ) -> Tuple[Dict[str, tf.Tensor], Dict[str, tf.Tensor]]: merged_slot_to_tensors = defaultdict(list) merged_slot_to_sizes = defaultdict(list) for slot in sorted(slot_to_tensor.keys()): # We sorted the merged slot keys here to guarantee, the merging order. tensor = slot_to_tensor[slot] merged_slot = self._slot_mapping[slot] merged_slot_to_sizes[merged_slot].append(tf.shape(tensor)[0]) merged_slot_to_tensors[merged_slot].append(tensor) return {k: tf.concat(v, axis=0) for k, v in merged_slot_to_tensors.items()}, \ {k: tf.stack(v) for k, v in merged_slot_to_sizes.items()} def _copy_with_new_table( self, table: BaseMultiTypeHashTable) -> MergedMultiTypeHashTable: copied = copy.copy(self) copied._table = table return copied def get_table_dim_sizes(self) -> List[int]: return [ infer_dim_size(self._merged_slot_to_config[slot].table_config) for slot in sorted(self._merged_slot_to_config.keys()) ] ================================================ FILE: monolith/native_training/multi_type_hash_table_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import hashlib from typing import Dict, List from unittest import mock import tensorflow as tf from monolith.native_training import entry from monolith.native_training import hash_table_ops from monolith.native_training import learning_rate_functions from monolith.native_training import multi_type_hash_table from monolith.native_training import test_utils def factory(idx: int, config): return hash_table_ops.hash_table_from_config(config=config, name_suffix=str(idx)) def _id(x): return tf.constant(x, dtype=tf.int64) def _value(x): return tf.constant(x, dtype=tf.float32) class MultiTypeHashTableTest(tf.test.TestCase): def test_basic(self): with self.session() as sess: hash_table = multi_type_hash_table.MultiTypeHashTable( { "slot0": test_utils.generate_test_hash_table_config(1), "slot1": test_utils.generate_test_hash_table_config(2), "slot2": test_utils.generate_test_hash_table_config(2), }, factory) hash_table = hash_table.assign_add({ "slot0": (_id([0]), _value([[1]])), "slot1": (_id([1]), _value([[2, 2]])), "slot2": (_id([2, 3]), _value([[4, 4], [8, 8]])) }) values_dict = hash_table.lookup({ "slot0": _id([0]), "slot1": _id([1]), "slot2": _id([2, 3]), }) values_dict = sess.run(values_dict) expected_values_dict = { "slot0": [[1]], "slot1": [[2, 2]], "slot2": [[4, 4], [8, 8]] } for slot, values in values_dict.items(): self.assertAllEqual(values, expected_values_dict[slot]) def test_apply_gradients(self): with self.session() as sess: hash_table = multi_type_hash_table.MultiTypeHashTable( { "slot0": test_utils.generate_test_hash_table_config(1), "slot1": test_utils.generate_test_hash_table_config(2), }, factory) values_dict = hash_table.lookup({ "slot0": _id([0]), "slot1": _id([1, 2]), }) grads = [tf.constant(2.0), tf.constant([[1.0, 3.0], [2.0, 4.0]])] global_step = tf.constant(0, dtype=tf.int64) hash_table = hash_table.apply_gradients( { "slot0": (_id([0]), grads[0]), "slot1": (_id([1, 2]), grads[1]), }, global_step) values_dict = hash_table.lookup({ "slot0": _id([0]), "slot1": _id([1, 2]), }) values_dict = sess.run(values_dict) expected_dict = {"slot0": [[-2]], "slot1": [[-1, -3], [-2, -4]]} for key in expected_dict: self.assertAllEqual(values_dict[key], expected_dict[key]) def test_apply_gradients_with_learning_rate_decay(self): with self.session() as sess: global_step = tf.compat.v1.train.get_or_create_global_step() self.evaluate(tf.compat.v1.global_variables_initializer()) hash_table = multi_type_hash_table.MultiTypeHashTable( { "slot0": test_utils.generate_test_hash_table_config( 1, learning_rate=learning_rate_functions.PolynomialDecay( initial_learning_rate=0.1, decay_steps=10, end_learning_rate=1.1)), "slot1": test_utils.generate_test_hash_table_config( 2, learning_rate=learning_rate_functions.PolynomialDecay( initial_learning_rate=0.1, decay_steps=10, end_learning_rate=1.1)), }, factory) values_dict = hash_table.lookup({ "slot0": _id([0]), "slot1": _id([1, 2]), }) grads = [tf.constant(2.0), tf.constant([[1.0, 3.0], [2.0, 4.0]])] hash_table = hash_table.apply_gradients( { "slot0": (_id([0]), grads[0]), "slot1": (_id([1, 2]), grads[1]), }, global_step) values_dict = hash_table.lookup({ "slot0": _id([0]), "slot1": _id([1, 2]), }) values_dict = sess.run(values_dict) expected_dict = {"slot0": [[-0.2]], "slot1": [[-0.1, -0.3], [-0.2, -0.4]]} for key in expected_dict: self.assertAllClose(values_dict[key], expected_dict[key]) def test_apply_gradients_without_lookup(self): with self.session() as sess: hash_table = multi_type_hash_table.MultiTypeHashTable( { "slot0": test_utils.generate_test_hash_table_config(1), "slot1": test_utils.generate_test_hash_table_config(2) }, factory) global_step = tf.constant(0, dtype=tf.int64) hash_table = hash_table.apply_gradients( { "slot0": (_id([1]), tf.constant(3.0)), "slot1": (_id([2, 2]), tf.constant([[1.1, 3.1], [2.2, 4.2]])), }, global_step) values_dict = hash_table.lookup({"slot0": _id([1]), "slot1": _id([1, 2])}) values_dict = sess.run(values_dict) expected_dict = {"slot0": [[-3]], "slot1": [[0, 0], [-3.3, -7.3]]} for key in expected_dict: self.assertAllClose(values_dict[key], expected_dict[key]) def test_fused_lookup(self): with self.session() as sess: hash_table = multi_type_hash_table.MultiTypeHashTable( { "slot0": test_utils.generate_test_hash_table_config(1), "slot1": test_utils.generate_test_hash_table_config(2), "slot2": test_utils.generate_test_hash_table_config(2), }, factory) hash_table = hash_table.assign_add({ "slot0": (_id([0]), _value([[1]])), "slot1": (_id([1]), _value([[2, 2]])), "slot2": (_id([2, 3]), _value([[4, 4], [8, 8]])) }) values_dict = hash_table.fused_lookup([0, 1, 2, 3], [1, 1, 2], 1) embeddings, recv_splits, id_offsets, emb_offsets, indices = sess.run( values_dict) self.assertAllEqual(embeddings, [1, 2, 2, 4, 4, 8, 8]) self.assertAllEqual(recv_splits, [7]) self.assertAllEqual(id_offsets, [0, 1, 2, 4]) self.assertAllEqual(emb_offsets, [0, 1, 3, 7]) def test_fused_lookup_multi_shards(self): with self.session() as sess: hash_table = multi_type_hash_table.MultiTypeHashTable( { "slot0": test_utils.generate_test_hash_table_config(1), "slot1": test_utils.generate_test_hash_table_config(2), "slot2": test_utils.generate_test_hash_table_config(2), }, factory) hash_table = hash_table.assign_add({ "slot0": (_id([0]), _value([[1]])), "slot1": (_id([1]), _value([[2, 2]])), "slot2": (_id([2, 3]), _value([[4, 4], [8, 8]])) }) values_dict = hash_table.fused_lookup([0, 2, 1, 3], [1, 0, 1, 0, 1, 1], 2) embeddings, recv_splits, id_offsets, emb_offsets, indices = sess.run( values_dict) self.assertAllEqual(embeddings, [1, 4, 4, 2, 2, 8, 8]) self.assertAllEqual(recv_splits, [3, 4]) self.assertAllEqual(id_offsets, [0, 1, 1, 2, 2, 3, 4]) self.assertAllEqual(emb_offsets, [0, 1, 1, 3, 3, 5, 7]) def test_fused_apply_gradients(self): with self.session() as sess: hash_table = multi_type_hash_table.MultiTypeHashTable( { "slot0": test_utils.generate_test_hash_table_config(1), "slot1": test_utils.generate_test_hash_table_config(2) }, factory) ids = tf.constant([0, 1, 2], dtype=tf.int64) fused_slot_size = tf.constant([1, 2]) embeddings, _, id_offsets, emb_offsets, indices = hash_table.fused_lookup( ids, fused_slot_size, 1) grads = tf.constant([2.0, 1.0, 3.0, 2.0, 4.0]) hash_table = hash_table.fused_apply_gradient( ids, indices, fused_slot_size, grads, id_offsets, emb_offsets, tf.constant(0, dtype=tf.int64), tf.constant(0, dtype=tf.int64), 1) lookup_op = hash_table.fused_lookup(ids, fused_slot_size, 1) embeddings, recv_splits, id_offsets, emb_offsets, indices = sess.run( lookup_op) self.assertAllEqual(embeddings, [-2, -1, -3, -2, -4]) self.assertAllEqual(recv_splits, [5]) self.assertAllEqual(id_offsets, [0, 1, 3]) self.assertAllEqual(emb_offsets, [0, 1, 5]) def test_fused_apply_gradients_missing_tables(self): with self.session() as sess: hash_table = multi_type_hash_table.MultiTypeHashTable( { "slot0": test_utils.generate_test_hash_table_config(1), "slot1": test_utils.generate_test_hash_table_config(2) }, factory) ids = tf.constant([1, 1], dtype=tf.int64) fused_slot_size = tf.constant([1, 0, 1, 0]) embeddings, _, id_offsets, emb_offsets, indices = hash_table.fused_lookup( ids, fused_slot_size, 2) grads = tf.constant([1.0, 2.0]) hash_table = hash_table.fused_apply_gradient( ids, indices, fused_slot_size, grads, id_offsets, emb_offsets, tf.constant(0, dtype=tf.int64), tf.constant(0, dtype=tf.int64), 2) lookup_op = hash_table.fused_lookup(ids, fused_slot_size, 2) embeddings, recv_splits, id_offsets, emb_offsets, indices = sess.run( lookup_op) self.assertAllEqual(embeddings, [-3, -3]) self.assertAllEqual(recv_splits, [1, 1]) self.assertAllEqual(id_offsets, [0, 1, 1, 2, 2]) self.assertAllEqual(emb_offsets, [0, 1, 1, 2, 2]) def _multi_type_factory(slot_to_config): return multi_type_hash_table.MultiTypeHashTable(slot_to_config, factory) class MergedMultiTypeHashTable(tf.test.TestCase): def testBasic(self): with self.session() as sess: hash_table = multi_type_hash_table.MergedMultiTypeHashTable( { "slot0": test_utils.generate_test_hash_table_config(1), "slot1": test_utils.generate_test_hash_table_config(2), "slot2": test_utils.generate_test_hash_table_config(2), }, _multi_type_factory) # slot 1 & 2 should be merged. updated_hash_table = hash_table.assign_add({ "slot0": (_id([0]), _value([[1]])), "slot1": (_id([1]), _value([[2, 2]])), "slot2": (_id([1]), _value([[4, 4]])) }) values_dict = updated_hash_table.lookup({ "slot0": _id([0]), "slot1": _id([1]), "slot2": _id([1]), }) values_dict = sess.run(values_dict) expected_values_dict = { "slot0": [[1]], "slot1": [[6, 6]], "slot2": [[6, 6]] } for slot, values in values_dict.items(): self.assertAllEqual(values, expected_values_dict[slot]) global_step = tf.constant(0, dtype=tf.int64) updated_hash_table = hash_table.apply_gradients( { "slot0": (_id([0]), _value([[-1]])), "slot1": (_id([1]), _value([[1, 1]])), "slot2": (_id([1]), _value([[1, 1]])) }, global_step) values_dict = updated_hash_table.lookup({ "slot0": _id([0]), "slot1": _id([1]), }) values_dict = sess.run(values_dict) expected_values_dict = {"slot0": [[2]], "slot1": [[4, 4]]} for slot, values in values_dict.items(): self.assertAllEqual(values, expected_values_dict[slot]) def testNameStability(self): factory = mock.MagicMock() def call(slot_to_config: Dict[str, entry.HashTableConfigInstance]): self.assertListEqual(list(slot_to_config.keys()), ["e21904dd414d1780e5fc904866dc69c2"]) return _multi_type_factory(slot_to_config) factory.side_effect = call hash_table = multi_type_hash_table.MergedMultiTypeHashTable( { "slot0": test_utils.generate_test_hash_table_config(1), "slot1": test_utils.generate_test_hash_table_config(1), }, factory) def testRestoreName(self): factory = mock.MagicMock() def call(slot_to_config: Dict[str, entry.HashTableConfigInstance]): config = next(iter(slot_to_config.values())) expected_name = hashlib.md5("slot_0".encode()).hexdigest() self.assertListEqual(config.extra_restore_names, [expected_name]) return _multi_type_factory(slot_to_config) hash_table = multi_type_hash_table.MergedMultiTypeHashTable( { "fc_slot_0": test_utils.generate_test_hash_table_config(1), }, factory) if __name__ == "__main__": tf.compat.v1.disable_eager_execution() tf.test.main() ================================================ FILE: monolith/native_training/native_model.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from posixpath import split from monolith.native_training.distributed_serving_ops import remote_predict from monolith.native_training.utils import with_params from absl import logging, flags from abc import ABC, abstractmethod from copy import deepcopy from datetime import datetime from functools import partial import os, math, time import hashlib from typing import Tuple, Dict, Iterable, Union, Optional import numpy as np import tensorflow as tf from tensorflow.estimator.export import ServingInputReceiver from tensorflow.python.data.ops.dataset_ops import DatasetV2 from tensorflow.python.framework import ops from tensorflow.python.saved_model.signature_constants import DEFAULT_SERVING_SIGNATURE_DEF_KEY from monolith.core import hyperparams from monolith.native_training.entry import * from monolith.native_training.feature import * from monolith.core.base_layer import get_layer_loss from monolith.core.hyperparams import update_params from monolith.native_training import distribution_ops from monolith.native_training import file_ops from monolith.native_training import hash_table_ops from monolith.native_training.native_task_context import get import monolith.native_training.feature_utils as feature_utils from monolith.native_training.estimator import EstimatorSpec from monolith.native_training.embedding_combiners import FirstN, Combiner from monolith.native_training.graph_utils import add_batch_norm_into_update_ops from monolith.native_training.layers import LogitCorrection from monolith.native_training.native_task import NativeTask, NativeContext from monolith.native_training.metric import utils as metric_utils from monolith.native_training.model_export import export_context from monolith.native_training.model_export.export_context import is_exporting, is_exporting_distributed from monolith.native_training.data.feature_list import get_feature_name_and_slot from monolith.native_training.monolith_export import monolith_export from monolith.native_training.runtime.hash_table import \ embedding_hash_table_pb2 from monolith.native_training.data.feature_utils import switch_slot, switch_slot_batch from monolith.native_training.data.utils import get_slot_feature_name, enable_tob_env, register_slots from monolith.native_training.utils import add_to_collections from monolith.native_training.model_dump.dump_utils import DumpUtils from monolith.native_training.dense_reload_utils import CustomRestoreListener, CustomRestoreListenerKey from monolith.native_training.layers.utils import dim_size from monolith.native_training.metric.metric_hook import KafkaMetricHook, FileMetricHook, vepfs_key_fn, vepfs_layout_fn from idl.matrix.proto.example_pb2 import OutConfig, OutType, TensorShape from monolith.native_training.data.datasets import POOL_KEY, PBDataset, PbType from monolith.native_training.data.parsers import ParserCtx, get_default_parser_ctx, parse_instances, parse_examples from monolith.native_training.model_dump.graph_utils import _node_name from monolith.native_training.distribution_utils import enable_sync_training from monolith.native_training.device_utils import input_device_fn, model_device_fn, serving_input_device_fn, maybe_device_if_allowed FLAGS = flags.FLAGS dump_utils = DumpUtils(enable=False) @monolith_export def get_sigmoid_loss_and_pred( name, logits, label, batch_size: int, sample_rate: Union[tf.Tensor, float] = 1.0, sample_bias: bool = False, mode: tf.estimator.ModeKeys = tf.estimator.ModeKeys.TRAIN, instance_weight: tf.Tensor = None, mask: tf.Tensor = None, logit_clip_threshold: Optional[float] = None, predict_before_correction: bool = True): """对二分类, 基于sigmoid计算loss和predict 由于负例采样, fast_emit等原因, 需要对logit进进较正, 在get_sigmoid_loss_and_pred会透明地进行 Args: name (:obj:`str`): 名称 logits (:obj:`tf.Tensor`): 样本logits(无偏logit), 可用于直接predict, 但是不能用于直接计算loss label (:obj:`tf.Tensor`): 样本标签 batch_size (:obj:`int`): 批大小 sample_rate (:obj:`tf.Tensor`): 负例采样的采样率 sample_bias (:obj:`bool`): 是否有开启fast_emit mode (:obj:`str`): 运行模式, 可以是train/eval/predict等 mask (:obj:`tf.Tensor`): Apply boolean mask to loss before reduce_sum """ logits = tf.reshape(logits, shape=(-1,)) batch_size = dim_size(logits, 0) if mode != tf.estimator.ModeKeys.PREDICT: if sample_rate is not None and isinstance(sample_rate, float): sample_rate = tf.fill(dims=(batch_size,), value=sample_rate) if sample_rate is None: sample_rate = tf.fill(dims=(batch_size,), value=1.0) src = LogitCorrection(activation=None, sample_bias=sample_bias, name='sample_rate_correction') logits_biased = src((logits, sample_rate)) if predict_before_correction: pred = tf.nn.sigmoid(logits, name='{name}_sigmoid_pred'.format(name=name)) else: pred = tf.nn.sigmoid(logits_biased, name='{name}_sigmoid_pred'.format(name=name)) if logit_clip_threshold is not None: assert 0 < logit_clip_threshold < 1 threshold = math.log((1 - logit_clip_threshold) / logit_clip_threshold) logits_biased = tf.clip_by_value(logits_biased, -threshold, threshold) losses = tf.nn.sigmoid_cross_entropy_with_logits( labels=tf.reshape(label, shape=(-1,)), logits=logits_biased, name='{name}_sigmoid_loss'.format(name=name)) if instance_weight is not None: instance_weight = tf.reshape(instance_weight, shape=(-1,)) if mask is not None: mask = tf.reshape(mask, shape=(-1,)) losses = tf.boolean_mask(losses, mask) if instance_weight is not None: instance_weight = tf.boolean_mask(instance_weight, mask) if instance_weight is not None: losses = tf.multiply(losses, instance_weight) loss = tf.reduce_sum(losses) else: loss = None pred = tf.nn.sigmoid(logits, name='{name}_sigmoid_pred'.format(name=name)) return loss, pred @monolith_export def get_softmax_loss_and_pred(name, logits, label, mode): """对多分类, 基于softmax计算loss和predict Args: name (:obj:`str`): 名称 logits (:obj:`tf.Tensor`): 样本logits label (:obj:`tf.Tensor`): 样本标签 mode (:obj:`str`): 运行模式, 可以是train/eval/predict等 """ pred = tf.argmax(tf.nn.softmax(logits, name='{name}_softmax_pred'.format(name=name)), axis=1) if mode != tf.estimator.ModeKeys.PREDICT: loss = tf.nn.softmax_cross_entropy_with_logits( labels=label, logits=logits, name='{name}_softmax_loss'.format(name=name)) else: loss = None return loss, pred class DeviceCtxType(object): INPUT_FN: str = 'input_fn' MODEL_FN: str = 'model_fn' INPUT_RECEIVER_FN: str = 'input_receiver_fn' OTHERS: str = 'others' @classmethod def all_types(cls): return {cls.INPUT_FN, cls.MODEL_FN, cls.INPUT_RECEIVER_FN, cls.OTHERS} class MonolithDeviceCtx(object): def __init__(self, ctx_type: str): assert ctx_type is not None and ctx_type in DeviceCtxType.all_types() self.ctx_type = ctx_type self._current = None self._device_fn = None def __enter__(self): if not enable_sync_training() or export_context.is_exporting(): return if self.ctx_type == DeviceCtxType.INPUT_FN: self._device_fn = input_device_fn elif self.ctx_type == DeviceCtxType.MODEL_FN: self._device_fn = model_device_fn elif self.ctx_type == DeviceCtxType.INPUT_RECEIVER_FN: self._device_fn = serving_input_device_fn else: return self._current = tf.compat.v1.device(self._device_fn) return self._current.__enter__() def __exit__(self, exc_type, exc_val, exc_tb): if self._current is not None: if self.ctx_type == DeviceCtxType.MODEL_FN: self.ensure_variables_in_device() self._current.__exit__(exc_type, exc_val, exc_tb) self._current = None self._device_fn = None def ensure_variables_in_device(self): graph = tf.compat.v1.get_default_graph() for op in graph.get_operations(): if op.name.startswith('global_step'): graph._apply_device_functions(op) @monolith_export class MonolithBaseModel(NativeTask, ABC): """模型开发的基类""" @classmethod def params(cls): p = super(MonolithBaseModel, cls).params() p.define("output_path", None, "The output path of predict/eval") p.define("output_fields", None, "The output fields") p.define("delimiter", '\t', "The delimiter of output file") p.define('file_name', '', 'the test input file name') p.define('enable_grads_and_vars_summary', False, 'enable_grads_and_vars_summary') # p.define("only_save_item_cache_hashtable", False, "if set, then only save item cache table in next run") p.define('dense_weight_decay', 0.0, 'dense_weight_decay') p.define("clip_norm", 1000.0, "float, clip_norm") p.define("sparse_norm_warmup_steps", None, "int, sparse norm warmup steps") p.define('default_occurrence_threshold', 0, 'int') return p def __init__(self, params): super(MonolithBaseModel, self).__init__(params) enable_tob_env() self.fs_dict = {} self.fc_dict = {} # feature_name -> slice_name -> FeatureSlice(feature_slot, start, end) self.slice_dict = {} self._layout_dict = {} self._occurrence_threshold = {} self._use_dense_allreduce = FLAGS.enable_sync_training self._share_slot_mapping = {} def __getattr__(self, name): if "p" in self.__dict__: if hasattr(self.p, name): return getattr(self.p, name) elif name == 'batch_size': if self.p.mode == tf.estimator.ModeKeys.EVAL: return self.p.eval.per_replica_batch_size else: return self.p.train.per_replica_batch_size if (hasattr(type(self), name) and isinstance(getattr(type(self), name), property)): return getattr(type(self), name).fget(self) else: return super(MonolithBaseModel, self).__getattr__(name) def __setattr__(self, key, value): if 'p' in self.__dict__: if hasattr(self.p, key): setattr(self.p, key, value) return value elif key == 'batch_size': self.p.eval.per_replica_batch_size = value self.p.train.per_replica_batch_size = value return value super(MonolithBaseModel, self).__setattr__(key, value) return value def __deepcopy__(self, memo): cls = self.__class__ result = cls.__new__(cls) memo[id(self)] = result for name, value in self.__dict__.items(): if name == 'dump_utils': result.__dict__[name] = value else: result.__dict__[name] = deepcopy(value) return result def _get_file_ops(self, features, pred): assert self.p.output_fields is not None output_path = os.path.join(self.p.output_path, f"part-{get().worker_index:05d}") op_file = file_ops.WritableFile(output_path) op_fields = [features[field] for field in self.p.output_fields.split(',')] if isinstance(pred, (tuple, list)): op_fields.extend(pred) elif isinstance(pred, dict): sorted_keys = list(sorted(pred.keys())) logging.info('sorted_keys: %s', sorted_keys) op_fields.extend([pred[k] for k in sorted_keys]) else: op_fields.append(pred) fmt = self.p.delimiter.join(["{}"] * len(op_fields)) + "\n" try: op_fields_tmp = [ tf.squeeze(tensor, axis=-1) if tensor.shape.rank > 1 and tensor.shape.as_list()[-1] == 1 else tensor for tensor in op_fields if tensor is not None] op_fields = op_fields_tmp except Exception as e: pass result = tf.nest.map_structure( tf.stop_gradient, tf.map_fn(fn=lambda t: tf.strings.format(fmt, t, summarize=-1), elems=tuple(op_fields), dtype=tf.string, fn_output_signature=tf.string) ) write_op = op_file.append(tf.strings.reduce_join(result)) return op_file, write_op def _dump_item_embedding_ops(self, features): assert isinstance(self, DeepRoughSortBaseModel) assert 'item_id' in features and 'item_bias' in features and 'item_vec' in features item_cache_table_path = self._cal_item_cache_table_path() cache_table_file_name = "MonolithHashTable_cached_item_embeddings-00000-of-00001" output_path = os.path.join(item_cache_table_path, cache_table_file_name) logging.info(f"_dump_item_embedding_ops: output_path={output_path}") op_file = file_ops.WritableFile(output_path) write_op = op_file.append_entry_dump(features['item_id'], features['item_bias'], features['item_vec']) return op_file, write_op def _get_real_mode(self, mode: tf.estimator.ModeKeys): if mode == tf.estimator.ModeKeys.PREDICT: return mode elif mode == tf.estimator.ModeKeys.TRAIN: return self.mode else: raise ValueError('model error!') def is_fused_layout(self) -> bool: return self.ctx.layout_factory is not None def instantiate(self): """实例化对像""" return self def add_loss(self, losses): """用于追加辅助loss, 如layer loss等 Args: losses (:obj:`List[tf.Tensor]`): 辅助loss列表 """ if losses: if isinstance(losses, (list, tuple)): self.losses.extend(losses) else: self.losses.append(losses) @property def losses(self): graph = tf.compat.v1.get_default_graph() if hasattr(graph, '__losses'): return getattr(graph, '__losses') else: setattr(graph, '__losses', []) return graph.__losses @losses.setter def losses(self, losses): graph = tf.compat.v1.get_default_graph() if hasattr(graph, '__losses'): graph.__losses = losses else: setattr(graph, '__losses', losses) @property def _global_step(self): with maybe_device_if_allowed('/device:GPU:0'): return tf.compat.v1.train.get_or_create_global_step() @property def _training_hooks(self): graph = tf.compat.v1.get_default_graph() if hasattr(graph, '__training_hooks'): return getattr(graph, '__training_hooks') else: setattr(graph, '__training_hooks', []) return graph.__training_hooks @_training_hooks.setter def _training_hooks(self, hooks): graph = tf.compat.v1.get_default_graph() if hasattr(graph, '__training_hooks'): graph.__training_hooks = hooks else: setattr(graph, '__training_hooks', hooks) def clean(self): # update fs_dict, fc_dict, slice_dict self.fs_dict = {} self.fc_dict = {} self.slice_dict = {} # slot_id -> Dict[slot_id, slice] self._occurrence_threshold = {} def create_input_fn(self, mode): """生成input_fn""" def input_fn_internal(): with MonolithDeviceCtx(ctx_type=DeviceCtxType.INPUT_FN): return self.input_fn(mode) return input_fn_internal def create_model_fn(self): """生成model_fn""" self.clean() def model_fn_internal( features: Dict[str, tf.Tensor], mode: tf.estimator.ModeKeys, config: tf.estimator.RunConfig) -> tf.estimator.EstimatorSpec: global_step = self._global_step real_mode = self._get_real_mode(mode) with MonolithDeviceCtx(ctx_type=DeviceCtxType.MODEL_FN): local_spec = self.model_fn(features, real_mode) # get label, loss, pred and head_name from model_fn result if isinstance(local_spec, EstimatorSpec): label, loss, pred = local_spec.label, local_spec.loss, local_spec.pred if isinstance(pred, dict): assert label is None or isinstance(label, dict) head_name, pred = list(zip(*pred.items())) else: head_name = local_spec.head_name or self.metrics.deep_insight_target.split( ',') is_classification = local_spec.classification elif isinstance(local_spec, (tuple, list)): label, loss, pred = local_spec if isinstance(pred, dict): assert label is None or isinstance(label, dict) head_name, pred = list(zip(*pred.items())) else: head_name = self.metrics.deep_insight_target assert head_name is not None is_classification = True logging.warning( 'if this is not a classification task, pls. return EstimatorSpec in model_fn and specify it' ) else: raise Exception("EstimatorSpec Error!") # check label/pred/head_name if isinstance(pred, (list, tuple, dict)): assert isinstance(head_name, (list, tuple)) assert isinstance(pred, (list, tuple)) if label is not None: assert len(head_name) == len(label) assert len(label) == len(pred) else: if isinstance(head_name, (list, tuple)): assert len(head_name) == 1 head_name = head_name[0] assert isinstance(head_name, str) if label is not None: assert isinstance(label, tf.Tensor) if isinstance(pred, (list, tuple)): assert len(pred) == 1 pred = pred[0] assert isinstance(pred, tf.Tensor) if label is not None: if isinstance(label, dict): label = { key: None if value is None else tf.identity(value, name=key) for key, value in label.items() } elif isinstance(label, (list, tuple)): label = [ None if l is None else tf.identity( l, name=f'label_{_node_name(l.name)}') for l in label ] else: label = label if label is None else tf.identity( label, name=f'label_{_node_name(label.name)}') dump_utils.add_model_fn(self, mode, features, label, loss, pred, head_name, is_classification) if self.losses: loss = loss + tf.add_n(self.losses) # in predict mode, when enable_resource_constrained_roughsort, only generate item_cache_hashtable if not is_exporting() and real_mode == tf.estimator.ModeKeys.PREDICT and FLAGS.enable_resource_constrained_roughsort: assert isinstance(self, DeepRoughSortBaseModel) if isinstance(pred, (list, tuple)): assert isinstance(head_name, (list, tuple)) and len(pred) == len(head_name) predictions = dict(zip(head_name, pred)) else: predictions = pred item_cache_op_file, item_cache_write_op = self._dump_item_embedding_ops(features) close_hook = file_ops.FileCloseHook([item_cache_op_file]) with tf.control_dependencies(control_inputs=[item_cache_write_op]): if isinstance(predictions, dict): predictions = {k: tf.identity(v) for k, v in predictions.items()} else: predictions = tf.identity(predictions) return tf.estimator.EstimatorSpec(tf.estimator.ModeKeys.PREDICT, loss=tf.constant(1), train_op=tf.no_op(), training_hooks=[close_hook] + self._training_hooks, predictions=predictions) if real_mode == tf.estimator.ModeKeys.PREDICT: if isinstance(pred, (list, tuple)): assert isinstance(head_name, (list, tuple)) and len(pred) == len(head_name) predictions = dict(zip(head_name, pred)) else: predictions = pred if is_exporting() or self.p.output_path is None: spec = tf.estimator.EstimatorSpec(real_mode, predictions=predictions, training_hooks=self._training_hooks) else: op_file, write_op = self._get_file_ops(features, pred) close_hook = file_ops.FileCloseHook([op_file]) with tf.control_dependencies(control_inputs=[write_op]): if isinstance(pred, dict): predictions = {k: tf.identity(v) for k, v in predictions.items()} else: predictions = tf.identity(predictions) spec = tf.estimator.EstimatorSpec(mode, training_hooks=[close_hook] + self._training_hooks, predictions=predictions) if is_exporting() and self._export_outputs: self._export_outputs.update(spec.export_outputs) return spec._replace(export_outputs=self._export_outputs) else: return spec train_ops = [] targets, labels_list, preds_list = [], [], [] if isinstance(pred, (list, tuple, dict)): assert isinstance(label, (list, tuple, dict)) and len(pred) == len(label) assert isinstance(head_name, (list, tuple)) and len(pred) == len(head_name) if isinstance(is_classification, (tuple, list, dict)): assert len(pred) == len(is_classification) else: is_classification = [is_classification] * len(pred) for i, name in enumerate(head_name): label_tensor = label[i] if isinstance(label, (list, tuple)) else label[name] pred_tensor = pred[i] if isinstance(pred, (list, tuple)) else pred[name] head_classification = is_classification[i] if isinstance( is_classification, (list, tuple)) else is_classification[name] targets.append(name) labels_list.append(label_tensor) preds_list.append(pred_tensor) if not FLAGS.disable_native_metrics: if head_classification: auc_per_core, auc_update_op = tf.compat.v1.metrics.auc( labels=label_tensor, predictions=pred_tensor, name=name) auc_head_name = "{}_auc".format(name) print_op = tf.print(auc_head_name, auc_per_core, output_stream=sys.stderr) with tf.control_dependencies([print_op]): tf.compat.v1.summary.scalar(auc_head_name, tf.identity(auc_per_core)) train_ops.append(auc_update_op) else: mean_squared_error, mse_update_op = tf.compat.v1.metrics.mean_squared_error( labels=label_tensor, predictions=pred_tensor, name=name) mse_head_name = "{}_mse".format(name) print_op = tf.print(mse_head_name, mean_squared_error, output_stream=sys.stderr) with tf.control_dependencies([print_op]): tf.compat.v1.summary.scalar(mse_head_name, tf.identity(mean_squared_error)) train_ops.append(mse_update_op) else: targets.append(head_name) labels_list.append(label) preds_list.append(pred) if not FLAGS.disable_native_metrics: if is_classification: auc_per_core, auc_update_op = tf.compat.v1.metrics.auc( labels=label, predictions=pred, name=head_name) auc_head_name = "{}_auc".format(head_name) print_op = tf.print(auc_head_name, auc_per_core, output_stream=sys.stderr) with tf.control_dependencies([print_op]): tf.compat.v1.summary.scalar(auc_head_name, tf.identity(auc_per_core)) train_ops.append(auc_update_op) else: mean_squared_error, mse_update_op = tf.compat.v1.metrics.mean_squared_error( labels=label, predictions=pred, name=head_name) mse_head_name = "{}_mse".format(head_name) print_op = tf.print(mse_head_name, mean_squared_error, output_stream=sys.stderr) with tf.control_dependencies([print_op]): tf.compat.v1.summary.scalar(mse_head_name, tf.identity(mean_squared_error)) train_ops.append(mse_update_op) enable_metrics = self.metrics.enable_kafka_metrics or self.metrics.enable_file_metrics or self.metrics.enable_deep_insight if enable_metrics and self.metrics.deep_insight_sample_ratio > 0: model_name = self.metrics.deep_insight_name sample_ratio = self.metrics.deep_insight_sample_ratio extra_fields_keys = self.metrics.extra_fields_keys dump_filename = f"{self.metrics.dump_filename}.part-{get().worker_index:05d}" if self.metrics.dump_filename else None deep_insight_op = metric_utils.write_deep_insight( features=features, sample_ratio=self.metrics.deep_insight_sample_ratio, labels=label, preds=pred, model_name=model_name or "model_name", target=self.metrics.deep_insight_target, targets=targets, labels_list=labels_list, preds_list=preds_list, extra_fields_keys=extra_fields_keys, enable_kafka_metrics=self.metrics.enable_kafka_metrics or self.metrics.enable_file_metrics, dump_filename=dump_filename) logging.info("model_name: {}, target: {}.".format( model_name, self.metrics.deep_insight_target)) train_ops.append(deep_insight_op) tf.compat.v1.add_to_collection("deep_insight_op", deep_insight_op) if self.metrics.enable_kafka_metrics: self.add_training_hook(KafkaMetricHook(deep_insight_op)) elif self.metrics.enable_file_metrics: self.add_training_hook( FileMetricHook(deep_insight_op, worker_id=get().worker_index, parse_fn=self.metrics.parse_fn, key_fn=self.metrics.key_fn or vepfs_key_fn, layout_fn=self.metrics.layout_fn or vepfs_layout_fn, base_name=self.metrics.file_base_name, file_ext=self.metrics.file_ext)) logging.info("model_name: {}, target {}".format(model_name, head_name)) if real_mode == tf.estimator.ModeKeys.EVAL: if is_exporting() or self.output_path is None: if isinstance(pred, (list, tuple)): train_ops.extend(pred) else: train_ops.append(pred) return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=tf.group(train_ops), training_hooks=self._training_hooks) else: op_file, write_op = self._get_file_ops(features, pred) close_hook = file_ops.FileCloseHook([op_file]) with tf.control_dependencies(control_inputs=[write_op]): if isinstance(pred, (list, tuple)): train_ops.extend([tf.identity(p) for p in pred]) else: train_ops.append(tf.identity(pred)) return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=tf.group(train_ops), training_hooks=[close_hook] + self._training_hooks) else: # training if hasattr(local_spec, 'optimizer'): dense_optimizer = local_spec.optimizer elif hasattr(self, '_default_dense_optimizer'): dense_optimizer = self._default_dense_optimizer else: raise Exception("dense_optimizer not found!") dump_utils.add_optimizer(dense_optimizer) train_ops.append( feature_utils.apply_gradients_with_var_optimizer( self.ctx, self.fc_dict.values(), dense_optimizer, loss, clip_type=feature_utils.GradClipType.ClipByGlobalNorm, clip_norm=self.clip_norm, dense_weight_decay=self.dense_weight_decay, global_step=self._global_step, grads_and_vars_summary=self.enable_grads_and_vars_summary, sparse_norm_warmup_steps=self.sparse_norm_warmup_steps, is_fused_layout=self.is_fused_layout(), use_allreduce=self._use_dense_allreduce)) add_batch_norm_into_update_ops() update_ops = tf.compat.v1.get_collection( tf.compat.v1.GraphKeys.UPDATE_OPS) logging.info('update_ops: %s', update_ops) with tf.compat.v1.control_dependencies(update_ops): train_op = tf.group(train_ops) return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op, training_hooks=self._training_hooks) return model_fn_internal def create_serving_input_receiver_fn(self): """生在Serving数据流, serving_input_receiver_fn""" def serving_input_receiver_fn_internal(): with MonolithDeviceCtx(ctx_type=DeviceCtxType.INPUT_RECEIVER_FN): return self.serving_input_receiver_fn() return dump_utils.record_receiver(serving_input_receiver_fn_internal) @abstractmethod def input_fn(self, mode: tf.estimator.ModeKeys) -> DatasetV2: """抽象方法, 定义数据流 Args: mode (:obj:`str`): 训练模式, train/eval/predict等 Returns: DatasetV2, TF数据集 """ raise NotImplementedError('input_fn() not Implemented') @abstractmethod def model_fn( self, features: Dict[str, tf.Tensor], mode: tf.estimator.ModeKeys ) -> Union[EstimatorSpec, Tuple[tf.Tensor, tf.Tensor, tf.Tensor]]: """抽象方法, 定义模型 Args: features (:obj:`Dict[str, tf.Tensor]`): 特征 mode (:obj:`str`): 训练模式, train/eval/predict等 Returns: Union[EstimatorSpec, Tuple[tf.Tensor, tf.Tensor, tf.Tensor]], 可以是tuple, 包括(loss, label, predict), 也可以是EstimatorSpec """ raise NotImplementedError('generate_model() not Implemented') @abstractmethod def serving_input_receiver_fn(self) -> ServingInputReceiver: """Serving数据流, 训练数据流与Serving数据流或能不一样 Returns: ServingInputReceiver """ raise NotImplementedError('serving_input_receiver_fn() not Implemented') @property def _export_outputs(self): graph = tf.compat.v1.get_default_graph() if hasattr(graph, '__export_outputs'): return getattr(graph, '__export_outputs') else: setattr(graph, '__export_outputs', {}) return graph.__export_outputs def add_extra_output(self, name: str, outputs: Union[tf.Tensor, Dict[str, tf.Tensor]], head_name: str = None, head_type: str = None): """如果有出多输出, 可以用add_extra_output, 每个输出会成为Serving中的一个Signature Args: name (:obj:`str`): 签名的名称 outputs (:obj:`Union[tf.Tensor, Dict[str, tf.Tensor]]`): 输出, 可以是一个Tensor, 也可以是一个Dict[str, tf.Tensor] head_name (:obj:`str`): output对应的head的名称 head_name (:obj:`str`): output对应的head的类型, 如user, item, context等 """ add_to_collections('signature_name', name) if is_exporting(): exported_outputs = self._export_outputs if name not in exported_outputs: exported_outputs[name] = tf.estimator.export.PredictOutput(outputs) else: raise KeyError("key {name} exists!".format(name=name)) def add_training_hook(self, hook): if isinstance(hook, KafkaMetricHook): if any(isinstance(h, KafkaMetricHook) for h in self._training_hooks): return elif isinstance(hook, FileMetricHook): if any(isinstance(h, FileMetricHook) for h in self._training_hooks): return self._training_hooks.append(hook) def add_layout(self, name: str, slice_list: list, out_type: str, shape_list: list): if out_type == 'concat': out_conf = OutConfig(out_type=OutType.CONCAT) elif out_type == 'stack': out_conf = OutConfig(out_type=OutType.STACK) elif out_type == 'addn': out_conf = OutConfig(out_type=OutType.ADDN) else: out_conf = OutConfig(out_type=OutType.NONE) for feature_name, slice_conf in slice_list: slice_config = out_conf.slice_configs.add() slice_config.feature_name = feature_name slice_config.start = slice_conf.start slice_config.end = slice_conf.end for shape in shape_list: shape_dims = out_conf.shape.add() for i, dim in enumerate(shape): if i == 0: shape_dims.dims.append(-1) else: if isinstance(dim, int): shape_dims.dims.append(dim) else: assert hasattr(dim, 'value') shape_dims.dims.append(dim.value) self._layout_dict[name] = out_conf @property def layout_dict(self): return self._layout_dict @layout_dict.setter def layout_dict(self, layouts): self._layout_dict = layouts @monolith_export class MonolithModel(MonolithBaseModel): '''模型开发的基类 Args: params (:obj:`Params`): 配置参数, 默认为None ''' @classmethod def params(cls): p = super(MonolithModel, cls).params() p.define("feature_list", None, "The feature_list conf file.") return p def __init__(self, params=None): params = params or type(self).params() super(MonolithModel, self).__init__(params) dump_utils.enable = FLAGS.enable_model_dump def _get_fs_conf(self, shared_name: str, slot: int, occurrence_threshold: int, expire_time: int) -> FeatureSlotConfig: return FeatureSlotConfig( name=shared_name, has_bias=False, slot_id=slot, occurrence_threshold=occurrence_threshold, expire_time=expire_time, hashtable_config=entry.GpucucoHashTableConfig() if self.p.train.use_gpu_emb_table else entry.CuckooHashTableConfig()) def _embedding_slice_lookup(self, fc: Union[str, FeatureColumn], slice_name: str, slice_dim: int, initializer: Initializer, optimizer: Optimizer, compressor: Compressor, learning_rate_fn, slice_list: list) -> FeatureSlice: assert not self.is_fused_layout() if isinstance(fc, str): fc = self.fc_dict[fc] feature_slot = fc.feature_slot feature_name = self._share_slot_mapping.get( fc.feature_name, fc.feature_name) if feature_name in self.slice_dict: if slice_name in self.slice_dict[feature_name]: fc_slice = self.slice_dict[feature_name][slice_name] else: fc_slice = feature_slot.add_feature_slice(slice_dim, initializer, optimizer, compressor, learning_rate_fn) self.slice_dict[feature_name][slice_name] = fc_slice else: fc_slice = feature_slot.add_feature_slice(slice_dim, initializer, optimizer, compressor, learning_rate_fn) self.slice_dict[feature_name] = {slice_name: fc_slice} slice_list.append((fc.feature_name, fc_slice)) return fc.embedding_lookup(fc_slice) @dump_utils.record_feature def create_embedding_feature_column(self, feature_name, occurrence_threshold: int = None, expire_time: int = 36500, max_seq_length: int = 0, shared_name: str = None, combiner: str = None) -> FeatureColumn: """创建嵌入特征列(embedding feature column) Args: feature_name (:obj:`Any`): 特征列的名字 occurrence_threshold (:obj:`int`): 用于低频特征过滤, 如果出现次数小于`occurrence_threshold`, 则这个特征将大概率不会进入模型 expire_time (:obj:`int`): 特征过期时间, 如果一个特征在`expire_time`之内没有更新了, 则这个特征可能从hash表中移除 max_seq_length (:obj:`int`): 如果设为0, 表示非序列特征, 如果设为正数, 则表示序列特征的长度 shared_name (:obj:`str`): 共享embedding. 如果本feature与另一个feature共享embedding, 则可以将被共享feature设为`shared_name` Returns: FeatureColumn, 特征列 """ if combiner and isinstance(combiner, str): assert combiner in {'reduce_sum', 'reduce_mean', 'first_n'} if combiner == 'reduce_sum': combiner = FeatureColumn.reduce_sum() elif combiner == 'reduce_mean': combiner = FeatureColumn.reduce_mean() else: combiner = FeatureColumn.first_n() feature_name, slot = get_feature_name_and_slot(feature_name) if feature_name in self.fc_dict: return self.fc_dict[feature_name] else: if shared_name is not None and len(shared_name) > 0: self._share_slot_mapping[feature_name] = shared_name if shared_name in self.fs_dict: fs = self.fs_dict[shared_name] elif shared_name in self.fc_dict: fs = self.fc_dict[shared_name].feature_slot else: try: shared_name, shared_slot = get_feature_name_and_slot(shared_name) shared_fs = self.ctx.create_feature_slot( self._get_fs_conf(shared_name, shared_slot, occurrence_threshold, expire_time)) self.fs_dict[shared_name] = shared_fs fs = shared_fs except: raise Exception( f"{feature_name} shared embedding with {shared_name}, so {shared_name} should create first!" ) else: fs = self.ctx.create_feature_slot( self._get_fs_conf(feature_name, slot, occurrence_threshold, expire_time)) if combiner is None: if max_seq_length > 0: combiner = FeatureColumn.first_n(max_seq_length) else: combiner = FeatureColumn.reduce_sum() fc = FeatureColumn(fs, feature_name, combiner=combiner) self.fc_dict[feature_name] = fc return fc @dump_utils.record_slice def lookup_embedding_slice(self, features, slice_name, slice_dim=None, initializer: Initializer = None, optimizer: Optimizer = None, compressor: Compressor = None, learning_rate_fn=None, group_out_type: str = 'add_n', out_type: str = None) -> tf.Tensor: """Monolith中embedding是分切片的, 每个切片可以有独立的初始化器, 优化器, 压缩器, 学习率等. 切片的引入使Embedding更加强大. 如某些情况 下要共享Embedding, 另一些情况下要独立Embedding, 与一些域交叉要用一种Embedding, 与另一些域交叉用另一种Embedding等. 切片的引入可以方便 解上以上问题. 切片与完整Embedding的关系由Monolith自动维护, 对用户透明. Args: slice_name (:obj:`str`): 切片名称 features (:obj:`List[str], Dict[str, int]`): 支持三种形式 1) 特征名列表, 此时每个切片的长度相同, 由`slice_dim`确定, 不能为None 2) 特征 (特征名, 切片长度) 列表, 此时每个切片的长度可以不同, 全局的`slice_dim`必须为None 3) 特征字典, 特征名 -> 切片长度, 此时每个切片的长度可以不同, 全局的`slice_dim`必须为None slice_dim (:obj:`int`): 切片长度 initializer (:obj:`Initializer`): 切片的初始化器, Monolith中的初始化器, 不能是TF中的 optimizer (:obj:`Optimizer`): 切片的优化器, Monolith中的优化器, 不能是TF中的 compressor (:obj:`Compressor`): 切片的压缩器, 用于在Servering模型加载时将模型压缩 learning_rate_fn (:obj:`tf.Tensor`): 切片的学习率 """ concat = ",".join(sorted(map(str, features))) layout_name = f'{slice_name}_{hashlib.md5(concat.encode()).hexdigest()}' if self.is_fused_layout(): if isinstance(features, (list, tuple)) and isinstance(slice_dim, int): if all(isinstance(ele, (tuple, list)) for ele in features): raise ValueError("group pool is not support when fused_layout") return self.ctx.layout_factory.get_layout(layout_name) feature_embeddings, slice_list = [], [] if isinstance(features, dict): for fc_name, sdim in features.items(): fc_name, _ = get_feature_name_and_slot(fc_name) feature_embeddings.append( self._embedding_slice_lookup(fc_name, slice_name, sdim, initializer, optimizer, compressor, learning_rate_fn, slice_list)) elif isinstance(features, (list, tuple)) and isinstance(slice_dim, int): if all(isinstance(ele, (str, int, FeatureColumn)) for ele in features): # a list of feature with fixed dim for fc_name in features: fc_name, _ = get_feature_name_and_slot(fc_name) feature_embeddings.append( self._embedding_slice_lookup(fc_name, slice_name, slice_dim, initializer, optimizer, compressor, learning_rate_fn, slice_list)) elif all(isinstance(ele, (tuple, list)) for ele in features): assert group_out_type in {'concat', 'add_n'} for group_name in features: assert all(isinstance(ele, int) for ele in group_name) local_embeddings = [] for fc_name in group_name: fc_name, _ = get_feature_name_and_slot(fc_name) local_embeddings.append( self._embedding_slice_lookup(fc_name, slice_name, slice_dim, initializer, optimizer, compressor, learning_rate_fn, slice_list)) if group_out_type == 'add_n': feature_embeddings.append(tf.add_n(local_embeddings)) else: feature_embeddings.append(tf.concat(local_embeddings, axis=1)) else: raise ValueError("ValueError for features") elif isinstance(features, (list, tuple)): if all([ isinstance(ele, (tuple, list)) and len(ele) == 2 for ele in features ]): for fc_name, sdim in features: fc_name, _ = get_feature_name_and_slot(fc_name) feature_embeddings.append( self._embedding_slice_lookup(fc_name, slice_name, sdim, initializer, optimizer, compressor, learning_rate_fn, slice_list)) else: raise ValueError("ValueError for features") else: raise ValueError("ValueError for features") if out_type is None: shape_list = [emb.shape for emb in feature_embeddings] self.add_layout(layout_name, slice_list, out_type, shape_list) return feature_embeddings else: assert out_type in {'concat', 'stack', 'add_n', 'addn'} if out_type == 'concat': out = tf.concat(feature_embeddings, axis=1, name=layout_name) self.add_layout(layout_name, slice_list, out_type, shape_list=[out.shape]) return out elif out_type == 'stack': out = tf.stack(feature_embeddings, axis=1, name=layout_name) self.add_layout(layout_name, slice_list, out_type, shape_list=[out.shape]) return out else: out = tf.add_n(feature_embeddings, name=layout_name) self.add_layout(layout_name, slice_list, 'addn', shape_list=[out.shape]) return out def share_slot(self, features: Union[tf.Tensor, Dict[str, tf.RaggedTensor]] = None, share_meta: Dict[str, Tuple[bool, int]] = None, variant_type: str = 'example', suffix: str = 'share'): for name, (inplace, slot) in share_meta.items(): shared_name = f'{name}_{suffix}' if not inplace: register_slots({shared_name: slot}) else: register_slots({name: slot}) if features is not None and isinstance(features, dict): for name, (inplace, slot) in share_meta.items(): if inplace: features[name] = switch_slot(features[name], slot) else: features[shared_name] = switch_slot(features[name], slot) return features else: map_fn = lambda tensor: switch_slot_batch(tensor, share_meta, variant_type=variant_type, suffix=suffix) return map_fn ================================================ FILE: monolith/native_training/native_task.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import abc import dataclasses from typing import Any, Callable, Dict, Iterable, Tuple, Union import tensorflow as tf from monolith.core import hyperparams from monolith.core.base_task import BaseTask from monolith.native_training import feature from monolith.native_training import prefetch_queue from monolith.native_training.model_export.export_context import ExportMode from idl.matrix.proto.example_pb2 import OutConfig, OutType, TensorShape class NativeContext: """Provides the context of the NativeTask.""" def __init__(self, feature_factory: feature.FeatureFactory = None, async_function_mgr: prefetch_queue.AsyncFunctionMgr = None, layout_factory: feature.EmbeddingLayoutFactory = None): self.feature_factory = feature_factory self.async_function_mgr = async_function_mgr self.layout_factory = layout_factory if layout_factory and feature_factory: raise ValueError( "Cannot set feature_factory and layout_factory in the same time") # Provides some convinient functions def create_feature_slot( self, config: feature.FeatureSlotConfig) -> feature.FeatureSlot: """Creates a feature slot.""" # No TensorFlow op is created at this function call. if self.layout_factory: return self.layout_factory.create_feature_slot(config) else: return self.feature_factory.create_feature_slot(config) def apply_embedding_gradients(self, grads_and_vars: Iterable[Tuple[tf.Tensor, tf.Tensor]], scale=1): """ Apply gradients for embeddings. Notice vars must be coming from FeatureColumn's get_all_embeddings_concatenated. """ if self.layout_factory: return self.layout_factory.apply_gradients(grads_and_vars) else: return self.feature_factory.apply_gradients(grads_and_vars, scale=scale) def add_async_function( self, target: Callable, args: Tuple = None, kwargs: Dict = None, is_async: bool = None, queue_name: str = "async_queue") -> Union[tf.Operation, Any]: """Adds async func. Returns an enqueue op if is_async. Otherwise, returns calling result of target. Args: is_async - if not specified, will use default value in async_function_mgr. Requirements: - target should return ops/tensors which can be added to session.run All tensors used by |async_function| should *ONLY* come from arguments passed in. Otherwise, we may use updated value in the async function. TODO(leqi.zou): Adds a check for this.""" return self.async_function_mgr.add_async_function(target, args, kwargs, is_async=is_async, queue_name=queue_name) class NativeTask(BaseTask, abc.ABC): """ A task is supported to be train/eval/serving in multiple devices with native tensorflow code. """ @classmethod def params(cls): p = super(NativeTask, cls).params() # metrics p.define("metrics", hyperparams.Params(), "Metric parameters.") p.metrics.define("enable_deep_insight", False, 'Whether enable deep insight.') p.metrics.define("deep_insight_target", "ctr_head", "Deep insight target.") p.metrics.define('deep_insight_name', None, 'str') p.metrics.define('deep_insight_sample_ratio', 0.01, 'float') p.metrics.define('extra_fields_keys', [], 'extra_fields_keys for deepinsight, List[str]') # [todo] (fitz) the mode will remove when the estimator is ready p.define("mode", tf.estimator.ModeKeys.TRAIN, "run mode") p.metrics.define("enable_throughput_hook", True, "If enables throughput hook.") p.metrics.define("enable_kafka_metrics", False, "enable_kafka_metrics") p.metrics.define( "enable_tf2_profiler_hook", True, "If enables tf profiler hook. When enabled, remeber to increase worker's memory." ) p.metrics.define("enable_file_metrics", False, "enable_file_metrics") p.metrics.define("file_base_name", '/vepfs/jaguar_deepinsight_results', "file_base_name") p.metrics.define("file_ext", 'txt', "file_ext") p.metrics.define("parse_fn", None, "parse_fn") p.metrics.define("key_fn", None, "key_fn") p.metrics.define("layout_fn", None, "layout_fn") p.metrics.define("dump_filename", '', "Dump filename") p.metrics.define('use_data_service', False, "use data service") p.train.define( 'max_pending_seconds_for_barrier', 30, 'Maximum waiting time for barrier block. Used for testing in most cases.' ) p.train.define( "slow_start_steps", 0, ("How many steps will worker wait before they start to train." " The formula of wait is `slow_start_steps * log(1 + index)`")) p.train.define( "sample_bias", 0., "Sample bias is a float scalar which acts as compensation for ads " "realtime training (FastEmit training instance).") p.train.define("use_gpu_emb_table", False, "Use GPU embedding table for sync training if enabled.") p.train.define("use_fountain", False, "Use fountain data service if enabled.") p.train.define("fountain_zk_host", "", "zk_host for fountain service.") p.train.define("fountain_model_name", "", "model_name for fountain service.") p.train.define("fountain_parse_on_server", False, "Parsing logic on fountain server.") p.train.define("fountain_precompute_value_rowids", False, "Parsing logic on fountain server.") p.define("serving", hyperparams.Params(), "Serving parameters.") p.serving.define( "export_with_gpu_allowed", False, "When true it allows cpu/gpu training to export model graph " "with specified gpu device placement contexts.") p.serving.define( "export_with_cleared_entry_devices", False, "When true it clears the devices in the exported model graph" "for entry only at DistributedExporter Mode.") p.serving.define( "export_when_saving", False, "When true, a valid create_serving_input_fn must be provided. The " "framework will do export when saving. ") p.serving.define( "export_dir_base", "exported_models", "The base dir (either relative to model_dir or an absolute path) When " "exporting models.") p.serving.define("export_mode", ExportMode.DISTRIBUTED, "standalone or distributed.") p.serving.define( "shared_embedding", True, "If true, instead of exporting a hermetic SavedModel, we will use the " "embedding in checkpoints instead of copying it.") p.serving.define("with_remote_gpu", False, "If true, the whole dense will be put on the GPU.") return p def __init__(self, params): super().__init__(params) self._ctx = NativeContext() self.p = params @property def ctx(self) -> NativeContext: """Returns task ctx.""" return self._ctx @abc.abstractmethod def create_input_fn(self, mode): """ Same as BaseTask.create_input_fn """ @abc.abstractmethod def create_model_fn(self): """ For the child class, returned model_fn must follow the signature of (features, mode, config) -> SomeEstimatorSpec """ def create_serving_input_receiver_fn(self): """Returns a serving input fn for serving. See https://www.tensorflow.org/api_docs/python/tf/estimator/Estimator#export_saved_model for the possible return values for this method. By default, None is provided (which is invalid if we enable serving). """ return None ================================================ FILE: monolith/native_training/native_task_context.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import contextlib from typing import NamedTuple from monolith.agent_service.backends import SyncBackend class NativeTaskContext(NamedTuple): num_ps: int ps_index: int num_workers: int worker_index: int # Model name is used to uniquely identify a model # It will influence how we export models and do the serving. model_name: str sync_backend: SyncBackend server_type: str _CTX = None @contextlib.contextmanager def with_ctx(ctx: NativeTaskContext): global _CTX old_ctx = _CTX _CTX = ctx try: yield finally: if old_ctx is not None: _CTX = old_ctx def get(): if _CTX is None: return NativeTaskContext(num_ps=0, ps_index=0, num_workers=1, worker_index=0, server_type="", model_name="", sync_backend=None) else: return _CTX ================================================ FILE: monolith/native_training/nested_tensors.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import copy import itertools from typing import List import tensorflow as tf def _iterate(nested, action): """Iterate nested structures. `action` should take element and returns a element.""" if nested is None: pass elif isinstance(nested, (list, tuple)): r = [] for v in nested: r.append(_iterate(v, action)) if isinstance(nested, tuple): r = tuple(r) nested = r elif isinstance(nested, dict): for k, v in nested.items(): nested[k] = _iterate(nested[k], action) else: nested = action(nested) return nested class NestedTensors: def __init__(self, nested): self._nested = nested self._id_mapping = {} self._ragged_tensors = [] self._tensors = [] self._other_objs = [] self._nested = _iterate(self._nested, self._add_tensor) def _add_tensor(self, tensor): obj_id = id(tensor) if not obj_id in self._id_mapping: if isinstance(tensor, tf.Tensor): self._id_mapping[obj_id] = (0, len(self._tensors)) self._tensors.append(tensor) elif isinstance(tensor, tf.RaggedTensor): if tensor.ragged_rank != 1: raise ValueError("Nested tensor doesn't support nested RaggedTensor.") self._id_mapping[obj_id] = (1, len(self._ragged_tensors)) self._ragged_tensors.append(tensor) elif isinstance(tensor, (bool, int, str, tf.Variable, None)): # There are some cases we want to keep it as it is. self._id_mapping[obj_id] = (2, len(self._other_objs)) self._other_objs.append(tensor) else: raise ValueError("Tensor is not supported. {}".format(tensor)) return obj_id def get_tensors(self) -> List[tf.Tensor]: flatten_ragged_tensors = self._ragged_to_flatten(self._ragged_tensors) return self._tensors + flatten_ragged_tensors def get_nested_result(self, tensors: List[tf.Tensor]): flatten_ragged_tensors = tensors[len(self._tensors):] tensors = tensors[:len(self._tensors)] assert len(flatten_ragged_tensors) == len(self._ragged_tensors) * 2 ragged_tensors = self._flatten_to_ragged(flatten_ragged_tensors) tensor_tuple = (tensors, ragged_tensors, self._other_objs) result = copy.deepcopy(self._nested) def action(obj_id): idx = self._id_mapping[obj_id] return tensor_tuple[idx[0]][idx[1]] return _iterate(result, action) @staticmethod def _convert_ragged_to_tensors(ragged): return ragged.values, ragged.row_splits @staticmethod def _convert_tensors_to_ragged(values, row_splits): return tf.RaggedTensor.from_row_splits(values, row_splits, validate=False) def _ragged_to_flatten(self, ragged_tensors): return list( itertools.chain.from_iterable((self._convert_ragged_to_tensors(ragged) for ragged in ragged_tensors))) def _flatten_to_ragged(self, tensors): ragged_values = tensors[::2] ragged_row_splits = tensors[1::2] return [ self._convert_tensors_to_ragged(*combined) for combined in zip(ragged_values, ragged_row_splits) ] ================================================ FILE: monolith/native_training/nested_tensors_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import tensorflow as tf from monolith.native_training import nested_tensors class NestedTensorTest(tf.test.TestCase): def testBasic(self): n = nested_tensors.NestedTensors({ "a": tf.ones([]), "b": (tf.ones([]), tf.ones([])), }) tensors = n.get_tensors() replaced = [tf.zeros_like(tensor) for tensor in tensors] result = n.get_nested_result(replaced) result = self.evaluate(result) self.assertDictEqual(result, {"a": 0, "b": (0, 0)}) def testConstant(self): n = nested_tensors.NestedTensors({"a": {"b": 2}}) tensors = n.get_tensors() self.assertLen(tensors, 0) result = n.get_nested_result([]) self.assertDictEqual(result, {"a": {"b": 2}}) def testRaggedTensor(self): n = nested_tensors.NestedTensors(tf.ragged.constant([[], [1], [2, 3]])) tensors = n.get_tensors() result = n.get_nested_result(tensors) self.assertAllEqual(result, [[], [1], [2, 3]]) def testRaggedTensorWithPlaceHolder(self): n = nested_tensors.NestedTensors(tf.ragged.constant([[], [1], [2, 3]])) tensors = n.get_tensors() phs = [tf.compat.v1.placeholder(dtype=t.dtype) for t in tensors] result = n.get_nested_result(tensors) if __name__ == "__main__": tf.compat.v1.disable_eager_execution() tf.test.main() ================================================ FILE: monolith/native_training/net_utils.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from absl import logging import os from queue import Queue, Empty import socket import threading from typing import Dict, List import ipaddress class NodeAliveChecker: def __init__(self, addrs: List, timeout: int = 1, num_thread: int = 10): self._addrs = addrs self._timeout = timeout self._num_thread = num_thread self._lock = threading.Lock() self._alive = set() self._dead = set() self._q = Queue() for addr in self._addrs: self._q.put(addr) self._start() def _ping(self, addr): skt = None try: ip, port = addr.rsplit(':', 1) ip = ip.strip('[]') is_ipv6 = is_ipv6_address(ip) skt = socket.socket(socket.AF_INET6 if is_ipv6 else socket.AF_INET, socket.SOCK_STREAM) skt.settimeout(self._timeout) skt.connect((ip, int(port))) with self._lock: self._alive.add(addr) except Exception as err: print("cannot connect to {}, because {}".format(addr, err)) with self._lock: self._dead.add(addr) finally: if skt: skt.close() def _check_open(self): try: while True: addr = self._q.get_nowait() self._ping(addr) except Empty as err: pass def _start(self): threads = [] for i in range(self._num_thread): t = threading.Thread(target=self._check_open) t.start() threads.append(t) for t in threads: t.join() def all_nodes_alive(self): with self._lock: return len(self._dead) == 0 def get_dead_nodes(self): with self._lock: return list(self._dead) def get_alive_nodes(self): with self._lock: return list(self._alive) def get_addrs(self): with self._lock: return self._addrs def is_ipv6_address(ip: str): try: ip_obj = ipaddress.ip_address(ip) except ValueError: return False return ip_obj.version == 6 def concat_ip_and_port(ip: str, port: int): if not is_ipv6_address(ip): return f"{ip}:{port}" else: return f"[{ip}]:{port}" def get_local_ip(): try: return socket.getaddrinfo(socket.gethostname(), None)[0][4][0] except socket.gaierror: return socket.getaddrinfo(socket.gethostname(), None, family=socket.AF_INET6)[0][4][0] def is_ipv4_supported(): return not is_ipv6_address(get_local_ip()) def get_local_server_addr(port: int): """Given a port. Returns an addr. In the machine that supports IPv4, it is equivalent to gethostbyname(gethostname()). """ return concat_ip_and_port(get_local_ip(), port) class AddressFamily(object): IPV4 = 'ipv4' IPV6 = 'ipv6' ================================================ FILE: monolith/native_training/net_utils_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from absl import logging import unittest from unittest import mock import random import time from monolith.native_training import net_utils _SOCKET = 'monolith.native_training.net_utils.socket.socket' _FAILED_TIME = 0 _DEAD_SET = set() class socket: AF_INET = -1 SOCK_STREAM = -1 def __init__(self, family=-1, stype=-1): self._family = family self._stype = stype self._timeout = 1 def settimeout(self, timeout): self._timeout = timeout def connect(self, addr): ip, port = addr sleep = random.uniform(0, 2 * self._timeout) time.sleep(sleep) print('sleep {}, connect to {}:{}'.format(sleep, ip, port)) if sleep > self._timeout: global _FAILED_TIME global _DEAD_SET _FAILED_TIME += 1 tmp_add = ':'.join([ip, str(port)]) _DEAD_SET.add(tmp_add) raise RuntimeError('{}:{} connect error'.format(ip, port)) def close(self): pass @classmethod def socket(cls, family, stype): return socket(family, stype) class NetUtilsTest(unittest.TestCase): def test_basic(self): with mock.patch(_SOCKET) as tmp_socket: tmp_socket.return_value = socket() addrs = [ 'localhost:1233', 'localhost:1234', 'localhost:1235', 'localhost:1236', 'localhost:1238' ] alive_checker = net_utils.NodeAliveChecker(addrs) self.assertEqual(set(alive_checker.get_addrs()), set(addrs)) self.assertEqual(len(alive_checker.get_alive_nodes()), 5 - _FAILED_TIME) self.assertEqual(len(alive_checker.get_dead_nodes()), _FAILED_TIME) self.assertEqual(set(alive_checker.get_alive_nodes()), set(addrs) - _DEAD_SET) self.assertEqual(set(alive_checker.get_dead_nodes()), _DEAD_SET) self.assertEqual(alive_checker.all_nodes_alive(), _FAILED_TIME == 0) def test_concat_ip_and_port(self): self.assertEqual(net_utils.concat_ip_and_port("localhost", 10), "localhost:10") self.assertEqual(net_utils.concat_ip_and_port("127.0.0.1", 10), "127.0.0.1:10") self.assertEqual(net_utils.concat_ip_and_port("::1", 10), "[::1]:10") def test_get_local_server_addr(self): self.assertIsNotNone(net_utils.get_local_server_addr(10)) if __name__ == "__main__": unittest.main() ================================================ FILE: monolith/native_training/optimizers/BUILD ================================================ load("@rules_python//python:defs.bzl", "py_binary", "py_library", "py_test") load("@org_tensorflow//tensorflow:tensorflow.bzl", "tf_cc_test", "tf_custom_op_library", "tf_kernel_library") package( default_visibility = ["//visibility:public"], ) tf_kernel_library( name = "training_ops", srcs = [ "cc/kernels/training_op_helpers.h", "cc/kernels/training_ops.h", "cc/kernels/training_ops.cc", "cc/training_ops.cc", ], gpu_srcs = [ "cc/kernels/training_op_helpers.h", "cc/kernels/training_ops.h", "cc/kernels/training_ops_gpu.cu.cc", ], deps = [ "@org_tensorflow//tensorflow/core:framework_headers_lib", "@org_tensorflow//tensorflow/core/kernels:gpu_device_array_for_custom_op", ], alwayslink = 1, ) py_library( name = "adamom", srcs = ["adamom.py"], deps = [ "//monolith/native_training/runtime/ops:gen_monolith_ops", ], ) py_test( name = "adamom_test", srcs = ["adamom_test.py"], deps = [ ":adamom", ], ) py_library( name = "shampoo", srcs = ["shampoo.py"], deps = [ "//monolith:utils", "@org_tensorflow//tensorflow:tensorflow_py", ], ) py_library( name = "rmsprop", srcs = ["rmsprop.py"], deps = [ "//monolith/native_training/runtime/ops:gen_monolith_ops", ], ) py_test( name = "rmsprop_test", srcs = ["rmsprop_test.py"], deps = [ ":rmsprop", ], ) py_test( name = "rmspropv2_test", srcs = ["rmspropv2_test.py"], deps = [ ":rmsprop", ], ) ================================================ FILE: monolith/native_training/optimizers/adamom.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import tensorflow as tf from monolith.native_training.runtime.ops import gen_monolith_ops training_ops = gen_monolith_ops class AdamomOptimizer(tf.compat.v1.train.Optimizer): def __init__(self, learning_rate=5e-6, ada_decay: float = 0.9999, mom_decay: float = 0.99, epsilon: float = 1e-6, weight_decay: float = 0.0, use_locking: bool = False, name="Adamom"): super().__init__(use_locking, name) self._learning_rate = learning_rate self._ada_decay = ada_decay self._mom_decay = mom_decay self._epsilon = epsilon self._weight_decay = weight_decay # Created in Initialize. self._learning_rate_tensor = None def _create_slots(self, var_list): # Create slots for the first and second moments. for v in var_list: self._zeros_slot(v, "m", self._name + "/m") self._zeros_slot(v, "v", self._name + "/v") self._zeros_slot(v, "c", self._name + "/c") def _prepare(self): learning_rate = self._call_if_callable(self._learning_rate) self._learning_rate_tensor = tf.convert_to_tensor(learning_rate, name="learning_rate") def _resource_apply_dense(self, grad, var): m = self.get_slot(var, "m") v = self.get_slot(var, "v") c = self.get_slot(var, "c") return training_ops.resource_apply_adamom(var.handle, m.handle, v.handle, c.handle, tf.cast( self._learning_rate_tensor, grad.dtype.base_dtype), self._ada_decay, self._mom_decay, self._epsilon, self._weight_decay, grad, use_locking=self._use_locking) ================================================ FILE: monolith/native_training/optimizers/adamom_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import tensorflow as tf from tensorflow.python.framework.ops import name_from_scope_name from monolith.native_training.optimizers import adamom class AdamomTest(tf.test.TestCase): def testBasic(self): v = tf.Variable([0.1], name="var") loss = 0.12 * v opt = adamom.AdamomOptimizer(learning_rate=0.1, weight_decay=0.01, ada_decay=0.99, mom_decay=0.9) update = opt.minimize(loss) with self.session() as sess: sess.run(tf.compat.v1.global_variables_initializer()) sess.run(update) all_vars = tf.compat.v1.all_variables() vars_map = sess.run({var.name: var for var in all_vars}) eps = 1e-8 found_count = 0 for name, val in vars_map.items(): if name.find("/m") >= 0: found_count += 1 self.assertNear(val, 0.0121, eps) elif name.find("/c") >= 0: found_count += 1 self.assertNear(val, 1.0, eps) elif name.find("/v") >= 0: found_count += 1 self.assertNear(val, 0.014641, eps) else: found_count += 1 # Must be variable self.assertNear(val, 0.090000336, eps) self.assertEqual(found_count, 4) if __name__ == "__main__": tf.compat.v1.disable_eager_execution() tf.test.main() ================================================ FILE: monolith/native_training/optimizers/cc/kernels/training_op_helpers.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ // Copied from tensorflow/core/kernels/training_op_helpers.h #ifndef MONOLITH_NATIVE_TRAINING_OPTIMIZERS_CC_TRAINING_OP_HELPERS_H_ #define MONOLITH_NATIVE_TRAINING_OPTIMIZERS_CC_TRAINING_OP_HELPERS_H_ #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_var.h" #include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/framework/variant.h" namespace tensorflow { namespace monolith_tf { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; #ifdef TENSORFLOW_USE_SYCL typedef Eigen::SyclDevice SYCLDevice; #endif // TENSORFLOW_USE_SYCL enum DenseUpdateType { ADD, SUB, ASSIGN }; namespace functor { template struct DenseUpdate { void operator()(const Device& d, typename TTypes::Flat params, typename TTypes::ConstFlat update); }; template struct DenseUpdate { void operator()(const CPUDevice& d, typename TTypes::Flat params, typename TTypes::ConstFlat update) { params.device(d) += update; } }; template struct DenseUpdate { void operator()(const CPUDevice& d, typename TTypes::Flat params, typename TTypes::ConstFlat update) { params.device(d) -= update; } }; template struct DenseUpdate { void operator()(const CPUDevice& d, typename TTypes::Flat params, typename TTypes::ConstFlat update) { params.device(d) = update; } }; #if GOOGLE_CUDA template struct DenseUpdate { void operator()(const GPUDevice& d, typename TTypes::Flat params, typename TTypes::ConstFlat update) { params.device(d) = update; } }; template struct DenseUpdate { void operator()(const GPUDevice& d, typename TTypes::Flat params, typename TTypes::ConstFlat update) { params.device(d) += update; } }; template struct DenseUpdate { void operator()(const GPUDevice& d, typename TTypes::Flat params, typename TTypes::ConstFlat update) { params.device(d) -= update; } }; #endif #ifdef TENSORFLOW_USE_SYCL template struct DenseUpdate { void operator()(const SYCLDevice& d, typename TTypes::Flat params, typename TTypes::ConstFlat update) { params.device(d) += update; } }; template struct DenseUpdate { void operator()(const SYCLDevice& d, typename TTypes::Flat params, typename TTypes::ConstFlat update) { params.device(d) -= update; } }; template struct DenseUpdate { void operator()(const SYCLDevice& d, typename TTypes::Flat params, typename TTypes::ConstFlat update) { params.device(d) = update; } }; #endif // TENSORFLOW_USE_SYCL } // end namespace functor using shape_inference::DimensionHandle; using shape_inference::InferenceContext; using shape_inference::ShapeHandle; // Must be called before performing a sparse operation on a variable. Ensures // that no concurrent dense operations can happen while holding the variable's // lock. template Status EnsureSparseVariableAccess(OpKernelContext* ctx, Var* var) { if (var->copy_on_read_mode.load()) { return Status::OK(); } mutex_lock ml(*var->mu()); // Once copy-on-read mode is True the refcount is guaranteed to be 1. This can // also happen if there are no concurrent reads of the variable and // copy-on-read mode is false. if (var->tensor()->RefCountIsOne()) { var->copy_on_read_mode.store(true); return Status::OK(); } PersistentTensor unused; Tensor* tmp; if (std::is_same::value) { AllocatorAttributes attr; attr.set_on_host(true); TF_RETURN_IF_ERROR(ctx->allocate_persistent( var->tensor()->dtype(), var->tensor()->shape(), &unused, &tmp, attr)); const auto elements_in = var->tensor()->flat(); auto elements_out = tmp->flat(); for (int64 i = 0; i < elements_in.size(); ++i) { elements_out(i) = elements_in(i); } } else { AllocatorAttributes attr; attr.set_gpu_compatible(true); attr.set_nic_compatible(true); TF_RETURN_IF_ERROR(ctx->allocate_persistent( var->tensor()->dtype(), var->tensor()->shape(), &unused, &tmp, attr)); functor::DenseUpdate copy_functor; copy_functor(ctx->eigen_device(), tmp->flat(), const_cast(var->tensor())->flat()); } *var->tensor() = *tmp; var->copy_on_read_mode.store(true); return Status::OK(); } // Utility structure that releases a sequence of borrowed mutexes when it is // deleted. struct VariableInputLockHolder { public: VariableInputLockHolder( std::vector vars, std::unique_ptr> locks, std::unique_ptr> shared_locks) : vars_(std::move(vars)), locks_(std::move(locks)), shared_locks_(std::move(shared_locks)) {} VariableInputLockHolder(VariableInputLockHolder&& other) : vars_(std::move(other.vars_)), locks_(std::move(other.locks_)), shared_locks_(std::move(other.shared_locks_)) {} ~VariableInputLockHolder() { // Release the locks before unreffing the Vars, because each lock // is potentially borrowed from a Var in vars_. locks_.reset(); for (Var* var : vars_) { var->Unref(); } } private: std::vector vars_; // NOTE: Use a `std::unique_ptr` instead of moving in a vector directly, // because a `std::vector` is not movable on all platforms. std::unique_ptr> locks_; std::unique_ptr> shared_locks_; }; // Returns a borrowed pointer to the mutex for the variable `input` in `ctx`. // // If `input` corresponds to a `DT_RESOURCE`-type variable input, // `*maybe_resource` will be updated to contain the underlying resource, and the // caller will be responsible for calling `Unref()` on that resource. template mutex* GetTrainingVariableMutex(OpKernelContext* ctx, int input, bool sparse, Var** maybe_resource) { *maybe_resource = nullptr; if (ctx->input_dtype(input) == DT_RESOURCE) { if (LookupResource(ctx, HandleFromInput(ctx, input), maybe_resource).ok()) { if (sparse) { EnsureSparseVariableAccess(ctx, *maybe_resource) .IgnoreError(); } return (*maybe_resource)->mu(); } else { ctx->CtxFailureWithWarning( errors::Internal("Invalid variable reference.")); return nullptr; } } return ctx->input_ref_mutex(input); } // MaybeLockVariableInputMutexesInOrder is a helper function to acquire mutexes // in address order to mitigate deadlock. Returns a structure that, when // deleted, will release the acquired mutexes. Safe to pass duplicates - will // only lock each distinct mutex once. If sparse is true will ensure the // variable gets switched to copy-on-read mode before trying to acquire the // locks. If do_lock is false, returns immediately for reference variables. For // resource variables in copy-on-read-mode it will grab a shared lock if do_lock // is false, exclusive lock otherwise. Note that this silently doesn't lock // mutexes for invalid variable references; in all usages this is followed by // GetInputTensor which will signal a failure. template VariableInputLockHolder MaybeLockVariableInputMutexesInOrder( OpKernelContext* ctx, bool do_lock, bool sparse, const std::vector& input_ids) { bool any_resource = false; for (auto i : input_ids) { if (ctx->input_dtype(i) == DT_RESOURCE) { any_resource = true; break; } } if (!do_lock && !any_resource) { return VariableInputLockHolder({}, {}, {}); } std::vector vars; std::vector mutexes; std::vector acquire_order; for (auto input : input_ids) { Var* var; mutex* mutex = GetTrainingVariableMutex(ctx, input, sparse, &var); if (var) vars.push_back(var); // Only lock each mutex once if duplicates exist (n^2 but n is 2 or 3). if (std::find(mutexes.begin(), mutexes.end(), mutex) == mutexes.end()) { acquire_order.push_back(mutexes.size()); mutexes.push_back(mutex); } } std::sort(acquire_order.begin(), acquire_order.end(), [&mutexes](int a, int b) { return mutexes[a] < mutexes[b]; }); auto locks = absl::make_unique>(); auto shared_locks = absl::make_unique>(); locks->reserve(acquire_order.size()); for (auto input : acquire_order) { Var* var; mutex* mu = GetTrainingVariableMutex(ctx, input, sparse, &var); core::ScopedUnref scoped_unref(var); if (mu != nullptr) { if (!sparse || do_lock) { locks->emplace_back(*mu); } else { shared_locks->emplace_back(*mu); } } } return VariableInputLockHolder(std::move(vars), std::move(locks), std::move(shared_locks)); } inline void MaybeForwardRefInputToRefOutput(OpKernelContext* ctx, int input, int output) { if (ctx->input_dtype(input) != DT_RESOURCE) { ctx->forward_ref_input_to_ref_output(input, output); } } // This is for use with ResourceVariables to ensure *tensor has a // reference count of 1 before you update it. // REQUIRES: If you pass in variable->tensor(), *variable->mu() must be held. template Status PrepareToUpdateVariable(OpKernelContext* ctx, Tensor* tensor, bool copy_on_read_mode) { if (copy_on_read_mode || !tensor->RefCountIsOne()) { // Tensor's buffer is in use by some read, so we need to copy before // updating. PersistentTensor unused; Tensor* tmp; if (std::is_same::value) { AllocatorAttributes attr; attr.set_on_host(true); TF_RETURN_IF_ERROR(ctx->allocate_persistent( tensor->dtype(), tensor->shape(), &unused, &tmp, attr)); const auto elements_in = tensor->flat(); auto elements_out = tmp->flat(); for (int64 i = 0; i < elements_in.size(); ++i) { elements_out(i) = elements_in(i); } } else { AllocatorAttributes attr; attr.set_gpu_compatible(true); attr.set_nic_compatible(true); TF_RETURN_IF_ERROR(ctx->allocate_persistent( tensor->dtype(), tensor->shape(), &unused, &tmp, attr)); functor::DenseUpdate copy_functor; copy_functor(ctx->eigen_device(), tmp->flat(), const_cast(tensor)->flat()); } *tensor = *tmp; } return Status::OK(); } // This gives you `*out`, a tensor you can update, corresponding to a variable // passed as input index `input`. This handles the differences between // reference and resource variables. For reference variables we can just grab // the tensor, grabbing the lock if lock_held is False. // // For resource variables we, if sparse is true, ensure it's in copy-on-read // mode, and then, regardless of the value of sparse, ensure its refcount is 1 // (by potentially copying its contents). In this case lock_held is ignored. template Status GetInputTensorFromVariable(OpKernelContext* ctx, int input, bool lock_held, bool sparse, Tensor* out) { if (ctx->input_dtype(input) == DT_RESOURCE) { core::RefCountPtr var; TF_RETURN_IF_ERROR(LookupResource(ctx, HandleFromInput(ctx, input), &var)); if (sparse) { TF_RETURN_IF_ERROR(EnsureSparseVariableAccess(ctx, var.get())); *out = *var->tensor(); return Status::OK(); } TF_RETURN_IF_ERROR(PrepareToUpdateVariable( ctx, var->tensor(), var->copy_on_read_mode.load())); *out = *var->tensor(); return Status::OK(); } *out = ctx->mutable_input(input, lock_held); return Status::OK(); } } // namespace monolith_tf } // namespace tensorflow #endif // MONOLITH_NATIVE_TRAINING_OPTIMIZERS_CC_TRAINING_OP_HELPERS_H_ ================================================ FILE: monolith/native_training/optimizers/cc/kernels/training_ops.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #define EIGEN_USE_THREADS #include "monolith/native_training/optimizers/cc/kernels/training_ops.h" namespace tensorflow { namespace monolith_tf { template <> struct ApplyRmsprop { void operator()(const CPUDevice& d, typename TTypes::Flat var, typename TTypes::Flat m, typename TTypes::Flat v, typename TTypes::ConstScalar lr, typename TTypes::ConstScalar beta1, typename TTypes::ConstScalar beta2, typename TTypes::ConstScalar epsilon, typename TTypes::ConstScalar weight_decay, typename TTypes::ConstFlat grad, bool update_slots) { auto grad_after_decay = weight_decay() * var + grad; if (update_slots) { v.device(d) += (grad_after_decay.square() - v) * (1.0f - beta2()); m.device(d) = beta1() * m + (grad_after_decay * lr()) * (v + epsilon()).rsqrt(); var.device(d) -= m; } } }; template <> struct ApplyRmspropV2 { void operator()(const CPUDevice& d, typename TTypes::Flat var, typename TTypes::Flat m, typename TTypes::Flat v, typename TTypes::ConstScalar lr, typename TTypes::ConstScalar beta1, typename TTypes::ConstScalar beta2, typename TTypes::ConstScalar epsilon, typename TTypes::ConstScalar weight_decay, typename TTypes::ConstFlat grad, bool update_slots) { auto grad_after_decay = weight_decay() * var + grad; if (update_slots) { v.device(d) = beta2() * v + grad_after_decay.square(); // m.device(d) = beta1() * m + (grad_after_decay * lr()) * (v + // epsilon()).rsqrt(); m.device(d) = beta1() * m + (grad_after_decay * lr()) / (v.sqrt() + epsilon()); var.device(d) -= m; } } }; template <> struct ApplyAdamom { void operator()(const CPUDevice& d, typename TTypes::Flat var, typename TTypes::Flat m, typename TTypes::Flat v, typename TTypes::Flat c, typename TTypes::ConstScalar lr, typename TTypes::ConstScalar ada_decay, typename TTypes::ConstScalar mom_decay, typename TTypes::ConstScalar epsilon, typename TTypes::ConstScalar weight_decay, typename TTypes::ConstFlat grad, bool update_slots) { auto grad_after_decay = weight_decay() * var + grad; if (update_slots) { m.device(d) = mom_decay() * m + (1.0f - mom_decay()) * grad_after_decay; v.device(d) = ada_decay() * v + grad_after_decay * grad_after_decay; c.device(d) = ada_decay() * c + 1.0f; } var.device(d) -= m * lr() * (v / c + epsilon()).rsqrt(); } }; template <> struct ApplyAdamomV2 { void operator()(const CPUDevice& d, typename TTypes::Flat var, typename TTypes::Flat m, typename TTypes::Flat v, typename TTypes::Flat c, typename TTypes::ConstScalar lr, typename TTypes::ConstScalar ada_decay, typename TTypes::ConstScalar mom_decay, typename TTypes::ConstScalar epsilon, typename TTypes::ConstScalar weight_decay, typename TTypes::ConstFlat grad, bool update_slots) { auto grad_after_decay = weight_decay() * var + grad; if (update_slots) { m.device(d) = mom_decay() * m + (1.0f - mom_decay()) * grad_after_decay; v.device(d) = ada_decay() * v + grad_after_decay * grad_after_decay; c.device(d) = ada_decay() * c + 1.0f; } var.device(d) -= m * lr() / ((v / c).sqrt() + epsilon()); } }; REGISTER_KERNEL_BUILDER(Name("ResourceApplyAdamom") // .HostMemory("var") // .HostMemory("m") // .HostMemory("v") // .HostMemory("c") .Device(DEVICE_CPU), ApplyAdamomOp); REGISTER_KERNEL_BUILDER(Name("ResourceApplyRmsprop").Device(DEVICE_CPU), ApplyRmspropOp); } // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/optimizers/cc/kernels/training_ops.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_MONOLITH_NATIVE_TRAINING_OPTIMIZERS_CC_TRAINING_OPS_H_ #define MONOLITH_MONOLITH_NATIVE_TRAINING_OPTIMIZERS_CC_TRAINING_OPS_H_ #include "monolith/native_training/optimizers/cc/kernels/training_op_helpers.h" #include "tensorflow/core/framework/op_kernel.h" namespace tensorflow { namespace monolith_tf { template struct ApplyRmsprop { void operator()(const Device& d, typename TTypes::Flat var, typename TTypes::Flat m, typename TTypes::Flat v, typename TTypes::ConstScalar lr, typename TTypes::ConstScalar beta1, typename TTypes::ConstScalar beta2, typename TTypes::ConstScalar epsilon, typename TTypes::ConstScalar weight_decay, typename TTypes::ConstFlat grad, bool update_slots); }; template struct ApplyRmspropV2 { void operator()(const Device& d, typename TTypes::Flat var, typename TTypes::Flat m, typename TTypes::Flat v, typename TTypes::ConstScalar lr, typename TTypes::ConstScalar beta1, typename TTypes::ConstScalar beta2, typename TTypes::ConstScalar epsilon, typename TTypes::ConstScalar weight_decay, typename TTypes::ConstFlat grad, bool update_slots); }; template struct ApplyAdamom { void operator()(const Device& d, typename TTypes::Flat var, typename TTypes::Flat m, typename TTypes::Flat v, typename TTypes::Flat c, typename TTypes::ConstScalar lr, typename TTypes::ConstScalar ada_decay, typename TTypes::ConstScalar mom_decay, typename TTypes::ConstScalar epsilon, typename TTypes::ConstScalar weight_decay, typename TTypes::ConstFlat grad, bool update_slots); }; template struct ApplyAdamomV2 { void operator()(const Device& d, typename TTypes::Flat var, typename TTypes::Flat m, typename TTypes::Flat v, typename TTypes::Flat c, typename TTypes::ConstScalar lr, typename TTypes::ConstScalar ada_decay, typename TTypes::ConstScalar mom_decay, typename TTypes::ConstScalar epsilon, typename TTypes::ConstScalar weight_decay, typename TTypes::ConstFlat grad, bool update_slots); }; template class ApplyRmspropOp : public OpKernel { public: explicit ApplyRmspropOp(OpKernelConstruction* ctx) : OpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("update_slots", &update_slots_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("use_v2", &use_v2_)); } void Compute(OpKernelContext* ctx) override { const bool sparse = false; auto locks = MaybeLockVariableInputMutexesInOrder( ctx, use_exclusive_lock_, sparse, {0, 1}); Tensor var; OP_REQUIRES_OK(ctx, GetInputTensorFromVariable( ctx, 0, use_exclusive_lock_, sparse, &var)); Tensor m; OP_REQUIRES_OK(ctx, GetInputTensorFromVariable( ctx, 1, use_exclusive_lock_, sparse, &m)); Tensor v; OP_REQUIRES_OK(ctx, GetInputTensorFromVariable( ctx, 2, use_exclusive_lock_, sparse, &v)); OP_REQUIRES( ctx, var.IsInitialized(), errors::FailedPrecondition( "Attempting to use uninitialized variables: ", requested_input(0))); OP_REQUIRES( ctx, m.IsInitialized(), errors::FailedPrecondition( "Attempting to use uninitialized variables: ", requested_input(1))); OP_REQUIRES( ctx, v.IsInitialized(), errors::FailedPrecondition( "Attempting to use uninitialized variables: ", requested_input(2))); const Tensor& lr = ctx->input(3); OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()), errors::InvalidArgument("lr is not a scalar: ", lr.shape().DebugString())); const Tensor& beta1 = ctx->input(4); OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta1.shape()), errors::InvalidArgument("beta1 is not a scalar: ", beta1.shape().DebugString())); const Tensor& beta2 = ctx->input(5); OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta2.shape()), errors::InvalidArgument("beta2 is not a scalar: ", beta2.shape().DebugString())); const Tensor& epsilon = ctx->input(6); OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon.shape()), errors::InvalidArgument("epsilon is not a scalar: ", epsilon.shape().DebugString())); const Tensor& weight_decay = ctx->input(7); OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(weight_decay.shape()), errors::InvalidArgument("weight_decay is not a scalar: ", weight_decay.shape().DebugString())); const Tensor& grad = ctx->input(8); OP_REQUIRES(ctx, var.shape().IsSameSize(m.shape()), errors::InvalidArgument("var and m do not have the same shape", var.shape().DebugString(), " ", m.shape().DebugString())); OP_REQUIRES(ctx, var.shape().IsSameSize(v.shape()), errors::InvalidArgument("var and v do not have the same shape", var.shape().DebugString(), " ", v.shape().DebugString())); OP_REQUIRES( ctx, var.shape().IsSameSize(grad.shape()), errors::InvalidArgument("var and grad do not have the same shape", var.shape().DebugString(), " ", grad.shape().DebugString())); const Device& device = ctx->eigen_device(); if (!use_v2_) { ApplyRmsprop()( device, var.flat(), m.flat(), v.flat(), lr.scalar(), beta1.scalar(), beta2.scalar(), epsilon.scalar(), weight_decay.scalar(), grad.flat(), update_slots_); } else { ApplyRmspropV2()( device, var.flat(), m.flat(), v.flat(), lr.scalar(), beta1.scalar(), beta2.scalar(), epsilon.scalar(), weight_decay.scalar(), grad.flat(), update_slots_); } MaybeForwardRefInputToRefOutput(ctx, 0, 0); } private: bool use_exclusive_lock_; bool update_slots_; bool use_v2_; }; template class ApplyAdamomOp : public OpKernel { public: explicit ApplyAdamomOp(OpKernelConstruction* ctx) : OpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("update_slots", &update_slots_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("use_v2", &use_v2_)); } void Compute(OpKernelContext* ctx) override { const bool sparse = false; auto locks = MaybeLockVariableInputMutexesInOrder( ctx, use_exclusive_lock_, sparse, {0, 1}); Tensor var; OP_REQUIRES_OK(ctx, GetInputTensorFromVariable( ctx, 0, use_exclusive_lock_, sparse, &var)); Tensor m; OP_REQUIRES_OK(ctx, GetInputTensorFromVariable( ctx, 1, use_exclusive_lock_, sparse, &m)); Tensor v; OP_REQUIRES_OK(ctx, GetInputTensorFromVariable( ctx, 2, use_exclusive_lock_, sparse, &v)); Tensor c; OP_REQUIRES_OK(ctx, GetInputTensorFromVariable( ctx, 3, use_exclusive_lock_, sparse, &c)); OP_REQUIRES( ctx, var.IsInitialized(), errors::FailedPrecondition( "Attempting to use uninitialized variables: ", requested_input(0))); OP_REQUIRES( ctx, m.IsInitialized(), errors::FailedPrecondition( "Attempting to use uninitialized variables: ", requested_input(1))); OP_REQUIRES( ctx, v.IsInitialized(), errors::FailedPrecondition( "Attempting to use uninitialized variables: ", requested_input(2))); OP_REQUIRES( ctx, c.IsInitialized(), errors::FailedPrecondition( "Attempting to use uninitialized variables: ", requested_input(3))); const Tensor& lr = ctx->input(4); OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()), errors::InvalidArgument("lr is not a scalar: ", lr.shape().DebugString())); const Tensor& ada_decay = ctx->input(5); OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(ada_decay.shape()), errors::InvalidArgument("ada_decay is not a scalar: ", ada_decay.shape().DebugString())); const Tensor& mom_decay = ctx->input(6); OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(mom_decay.shape()), errors::InvalidArgument("mom_decay is not a scalar: ", mom_decay.shape().DebugString())); const Tensor& epsilon = ctx->input(7); OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon.shape()), errors::InvalidArgument("epsilon is not a scalar: ", epsilon.shape().DebugString())); const Tensor& weight_decay = ctx->input(8); OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(weight_decay.shape()), errors::InvalidArgument("weight_decay is not a scalar: ", weight_decay.shape().DebugString())); const Tensor& grad = ctx->input(9); OP_REQUIRES(ctx, var.shape().IsSameSize(m.shape()), errors::InvalidArgument( "var and accum do not have the same shape", var.shape().DebugString(), " ", m.shape().DebugString())); OP_REQUIRES(ctx, var.shape().IsSameSize(v.shape()), errors::InvalidArgument( "var and accum do not have the same shape", var.shape().DebugString(), " ", v.shape().DebugString())); OP_REQUIRES(ctx, var.shape().IsSameSize(c.shape()), errors::InvalidArgument( "var and accum do not have the same shape", var.shape().DebugString(), " ", c.shape().DebugString())); OP_REQUIRES( ctx, var.shape().IsSameSize(grad.shape()), errors::InvalidArgument("var and grad do not have the same shape", var.shape().DebugString(), " ", grad.shape().DebugString())); const Device& device = ctx->eigen_device(); if (!use_v2_) { ApplyAdamom()( device, var.flat(), m.flat(), v.flat(), c.flat(), lr.scalar(), ada_decay.scalar(), mom_decay.scalar(), epsilon.scalar(), weight_decay.scalar(), grad.flat(), update_slots_); } else { ApplyAdamomV2()( device, var.flat(), m.flat(), v.flat(), c.flat(), lr.scalar(), ada_decay.scalar(), mom_decay.scalar(), epsilon.scalar(), weight_decay.scalar(), grad.flat(), update_slots_); } MaybeForwardRefInputToRefOutput(ctx, 0, 0); } private: bool use_exclusive_lock_; bool update_slots_; bool use_v2_; }; } // namespace monolith_tf } // namespace tensorflow #endif // MONOLITH_MONOLITH_NATIVE_TRAINING_OPTIMIZERS_CC_TRAINING_OPS_H_ ================================================ FILE: monolith/native_training/optimizers/cc/kernels/training_ops_gpu.cu.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define EIGEN_USE_GPU #include "monolith/native_training/optimizers/cc/kernels/training_ops.h" namespace tensorflow { namespace monolith_tf { typedef Eigen::GpuDevice GPUDevice; template <> struct ApplyRmsprop { void operator()(const GPUDevice& d, typename TTypes::Flat var, typename TTypes::Flat m, typename TTypes::Flat v, typename TTypes::ConstScalar lr, typename TTypes::ConstScalar beta1, typename TTypes::ConstScalar beta2, typename TTypes::ConstScalar epsilon, typename TTypes::ConstScalar weight_decay, typename TTypes::ConstFlat grad, bool update_slots) { Eigen::array::Tensor::Index, 1> bcast; bcast[0] = grad.dimension(0); Eigen::Sizes<1> single; auto grad_after_decay = weight_decay.reshape(single).broadcast(bcast) * var + grad; if (update_slots) { v.device(d) += (grad_after_decay.square() - v) * (beta2.constant(1.0f) - beta2).reshape(single).broadcast(bcast); m.device(d) = beta1.reshape(single).broadcast(bcast) * m + (grad_after_decay * lr.reshape(single).broadcast(bcast)) * (v + epsilon.reshape(single).broadcast(bcast)).rsqrt(); var.device(d) -= m; } } }; template <> struct ApplyRmspropV2 { void operator()(const GPUDevice& d, typename TTypes::Flat var, typename TTypes::Flat m, typename TTypes::Flat v, typename TTypes::ConstScalar lr, typename TTypes::ConstScalar beta1, typename TTypes::ConstScalar beta2, typename TTypes::ConstScalar epsilon, typename TTypes::ConstScalar weight_decay, typename TTypes::ConstFlat grad, bool update_slots) { Eigen::array::Tensor::Index, 1> bcast; bcast[0] = grad.dimension(0); Eigen::Sizes<1> single; auto grad_after_decay = weight_decay.reshape(single).broadcast(bcast) * var + grad; if (update_slots) { v.device(d) = beta2.reshape(single).broadcast(bcast) * v + grad_after_decay.square(); // m.device(d) = beta1() * m + (grad_after_decay * lr()) * (v + // epsilon()).rsqrt(); m.device(d) = beta1.reshape(single).broadcast(bcast) * m + (grad_after_decay * lr.reshape(single).broadcast(bcast)) / (v.sqrt() + epsilon.reshape(single).broadcast(bcast)); var.device(d) -= m; } } }; REGISTER_KERNEL_BUILDER(Name("ResourceApplyRmsprop").Device(DEVICE_GPU), ApplyRmspropOp); } // namespace monolith_tf } // namespace tensorflow #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM ================================================ FILE: monolith/native_training/optimizers/cc/training_ops.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #define EIGEN_USE_THREADS #include "monolith/native_training/optimizers/cc/kernels/training_op_helpers.h" #include "tensorflow/core/framework/op_kernel.h" namespace tensorflow { namespace monolith_tf { template ShapeHandle ShapeOrHandleShape(InferenceContext* c, int input) { auto* handle_data = c->input_handle_shapes_and_types(input); if (handle_data != nullptr && !handle_data->empty() && (*handle_data)[0].dtype != DT_INVALID) { return (*handle_data)[0].shape; } return c->input(input); } template <> ShapeHandle ShapeOrHandleShape(InferenceContext* c, int input) { auto* handle_data = c->input_handle_shapes_and_types(input); if (handle_data != nullptr && !handle_data->empty() && (*handle_data)[0].dtype != DT_INVALID) { return (*handle_data)[0].shape; } // If a resource input is missing shape information, we should return // UnknownShape rather than the shape of the input, which is a scalar // resource handle. return c->UnknownShape(); } // Handle the gradient and, if , indices inputs. // is an input+output parameter, containing the current known input shape to // the gradient. template static Status HandleGradAndIndicesInputs(InferenceContext* c, int grad_idx, ShapeHandle* s) { ShapeHandle grad = ShapeOrHandleShape(c, grad_idx); if (!is_sparse) { TF_RETURN_IF_ERROR(c->Merge(*s, grad, s)); return Status::OK(); } // Indices is a vector where indices.dim[0].rank == grad[0].rank. ShapeHandle indices; TF_RETURN_IF_ERROR(c->WithRank(c->input(grad_idx + 1), 1, &indices)); DimensionHandle unused; TF_RETURN_IF_ERROR(c->Merge(c->Dim(indices, 0), c->Dim(grad, 0), &unused)); // Trailing part of grad matches trailing part of *s. ShapeHandle grad_unknown_first; TF_RETURN_IF_ERROR( c->ReplaceDim(grad, 0, c->UnknownDim(), &grad_unknown_first)); TF_RETURN_IF_ERROR(c->Merge(*s, grad_unknown_first, s)); return Status::OK(); } static Status ApplyAdamomShapeFn(InferenceContext* c) { ShapeHandle unused; ShapeHandle s = ShapeOrHandleShape(c, 0); // var TF_RETURN_IF_ERROR( c->Merge(s, ShapeOrHandleShape(c, 1), &s)); // m TF_RETURN_IF_ERROR( c->Merge(s, ShapeOrHandleShape(c, 2), &s)); // v TF_RETURN_IF_ERROR( c->Merge(s, ShapeOrHandleShape(c, 3), &s)); // c TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); // lr TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); // ada_decay TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused)); // mom_decay TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused)); // epsilon TF_RETURN_IF_ERROR(c->WithRank(c->input(8), 0, &unused)); // weight_decay TF_RETURN_IF_ERROR( HandleGradAndIndicesInputs( c, 9 /* grad_idx */, &s)); if (c->num_outputs() > 0) { c->set_output(0, s); } return Status::OK(); } REGISTER_OP("ResourceApplyAdamom") .Input("var: resource") .Input("m: resource") .Input("v: resource") .Input("c: resource") .Input("learning_rate: float") .Input("ada_decay: float") .Input("mom_decay: float") .Input("epsilon: float") .Input("weight_decay: float") .Input("grad: float") .Attr("use_locking: bool = false") .Attr("update_slots: bool = true") .Attr("use_v2: bool = false") .SetShapeFn(ApplyAdamomShapeFn); static Status ApplyRmspropShapeFn(InferenceContext* c) { ShapeHandle unused; ShapeHandle s = ShapeOrHandleShape(c, 0); // var TF_RETURN_IF_ERROR( c->Merge(s, ShapeOrHandleShape(c, 1), &s)); // m TF_RETURN_IF_ERROR( c->Merge(s, ShapeOrHandleShape(c, 2), &s)); // v TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); // lr TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); // beta1 TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); // beta2 TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused)); // epsilon TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused)); // weight_decay TF_RETURN_IF_ERROR( HandleGradAndIndicesInputs( c, 8 /* grad_idx */, &s)); if (c->num_outputs() > 0) { c->set_output(0, s); } return Status::OK(); } REGISTER_OP("ResourceApplyRmsprop") .Input("var: resource") .Input("m: resource") .Input("v: resource") .Input("learning_rate: float") .Input("beta1: float") .Input("beta2: float") .Input("epsilon: float") .Input("weight_decay: float") .Input("grad: float") .Attr("use_locking: bool = false") .Attr("update_slots: bool = true") .Attr("use_v2: bool = false") .SetShapeFn(ApplyRmspropShapeFn); } // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/optimizers/rmsprop.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== from typing import Union, Callable import tensorflow as tf import numpy as np from tensorflow.python.framework import ops from tensorflow.python.keras.optimizer_v2 import optimizer_v2 from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import gen_array_ops from tensorflow.python.util.tf_export import keras_export from tensorflow.keras.initializers import Constant from monolith.native_training.runtime.ops import gen_monolith_ops training_ops = gen_monolith_ops class RmspropOptimizer(tf.compat.v1.train.Optimizer): """http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf""" def __init__(self, learning_rate=5e-6, beta1: float = 0.99, beta2: float = 0.999, epsilon: float = 1e-8, weight_decay: float = 0.0, use_locking: bool = False, use_v2: bool = False, name="Rmsprop"): super().__init__(use_locking, name) self._learning_rate = learning_rate self._beta1 = beta1 self._beta2 = beta2 self._epsilon = epsilon self._weight_decay = weight_decay self._use_v2 = use_v2 # Created in Initialize. self._learning_rate_tensor = None def _create_slots(self, var_list): # Create slots for the first and second moments. for v in var_list: self._zeros_slot(v, "m", self._name + "/m") self._zeros_slot(v, "v", self._name + "/v") def _prepare(self): learning_rate = self._call_if_callable(self._learning_rate) self._learning_rate_tensor = tf.convert_to_tensor(learning_rate, name="learning_rate") def _apply_dense(self, grad, var): raise NotImplementedError( "Please use tf.compat.v1.disable_eager_execution() instead of tf.compat.v1.disable_v2_behavior()" ) def _resource_apply_dense(self, grad, var): m = self.get_slot(var, "m") v = self.get_slot(var, "v") return training_ops.resource_apply_rmsprop(var.handle, m.handle, v.handle, tf.cast( self._learning_rate_tensor, grad.dtype.base_dtype), self._beta1, self._beta2, self._epsilon, self._weight_decay, grad, use_locking=self._use_locking, use_v2=self._use_v2) ================================================ FILE: monolith/native_training/optimizers/rmsprop_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from absl import logging import tensorflow as tf from tensorflow.python.framework.ops import name_from_scope_name from tensorflow.python.framework import test_util from monolith.native_training.optimizers import rmsprop def build_graph() -> tf.Operation: v = tf.Variable([0.1], name="var") loss = 0.12 * v opt = rmsprop.RmspropOptimizer(learning_rate=0.1, weight_decay=1, beta1=0.9, beta2=0.9, epsilon=0.1) return opt.minimize(loss) class RmspropTest(tf.test.TestCase): def testBasic(self): with tf.Graph().as_default(), test_util.use_gpu(): train_op = build_graph() with self.session() as sess: sess.run(tf.compat.v1.global_variables_initializer()) sess.run(train_op) all_vars = tf.compat.v1.all_variables() if tf.test.is_gpu_available(): for var in all_vars: self.assertEqual(var.device, '/device:GPU:0') vars_map_maybe_on_gpu = sess.run({var.name: var for var in all_vars}) eps = 1e-8 found_count = 0 for name, val in vars_map_maybe_on_gpu.items(): if name.find("/m") >= 0: found_count += 1 self.assertNear(val, 0.06794526153774846, eps) elif name.find("/v") >= 0: found_count += 1 self.assertNear(val, 0.00484, eps) else: found_count += 1 # Must be variable self.assertNear(val, 0.03205473846225154, eps) self.assertEqual(found_count, 3) with tf.Graph().as_default(), test_util.force_cpu(): train_op = build_graph() with self.session() as sess: sess.run(tf.compat.v1.global_variables_initializer()) sess.run(train_op) all_vars = tf.compat.v1.all_variables() for var in all_vars: self.assertEqual(var.device, '/device:CPU:0') vars_map_on_cpu = sess.run({var.name: var for var in all_vars}) self.assertEqual(vars_map_maybe_on_gpu, vars_map_on_cpu) if __name__ == "__main__": tf.compat.v1.disable_eager_execution() tf.test.main() ================================================ FILE: monolith/native_training/optimizers/rmspropv2_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import tensorflow as tf from tensorflow.python.framework.ops import name_from_scope_name from tensorflow.python.framework import test_util from monolith.native_training.optimizers import rmsprop def build_graph() -> tf.Operation: v = tf.Variable([0.1], name="var") loss = 0.12 * v opt = rmsprop.RmspropOptimizer(learning_rate=0.1, weight_decay=1, beta1=0.9, beta2=0.9, epsilon=0.1, use_v2=True) return opt.minimize(loss) class RmspropTest(tf.test.TestCase): def testBasic(self): with tf.Graph().as_default(), test_util.use_gpu(): train_op = build_graph() with self.session() as sess: sess.run(tf.compat.v1.global_variables_initializer()) sess.run(train_op) all_vars = tf.compat.v1.all_variables() if tf.test.is_gpu_available(): for var in all_vars: self.assertEqual(var.device, '/device:GPU:0') vars_map_maybe_on_gpu = sess.run({var.name: var for var in all_vars}) eps = 1e-8 found_count = 0 for name, val in vars_map_maybe_on_gpu.items(): if name.find("/m") >= 0: found_count += 1 self.assertNear(val, 0.068750, eps) elif name.find("/v") >= 0: found_count += 1 self.assertNear(val, 0.0484, eps) else: found_count += 1 # Must be variable self.assertNear(val, 0.031250, eps) self.assertEqual(found_count, 3) with tf.Graph().as_default(), test_util.force_cpu(): train_op = build_graph() with self.session() as sess: sess.run(tf.compat.v1.global_variables_initializer()) sess.run(train_op) all_vars = tf.compat.v1.all_variables() for var in all_vars: self.assertEqual(var.device, '/device:CPU:0') vars_map_on_cpu = sess.run({var.name: var for var in all_vars}) self.assertEqual(vars_map_maybe_on_gpu, vars_map_on_cpu) def testWeightDecay(self): with tf.Graph().as_default(), test_util.use_gpu(): train_op = build_graph() with self.session() as sess: sess.run(tf.compat.v1.global_variables_initializer()) sess.run(train_op) all_vars = tf.compat.v1.all_variables() if tf.test.is_gpu_available(): for var in all_vars: self.assertEqual(var.device, '/device:GPU:0') vars_map_maybe_on_gpu = sess.run({var.name: var for var in all_vars}) eps = 1e-8 found_count = 0 for name, val in vars_map_maybe_on_gpu.items(): if name.find("/m") >= 0: found_count += 1 self.assertNear(val, 0.068750, eps) elif name.find("/v") >= 0: found_count += 1 self.assertNear(val, 0.0484, eps) else: found_count += 1 # Must be variable self.assertNear(val, 0.031250, eps) self.assertEqual(found_count, 3) with tf.Graph().as_default(), test_util.force_cpu(): train_op = build_graph() with self.session() as sess: sess.run(tf.compat.v1.global_variables_initializer()) sess.run(train_op) all_vars = tf.compat.v1.all_variables() for var in all_vars: self.assertEqual(var.device, '/device:CPU:0') vars_map_on_cpu = sess.run({var.name: var for var in all_vars}) self.assertEqual(vars_map_maybe_on_gpu, vars_map_on_cpu) if __name__ == "__main__": tf.compat.v1.disable_eager_execution() tf.test.main() ================================================ FILE: monolith/native_training/optimizers/shampoo.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import tensorflow as tf from tensorflow.python.ops import state_ops, io_ops @tf.function def eigen_inverse_root(mat, p, head, tail, damping=1e-3): alpha = -1.0 / p dim = mat.shape[0] eval, evec = tf.linalg.eigh(mat) non_zero = tf.where(tf.greater(eval, damping)) zeros = tf.cond( tf.greater(tf.size(non_zero), 0), lambda: tf.cast(tf.reduce_min(non_zero), dtype="int32"), lambda: tf.constant(0, dtype="int32")) # Count the number of zeros eval_p = tf.pow(tf.maximum(eval, damping), alpha) if tf.greater(head + tail, dim): zeros = 0 head = dim tail = 0 elif tf.greater(zeros + head + tail, dim): zeros = dim - head - tail eval_ht = tf.concat([eval_p[zeros:zeros + head], eval_p[dim - tail:]], 0) # selected eigenvalues evec_ht = tf.concat([evec[:, zeros:zeros + head], evec[:, dim - tail:]], 1) # selected eigenvectors if tf.equal(zeros + head + tail, dim): offset = 0.0 else: offset = tf.reduce_mean(eval[zeros + head:dim - tail]) return evec_ht, eval_ht - offset, offset def apply_sparse_precond(tensor, pvec, pval, offset): tensor_tmp_1 = tf.tensordot(tensor, pvec, axes=[[0], [0]]) tensor_tmp_2 = tf.multiply(tensor_tmp_1, pval) tensor_tmp_3 = tf.tensordot(tensor_tmp_2, pvec, axes=[[-1], [-1]]) rank = len(tensor.shape) tensor_transpose = tf.transpose(tensor, perm=list(range(1, rank)) + [0]) return tensor_tmp_3 + tensor_transpose * offset class ShampooOptimizer(tf.compat.v1.train.Optimizer): def __init__(self, learning_rate=0.03, beta_1: float = 0.9, beta_2: float = 1.0, warmup: int = 5000, tau_1: int = 200, tau_2: int = 20, eigen_head: int = 100, eigen_tail: int = 100, damping_epsilon: float = 1e-3, use_locking: bool = False, name="Shampoo", **kwargs): super().__init__(use_locking, name, **kwargs) self._learning_rate = learning_rate self._beta_1 = beta_1 self._beta_2 = beta_2 self._warmup = warmup self._tau_1 = tau_1 self._tau_2 = tau_2 self._eigen_head = eigen_head self._eigen_tail = eigen_tail self._damping_epsilon = damping_epsilon def _create_slots(self, var_list): for var in var_list: for i, dim in enumerate(var.shape): eigens = min(dim, self._eigen_head + self._eigen_tail) self._get_or_make_slot(var, tf.zeros([dim, dim]), "s" + str(i), self._name + "/s" + str(i)) self._get_or_make_slot(var, tf.zeros([dim, dim]), "g" + str(i), self._name + "/g" + str(i)) self._get_or_make_slot(var, tf.zeros([dim, eigens]), "pvec" + str(i), self._name + "/pvec" + str(i)) self._get_or_make_slot(var, tf.zeros([eigens]), "pval" + str(i), self._name + "/pval" + str(i)) self._get_or_make_slot(var, tf.zeros([]), "o" + str(i), self._name + "/o" + str(i)) self._zeros_slot(var, 'd', self._name + "/d") self._zeros_slot(var, 'm', self._name + "/m") self._zeros_slot(var, 'pm', self._name + "/pm") def _resource_apply_dense(self, grad, var): lr = self._learning_rate beta_1 = self._beta_1 beta_2 = self._beta_2 warmup = self._warmup tau_1 = tf.cast(self._tau_1, dtype='int32') tau_2 = tf.cast(self._tau_2, dtype='int32') eigen_head = self._eigen_head eigen_tail = self._eigen_tail damping_epsilon = self._damping_epsilon global_step = tf.cast(tf.compat.v1.train.get_global_step(), dtype='int32') if_update_stat = tf.equal(tf.math.mod(global_step, tau_2), 0) if_warmed_up = tf.greater(global_step, warmup) if_update_precond = tf.math.logical_and( if_warmed_up, tf.equal(tf.math.mod(global_step, tau_1), 0)) global_step_f = tf.cast(global_step, dtype='float32') warmup_f = tf.cast(self._warmup, dtype='float32') warmup_rate = tf.minimum(tf.maximum(global_step_f / warmup_f - 1.0, 0.0), 1.0) if_stat_momentum = tf.less( beta_2, 1.0 - 1e-10) # if beta_2 = 1.0, do not use momentum on statistics ops = [] rank = len(grad.shape) grad_precond = grad for i in range(rank): axes = list(range(i)) + list(range(i + 1, rank)) g = self.get_slot(var, 'g' + str(i)) g_t = tf.cond( if_update_stat, lambda: state_ops.assign( g, tf.tensordot(grad, grad, axes=[axes, axes])), lambda: tf.identity(g)) s = self.get_slot(var, 's' + str(i)) s_t = tf.cond( if_stat_momentum, lambda: state_ops.assign(s, beta_2 * s + (1 - beta_2) * g_t), lambda: state_ops.assign_add(s, g_t)) pvec = self.get_slot(var, 'pvec' + str(i)) pval = self.get_slot(var, 'pval' + str(i)) offset = self.get_slot(var, 'o' + str(i)) def update_precond(): pvec_t, pval_t, offset_t = eigen_inverse_root(s_t, 2 * rank, eigen_head, eigen_tail, damping_epsilon) return (state_ops.assign(pvec, pvec_t), state_ops.assign(pval, pval_t), state_ops.assign(offset, offset_t)) pvec_t, pval_t, offset_t = tf.cond( if_update_precond, lambda: update_precond(), lambda: (tf.identity(pvec), tf.identity(pval), tf.identity(offset))) grad_precond = apply_sparse_precond(grad_precond, pvec_t, pval_t, offset_t) ops += [ g_t, s_t, pvec_t, pval_t, offset_t, ] d = self.get_slot(var, 'd') d_t = state_ops.assign_add(d, grad * grad) m = self.get_slot(var, 'm') m_t = state_ops.assign( m, beta_1 * m + (1 - beta_1) * grad * tf.math.rsqrt(d_t + 1e-30)) pm = self.get_slot(var, 'pm') pm_t = state_ops.assign(pm, beta_1 * pm + (1.0 - beta_1) * grad_precond) update_diag = lr * m_t # AdaGrad gradient used in warmup steps update_second = lr * tf.norm(m_t) / (tf.norm(pm_t) + 1e-10) * pm_t # Shampoo gradient normalized by AdaGrad var_t = tf.cond( if_warmed_up, lambda: state_ops.assign_sub(var, ( 1.0 - warmup_rate) * update_diag + warmup_rate * update_second), lambda: state_ops.assign_sub(var, update_diag)) ops += [d_t, m_t, pm_t, var_t] return tf.group(*ops) def _resource_apply_sparse(self, grad, var): raise tf.no_op() ================================================ FILE: monolith/native_training/prefetch_queue.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import copy from itertools import accumulate from itertools import chain from typing import Any, Callable, Dict, List, Optional, Tuple, Union from absl import logging import tensorflow as tf from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import gen_data_flow_ops from monolith.native_training import utils from monolith.native_training import nested_tensors # Similar to https://github.com/tensorflow/tensorflow/commit/f98b3bc7012085096d8171fe56f6004677461567# class _GPUCompatiblePaddingFIFOQueue(data_flow_ops.QueueBase): """A queue implementation that dequeues elements in first-in first-out order. GPUCompatiblePaddingFIFOQueue is like PaddingFIFOQueue, but the queue resource may be placed either on a CPU or on a GPU. It is not cross-device: enqueues and dequeues will be colocated with the queue resource. """ def __init__(self, capacity, dtypes, shapes, names=None, shared_name=None, name="padding_fifo_queue"): """A `PaddingFIFOQueue` may contain components with dynamic shape, while also supporting `dequeue_many`. The `shapes` argument must be specified; each component of a queue element must have the respective shape. Shapes of fixed rank but variable size are allowed by setting any shape dimension to None. Args: capacity: An integer. The upper bound on the number of elements that may be stored in this queue. dtypes: A list of `DType` objects. The length of `dtypes` must equal the number of tensors in each queue element. shapes: A list of `TensorShape` objects, with the same length as `dtypes`. Any dimension in the `TensorShape` containing value `None` is dynamic and allows values to be enqueued with variable size in that dimension. names: (Optional.) A list of string naming the components in the queue with the same length as `dtypes`, or `None`. If specified the dequeue methods return a dictionary with the names as keys. shared_name: (Optional.) If non-empty, this queue will be shared under the given name across multiple sessions. name: Optional name for the queue operation. """ dtypes = data_flow_ops._as_type_list(dtypes) shapes = data_flow_ops._as_shape_list(shapes, dtypes, unknown_dim_allowed=True) names = data_flow_ops._as_name_list(names, dtypes) if len(dtypes) != len(shapes): raise ValueError("Shapes must be provided for all components, " f"but received {len(dtypes)} dtypes and " f"{len(shapes)} shapes.") # init_scope() context required queue_ref = gen_data_flow_ops.padding_fifo_queue_v2( component_types=dtypes, shapes=shapes, capacity=capacity, shared_name=data_flow_ops._shared_name(shared_name), name=name) super().__init__(dtypes, shapes, names, queue_ref) def enqueue_many(self, vals, name=None): """enqueue_many is not supported on GPUCompatiblePaddingFIFOQueue.""" raise NotImplementedError( "GPUCompatiblePaddingFIFOQueue does not support enqueue_many or dequeue_many, " "only enqueue and dequeue.") def dequeue_many(self, n, name=None): """dequeue_many is not supported on GPUCompatiblePaddingFIFOQueue.""" raise NotImplementedError( "GPUCompatiblePaddingFIFOQueue does not support enqueue_many or dequeue_many, " "only enqueue and dequeue.") class _QueueBase: """Monolith Specialized Prefetch QueueBase.""" @property def queue(self): raise NotImplementedError @property def queues(self): raise NotImplementedError @property def enqueue_op(self): raise NotImplementedError def dequeue(self): raise NotImplementedError class _FIFOQueue(_QueueBase): def __init__(self, dense_list: Optional[List[tf.Tensor]] = None, capacity: int = 2, queue_name: str = "prefetch_queue"): if dense_list is None: raise ValueError("Arguments `dense_list` should not be empty.") if dense_list is None: dense_list = [] else: if not isinstance(dense_list, list): raise TypeError("dense_list should be a list of `tf.Tensor`s") self._dense_list = dense_list flatten_tensor_list = self._dense_list dtypes = [f.dtype for f in flatten_tensor_list] shapes = [f.shape for f in flatten_tensor_list] with tf.init_scope(): self._queue = _GPUCompatiblePaddingFIFOQueue(capacity, dtypes=dtypes, shapes=shapes, name=queue_name) self._enqueue_op = self._queue.enqueue(flatten_tensor_list) @property def queue(self): return self._queue @property def queues(self): return [self._queue] @property def enqueue_op(self): return self._enqueue_op def dequeue(self): with tf.init_scope(): dequeue_tensor_list = self._queue.dequeue() if not isinstance(dequeue_tensor_list, list): assert len(self._dense_list) == 1 return [dequeue_tensor_list] return dequeue_tensor_list class _MultiFIFOQueue(_QueueBase): """Multi-Device FIFOQueue that supports CPU and GPU tensors in queue.""" def __init__(self, dense_list: Optional[List[tf.Tensor]] = None, capacity: int = 2, queue_name: str = "prefetch_queue"): # Don't call the super() constructor here. Just inherit for interfaces. self._qs = [] dense_list_cpu, dense_list_gpu = self._split_tensor_list_by_device( dense_list) with tf.device("/device:CPU:0"): queue = _FIFOQueue(dense_list=dense_list_cpu, capacity=capacity, queue_name=queue_name) self._qs.append(queue) if dense_list_gpu: with tf.device("/device:GPU:0"): queue_gpu = _FIFOQueue(dense_list=dense_list_gpu, capacity=capacity, queue_name=queue_name + "_gpu") self._qs.append(queue_gpu) # enqueue altogether self._enqueue_op = tf.group([q.enqueue_op for q in self._qs]) @property def queue(self): if len(self._qs) == 1: return self._qs[0].queue else: raise NotImplementedError( "When using multi-device queues, this interface is disabled." "Check if a tensor to be enqueued is mistakenly placed on GPU.") @property def queues(self): return [q.queue for q in self._qs] @property def enqueue_op(self): return self._enqueue_op def dequeue(self): n = len(self._qs) if n == 1: return self._qs[0].dequeue() else: # We assume here that when we dequeue, we dequeue both CPU and GPU tensors together; # otherwise we need to enforce mutual control dependencies with gate_op at python level. # Therefore, a C++ implementation of this multi-device FIFOQueue would be a better choice. # TODO(peng.wu): make this whole queue implementation work at C++ TF queue resource level. return self._merge_tensor_list_by_device([q.dequeue() for q in self._qs]) def size(self): if len(self._qs) == 1: return self.queue.size() else: # Based on the assumption commented in the above "dequeue()" method, # it allows to check cpu-device queue size to get the size for all. return self.queues[0].size() def _split_tensor_list_by_device(self, tensors): """List of tensors to (List of CPU Tensors, List of GPU Tensors).""" split_tensors = ([], []) # 0: CPU, 1: GPU self._split_tensors_indices = [] for f in tensors: tuple_i = 1 if "GPU" in f.device else 0 l = split_tensors[tuple_i] idx = (tuple_i, len(l)) l.append(f) self._split_tensors_indices.append(idx) return split_tensors def _merge_tensor_list_by_device(self, tensor_lists): return [tensor_lists[i][j] for i, j in self._split_tensors_indices] class MultiQueueRunner(tf.compat.v1.train.QueueRunner): def _init_from_args(self, queue=None, enqueue_ops=None, close_op=None, cancel_op=None, queue_closed_exception_types=None): """Create a QueueRunner from arguments.""" if isinstance(queue, list): close_op = tf.group([q.close() for q in queue]) cancel_op = tf.group( [q.close(cancel_pending_enqueues=True) for q in queue]) # else fallback to original QueueRunner super()._init_from_args( queue=queue, enqueue_ops=enqueue_ops, close_op=close_op, cancel_op=cancel_op, queue_closed_exception_types=queue_closed_exception_types) @property def queue(self): if isinstance(self._queue, list): raise NotImplementedError( "When using multi-device queues, this interface is disabled.") return self._queue @property def name(self): if isinstance(self._queue, list): return self._queue[0].name return self._queue.name class EnqueueHook(tf.estimator.SessionRunHook): def __init__(self, q: _QueueBase): self._q_runner = MultiQueueRunner(q.queues, [q.enqueue_op]) self._threads = [] def after_create_session(self, session, coord): self._threads = self._q_runner.create_threads(session, coord=coord, start=True, daemon=True) def enqueue_dicts_with_queue_return( tensors, capacity: int = 1, queue_name: str = "prefetch_queue") -> Tuple[Any, Optional[_FIFOQueue]]: """tensors can be any nested structures (list, tuple, dict) with tensors""" if capacity == 0: return tensors, None nested = nested_tensors.NestedTensors(tensors) flatten_tensors = nested.get_tensors() queue = _MultiFIFOQueue(dense_list=flatten_tensors, capacity=capacity, queue_name=queue_name) with tf.init_scope(): dequeue_dense_list = queue.dequeue() return nested.get_nested_result(dequeue_dense_list), queue class AsyncPushHook(tf.estimator.SessionRunHook): def __init__(self, queue, ops): self._queue = queue self._queue_init = False self._run_ops = ops def begin(self): self._queue_size = self._queue.size() def before_run(self, run_context): if self._queue_init: return tf.estimator.SessionRunArgs(self._run_ops) def after_run(self, run_context, run_values): if not self._queue_init: self._queue_init = run_context.session.run(self._queue_size) > 0 def end(self, session): while session.run(self._queue_size) > 0: session.run(self._run_ops) class AsyncFunctionMgr: """A class that supports adding async functions""" def __init__(self, is_async: bool = True): """ Args: is_async - by default, added async function will be executed asyncly or not. """ self._is_async = is_async self._hooks = [] def add_async_function( self, target: Callable, args: Tuple = None, kwargs: Dict = None, is_async: bool = None, queue_name: str = "async_queue") -> Union[tf.Operation, Any]: """ Args: is_async - if execute target synchronously. If None, will use default value in __init__. """ if is_async is None: is_async = self._is_async if args is None: args = () if kwargs is None: kwargs = {} if is_async: # This prevents from using an empty input list. args = (args) + (tf.constant(0, name="dummy_tensor_for_async_function"),) (args, kwargs), queue = enqueue_dicts_with_queue_return( (args, kwargs), queue_name=queue_name) dummy_op = tf.no_op(name="dummy_depended_op_for_async_function") with tf.init_scope(), tf.control_dependencies([args[-1], dummy_op]): run_ops = target(*args[0:-1], **kwargs) self._hooks.append(AsyncPushHook(queue, run_ops)) # Check ops dependence in async func. utils.check_ops_dependence(queue.enqueue_op.name, dummy_op.name) return queue.enqueue_op else: return target(*args, **kwargs) @property def hooks(self): return self._hooks ================================================ FILE: monolith/native_training/prefetch_queue_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import time import tensorflow as tf from tensorflow.python.framework import test_util from monolith.native_training import nested_tensors from monolith.native_training import prefetch_queue class GPUCompatiblePaddingFIFOQueueTests(tf.test.TestCase): def testEnqueueAndDequeue(self): with test_util.use_gpu(): q = prefetch_queue._GPUCompatiblePaddingFIFOQueue(10, tf.float32, ((),)) elems_numpy = [10.0, 20.0, 30.0] _a = tf.constant(2.0) elems = [tf.constant(x) / _a for x in elems_numpy] # MemcpyH2D * 3: [10.0, 20.0, 30.0] # MemcpyH2D * 3: _a # GPU RealDiv * 3 for x in elems: self.evaluate(q.enqueue((x,))) _b = tf.constant(3.0) _c = tf.constant(1.0) dequeued_tensor = q.dequeue() for i in range(len(elems)): # Ensure same device self.assertEqual(elems[0].device, dequeued_tensor.device) # MemcpyH2D * 3: _b # MemcpyH2D * 3: _c # GPU Mul * 3 # GPU Add * 3 # MemcpyD2H * 3: return vals = self.evaluate(dequeued_tensor * _b + _c) self.assertEqual([elems_numpy[i] / 2 * 3 + 1], vals) def testGPUQueueCPUTensor(self): with tf.device("CPU:0"): elems_numpy = [7, 8, 9] _a = tf.constant(5) elems = [tf.constant(x) * _a for x in elems_numpy] with test_util.use_gpu(): # MemcpyH2D * 3 q = prefetch_queue._GPUCompatiblePaddingFIFOQueue(10, tf.int32, ((),)) # MemcpyD2H * 3 # Note that even though the below enqueue/dequeue are declared on CPU, # it still copys-to and holds the enqueued resources on GPU. # So to pin the tensors on CPU, we need to declare the queue itself on CPU. for x in elems: with tf.device("CPU:0"): self.evaluate(q.enqueue((x,))) with tf.device("CPU:0"): dequeued_tensor = q.dequeue() for i in range(len(elems)): self.assertEqual(elems[0].device, dequeued_tensor.device) with tf.device("CPU:0"): vals = self.evaluate(dequeued_tensor + 2) self.assertEqual([elems_numpy[i] * 5 + 2], vals) def testMultiEnqueueAndDequeue(self): with test_util.use_gpu(): q = prefetch_queue._GPUCompatiblePaddingFIFOQueue(10, (tf.int32, tf.float32), ((), ())) elems_numpy = [(5, 10.0), (10, 20.0), (15, 30.0)] elems = [(tf.constant(x), tf.constant(y)) for x, y in elems_numpy] for x, y in elems: self.evaluate(q.enqueue((x, y))) dequeued_tensor = q.dequeue() print(dequeued_tensor[0].device) for i in range(len(elems)): self.assertEqual(elems[i][0].device, dequeued_tensor[0].device) x_val, y_val = self.evaluate(dequeued_tensor) x, y = elems_numpy[i] self.assertEqual(x, x_val) self.assertEqual(y, y_val) def testIdentityHelper(self): with tf.device("CPU:0"): a = tf.constant(1) b = a + 1 with test_util.use_gpu(): c = tf.identity(b) q = prefetch_queue._GPUCompatiblePaddingFIFOQueue(1, tf.int32, ((),)) self.evaluate(q.enqueue(c)) # MemcpyH2D: CPU b to GPU c pinned in queue self.assertAllEqual(self.evaluate(q.dequeue()), 2) # MemcpyD2H: return class FIFOQueueTest(tf.test.TestCase): def test_fifo_queue_data(self): dense_tensors = [tf.constant(2.), tf.constant([[3], [4]])] ragged_tensors = [ tf.ragged.constant([[2], [-1, 3]]), tf.ragged.constant([[1.], []]) ] nested = nested_tensors.NestedTensors((dense_tensors, ragged_tensors)) flatten_tensors = nested.get_tensors() queue = prefetch_queue._FIFOQueue(flatten_tensors) dequeued = queue.dequeue() dequeued_dense, dequeued_ragged = nested.get_nested_result(dequeued) with self.session() as sess: sess.run(queue.enqueue_op) dequeued_dense, dequeued_ragged = sess.run( [dequeued_dense, dequeued_ragged]) self.assertAllClose(dequeued_dense[0], 2.) self.assertAllEqual(dequeued_dense[1], [[3], [4]]) self.assertAllEqual(dequeued_ragged[0], [[2], [-1, 3]]) self.assertAllClose(dequeued_ragged[1], [[1.], []]) def test_fifo_queue_capacity(self): dense_tensors = [tf.constant([2])] queue = prefetch_queue._FIFOQueue(dense_tensors, capacity=4) dequeue_result = queue.dequeue() with self.session() as sess: for _ in range(4): sess.run(queue.enqueue_op) for _ in range(4): result = sess.run(dequeue_result) self.assertAllEqual(result[0], [2]) class PrefetchTest(tf.test.TestCase): def test_enqueue_dicts_with_queue_return(self): dense_dicts = [{ "dense_0_0": tf.constant(2.), "dense_0_1": tf.constant([[3], [4]]) }, { "dense_1_0": tf.constant([0]) }] ragged_dicts = [{ "ragged_0_0": tf.ragged.constant([[2], [-1, 3]]), "ragged_0_1": tf.ragged.constant([[1.], []]) }, { "ragged_1_0": tf.ragged.constant([[0, 0], [1]]) }] with test_util.use_gpu(): dense_dicts[0]["dense_0_0"] += 1.0 result = prefetch_queue.enqueue_dicts_with_queue_return( (dense_dicts, ragged_dicts), capacity=3) (dequeue_dense_dicts, dequeue_ragged_dicts), queue = result with self.session() as sess: for _ in range(5): sess.run(queue.enqueue_op) dense_dicts_result = sess.run(dequeue_dense_dicts) self.assertAllClose(dense_dicts_result[0]["dense_0_0"], 2. + 1.0) self.assertAllEqual(dense_dicts_result[0]["dense_0_1"], [[3], [4]]) self.assertAllEqual(dense_dicts_result[1]["dense_1_0"], [0]) sess.run(queue.enqueue_op) dense_dicts_result, ragged_dicts_result = sess.run( [dequeue_dense_dicts, dequeue_ragged_dicts]) self.assertAllEqual(ragged_dicts_result[0]["ragged_0_0"], [[2], [-1, 3]]) self.assertAllClose(ragged_dicts_result[0]["ragged_0_1"], [[1.], []]) self.assertAllEqual(ragged_dicts_result[1]["ragged_1_0"], [[0, 0], [1]]) self.assertAllClose(dense_dicts_result[0]["dense_0_0"], 2. + 1.0) def test_enqueue_dicts_with_queue_return(self): tensors = ([{ "a": tf.constant(1.0), "b": "abc", "c": None, "d": None, }], { "a": tf.Variable(0.5), "b": tf.ragged.constant([[1.0]]) }) dequeued_tensors, q = prefetch_queue.enqueue_dicts_with_queue_return( tensors) self.assertAllEqual(dequeued_tensors[0][0]["b"], "abc") del dequeued_tensors[0][0]["b"] self.assertEqual(dequeued_tensors[0][0]["c"], None) del dequeued_tensors[0][0]["c"] self.assertEqual(dequeued_tensors[0][0]["d"], None) del dequeued_tensors[0][0]["d"] with tf.compat.v1.train.SingularMonitoredSession() as sess: sess.run(q.enqueue_op) tensors = sess.run(dequeued_tensors) self.assertAllEqual(tensors[0][0]["a"], 1.0) self.assertAllEqual(tensors[1]["a"], 0.5) self.assertAllEqual(tensors[1]["b"], [[1.0]]) def test_enqueue_dicts_with_control_flow(self): v = tf.Variable(0) with tf.control_dependencies([v.assign_add(1)]): tensor, q = prefetch_queue.enqueue_dicts_with_queue_return(tf.constant(0)) with tf.compat.v1.train.MonitoredSession() as sess: sess.run(q.enqueue_op) sess.run(tensor) self.assertAllEqual(sess.run(v), 1) def test_enqueue_with_zero_capacity(self): dense_dicts = [{"dense": tf.constant([0])}] ragged_dicts = [{"ragged": tf.ragged.constant([[0, 0], [1]])}] result = prefetch_queue.enqueue_dicts_with_queue_return( (dense_dicts, ragged_dicts), 0) (dequeue_dense_dicts, dequeue_ragged_dicts), queue = result with self.session() as sess: dequeue_dense_dicts = sess.run(dequeue_dense_dicts) dequeue_ragged_dicts = sess.run(dequeue_ragged_dicts) self.assertAllEqual(dequeue_dense_dicts[0]["dense"], [0]) self.assertAllEqual(dequeue_ragged_dicts[0]["ragged"], [[0, 0], [1]]) def test_estimator_prefetch(self): def input_fn(): return tf.data.Dataset.range(0, 20).map( lambda x: {"rag": tf.ragged.constant([[0], []], dtype=tf.int64) + x}) def model_fn(features, mode): ragged = features["rag"] ragged_dicts = [{"ragged": ragged}] dequeue_raggeds, queue = prefetch_queue.enqueue_dicts_with_queue_return( ragged_dicts) predictions = dequeue_raggeds[0]["ragged"].values enqueue_hook = prefetch_queue.EnqueueHook(queue) global_step = tf.compat.v1.train.get_or_create_global_step() train_op = tf.compat.v1.assign_add(global_step, 1) return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions, prediction_hooks=(enqueue_hook,), train_op=train_op) estimator = tf.estimator.Estimator(model_fn) predicts = estimator.predict(input_fn) self.assertAllEqual(list(range(20)), list(predicts)) class AsyncManagerTest(tf.test.TestCase): def testBasic(self): x = tf.Variable(0.0) def add(y): return x.assign_add(y) mgr = prefetch_queue.AsyncFunctionMgr() op = mgr.add_async_function(add, (tf.constant(1.0),)) with tf.compat.v1.train.SingularMonitoredSession(hooks=mgr.hooks) as sess: sess.run(op) # Make push happen. sess.run(op) x_value = sess.run(x) # Since it is async pushed, the value will be 1. self.assertAllEqual(x_value, 1.0) def testSync(self): x = tf.Variable(0.0) def add(y): return x.assign_add(y) mgr = prefetch_queue.AsyncFunctionMgr(is_async=False) op = mgr.add_async_function(add, (tf.constant(1.0),)) with tf.compat.v1.train.SingularMonitoredSession(hooks=mgr.hooks) as sess: sess.run(op) x_value = sess.run(x) self.assertAllEqual(x_value, 1.0) def testEmptyInput(self): x = tf.Variable(0) def add(): return x.assign_add(1) mgr = prefetch_queue.AsyncFunctionMgr() op = mgr.add_async_function(add) with tf.compat.v1.train.SingularMonitoredSession(hooks=mgr.hooks) as sess: sess.run(op) # Make push happen. sess.run(op) x_value = sess.run(x) # Since it is async pushed, the value will be 1. self.assertAllEqual(x_value, 1) if __name__ == "__main__": tf.compat.v1.disable_eager_execution() tf.test.main() ================================================ FILE: monolith/native_training/proto/BUILD ================================================ load("@com_github_grpc_grpc//bazel:python_rules.bzl", "py_grpc_library", "py_proto_library") load("@rules_proto//proto:defs.bzl", "proto_library") package(default_visibility = ["//visibility:public"]) proto_library( name = "primus_am_service_proto", srcs = ["primus_am_service.proto"], deps = ["@com_google_protobuf//:wrappers_proto"], ) py_proto_library( name = "primus_am_service_py_proto", deps = [ ":primus_am_service_proto", ], ) py_grpc_library( name = "primus_am_service_py_proto_grpc", srcs = ["primus_am_service_proto"], deps = [":primus_am_service_py_proto"], ) proto_library( name = "debugging_info_proto", srcs = ["debugging_info.proto"], ) py_proto_library( name = "debugging_info_py_proto", deps = [ ":debugging_info_proto", ], ) proto_library( name = "ckpt_info_proto", srcs = ["ckpt_info.proto"], ) py_proto_library( name = "ckpt_info_py_proto", deps = [ ":ckpt_info_proto", ], ) ================================================ FILE: monolith/native_training/proto/ckpt_info.proto ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. syntax = "proto3"; package monolith; message CkptInfo { map slot_counts = 1; map feature_counts = 2; } ================================================ FILE: monolith/native_training/proto/debugging_info.proto ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. syntax = "proto3"; message DebuggingInfo { Cluster cluster = 1; uint32 num_workers = 2; repeated FeatureNameConfig feature_name_configs = 3; } message Cluster { string chief_addr = 1; repeated string ps_addrs = 2; } message FeatureNameConfig { string feature_name = 1; string config_str = 2; } ================================================ FILE: monolith/native_training/proto/primus_am_service.proto ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. syntax = "proto3"; option java_package = "com.bytedance.primus.proto"; option java_generate_equals_and_hash = true; import "google/protobuf/wrappers.proto"; package primus; enum PonyState { PS_REGISTERING = 0; PS_REGISTERED = 1; RUNNING = 2; FINISH = 3; } message PonyHeartbeatRequest { } message PonyHeartbeatResponse { PonyState state = 1; } message PSInfo { string ip = 1; repeated int32 ports = 2; int32 shard_id = 3; string name = 4; } message PonyGetPSInfoRequest { } message PonyGetPSInfoResponse { repeated PSInfo ps_info = 1; } message PonyStartWorkerRequest { } message PonyStartWorkerResponse { } message SucceedRequest { int32 exit_code = 1; string diagnose = 2; google.protobuf.Int64Value graceful_shutdown_timeout_ms = 3; } message SucceedResponse { } message KillRequest { int32 exit_code = 1; string diagnose = 2; google.protobuf.Int64Value graceful_shutdown_timeout_ms = 3; } message KillResponse { } message SuspendRequest { int32 snapshot_id = 1; } message SuspendResponse { } message SuspendStatusRequest { } message SuspendStatusResponse { bool succeed = 1; string message = 2; } message ResumeRequest { } message ResumeResponse { } message GetSnapshotRequest { int32 snapshot_id = 1; } message GetSnapshotResponse { bool available = 1; string dir = 2; } message ProgressRequest { } message ProgressResponse { float progress = 1; } message StarvingRequest { } message StarvingResponse { bool starving = 1; } message StatusRequest { } message StatusResponse { string app_id = 1; string final_status = 2; string track_url = 3; } message TaskTimePointRequest { } message TaskTimePointResponse { string time_point = 1; } message CreateSavepointRequest { string savepoint_dir = 1; } message CreateSavepointResponse { int32 code = 1; string message = 2; string savepoint_id = 3; } message CreateSavepointStatusRequest { string savepoint_restore_id = 1; } message CreateSavepointStatusResponse { enum CreateSavepointState { PENDING = 0; RUNNING = 1; SUCCEEDED = 2; FAILED = 3;; } int32 code = 1; string message = 2; CreateSavepointState create_savepoint_state = 3; } service AppMasterService { rpc ponyHeartbeat (PonyHeartbeatRequest) returns (PonyHeartbeatResponse); rpc ponyGetPSInfo (PonyGetPSInfoRequest) returns (PonyGetPSInfoResponse); rpc ponyStartWorker (PonyStartWorkerRequest) returns (PonyStartWorkerResponse); rpc succeed (SucceedRequest) returns (SucceedResponse); rpc kill (KillRequest) returns (KillResponse); rpc suspend (SuspendRequest) returns (SuspendResponse); rpc suspendStatus (SuspendStatusRequest) returns (SuspendStatusResponse); rpc resume (ResumeRequest) returns (ResumeResponse); rpc getSnapshot (GetSnapshotRequest) returns (GetSnapshotResponse); rpc progress (ProgressRequest) returns (ProgressResponse); rpc isStarving (StarvingRequest) returns (StarvingResponse); rpc status (StatusRequest) returns (StatusResponse); rpc getTaskTimePoint (TaskTimePointRequest) returns (TaskTimePointResponse); rpc createSavepoint(CreateSavepointRequest) returns (CreateSavepointResponse); rpc createSavepointStatus(CreateSavepointStatusRequest) returns (CreateSavepointStatusResponse); } ================================================ FILE: monolith/native_training/ps_benchmark.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import dataclasses import time from typing import Dict, List from absl import logging import tensorflow as tf from monolith.native_training.optimizers.adamom import AdamomOptimizer from monolith.native_training import logging_ops from monolith.native_training import native_task from monolith.native_training import service_discovery from monolith.native_training import utils # We need a scope name unique enough to prevent name confliction. _SCOPE_NAME = "machine_benchmark_I_AM_PECULIAR" @dataclasses.dataclass class BenchmarkConfig: ps_list: List num_ps_required: int num_workers: int index: int benchmark_secs: float = 60.0 # If non-empty, it will skip the benchmark and # use `ps_str_overridden` here. Separated by `,` # # Example: `127.0.0.1:1,127.0.0.1:2` ps_str_overridden: str = "" class _BenchmarkWorkerHook(tf.estimator.SessionRunHook): def __init__(self, config: BenchmarkConfig, throughput_tensor: tf.Tensor): with tf.name_scope(_SCOPE_NAME): self._config = config self._throughput_tensor = throughput_tensor self._result = tf.Variable("", trainable=False) self._result_placeholder = tf.compat.v1.placeholder(tf.string, []) self._result_assign = self._result.assign(self._result_placeholder) self._ready = tf.Variable([False] * self._config.num_workers, trainable=False) self._make_ready = self._ready[self._config.index].assign(True) self._done = tf.Variable([False] * self._config.num_workers, trainable=False) self._make_done = self._done[self._config.index].assign(True) self._start_time = None def after_create_session(self, sess, coord): sess.run(self._make_ready) self._wait(lambda: sum(sess.run(self._ready)) >= int(self._config. num_workers * 0.9)) # Before we start we wait for another 1 secs to make sure everyone # got the result. time.sleep(1) logging.info("Benchmark started.") self._start_time = time.time() def before_run(self, run_context): if self._config.ps_str_overridden: run_context.session.run( self._result_assign, feed_dict={self._result_placeholder: self._config.ps_str_overridden}) result_value = run_context.session.run(self._result) if result_value: raise tf.errors.OutOfRangeError(None, None, "Benchmark is done already.") def after_run(self, run_context, run_values): duration = time.time() - self._start_time logging.info("Benchmarking {} seconds".format(duration)) run_context.request_stop() def end(self, sess): sess.run(self._make_done) self._wait(lambda: sum(sess.run(self._done)) == self._config.num_workers, timeout=10) if self._config.index == 0 and not self._config.ps_str_overridden: # OK now we know how ps should look like. throughput_value = sess.run(self._throughput_tensor) reversed_sorted_throughput_and_ps = sorted( [[throughput, i, self._config.ps_list[i].split(":")[0]] for i, throughput in enumerate(throughput_value)]) sorted_throughput_and_ps = [ item[:] for item in reversed_sorted_throughput_and_ps ] logging.info("Measure result (throughput, ps): {}".format([ "ps_{}({}):{}".format(ps, ip, throughput) for throughput, ps, ip in reversed(sorted_throughput_and_ps) ])) for i in range(len(reversed_sorted_throughput_and_ps) - 1): for j in range(i + 1, len(reversed_sorted_throughput_and_ps)): if reversed_sorted_throughput_and_ps[i][ 2] == reversed_sorted_throughput_and_ps[j][2]: sorted_throughput_and_ps[j][0] += reversed_sorted_throughput_and_ps[ i][0] sorted_throughput_and_ps = sorted(sorted_throughput_and_ps, reverse=True) logging.info( "Measure result (throughput, ps) (ps with the same ip addresses had their throughput adjusted): {}" .format([ "ps_{}({}):{}".format(ps, ip, throughput) for throughput, ps, ip in sorted_throughput_and_ps ])) selected_ps = [ self._config.ps_list[i] for throughput, i, _ in sorted_throughput_and_ps[:self._config.num_ps_required] ] ps_str = ",".join(selected_ps) sess.run(self._result_assign, feed_dict={self._result_placeholder: ps_str}) ps_str = "" def ps_ready(): nonlocal ps_str ps_str = sess.run(self._result) return bool(ps_str) self._wait(ps_ready) self._config.ps_list.clear() selected_ps = ps_str.decode().split(",") for i in range(self._config.num_ps_required): self._config.ps_list.append(selected_ps[i]) def _wait(self, cond, timeout=3600): start_time = time.time() while time.time() - start_time < timeout: if cond(): break time.sleep(0.5) class _DummyCheckpointSaverHook(tf.estimator.CheckpointSaverHook): """A saver hook which won't perform the first save (which happpend on after_create_session).""" def __init__(self, checkpoint_dir=None, save_steps=10240, **kwargs): if not checkpoint_dir: checkpoint_dir = os.path.join(os.environ.get('HOME', "/"), 'tmp') super(_DummyCheckpointSaverHook, self).__init__(checkpoint_dir, save_steps) logging.info("Create DummyCheckpointSaverHook.") def begin(self): return def after_create_session(self, session, coord): return def before_run(self, run_context): return None def after_run(self, run_context, run_values): return def end(self, session): return def _save(self, session, step: int) -> bool: return False class PsBenchMarkTask(native_task.NativeTask): @classmethod def params(cls): p = super().params() p.define("bm_config", None, "The BenchmarkConfig.") return p def create_input_fn(self, mode): del mode def input_fn(): with tf.name_scope(_SCOPE_NAME): return tf.data.Dataset.from_tensor_slices([[ tf.constant(0.12), tf.constant(0.23), tf.constant(0.34), tf.constant(0.45) ]]).repeat().prefetch(2) return input_fn def create_model_fn(self): def model_fn(features, mode, config): logging.info("Running model_fn of the ps benchmark") del config bm_config: BenchmarkConfig = self.p.bm_config global_step = tf.compat.v1.train.get_or_create_global_step() with tf.name_scope(_SCOPE_NAME): throughputs = [] for ps_i in range(len(bm_config.ps_list)): with tf.device(utils.ps_device(ps_i)): var = tf.Variable(initial_value=[[0.0] * 256] * 256, trainable=True) with tf.control_dependencies([features]): ts_before = tf.timestamp() i = tf.constant(0) grad = tf.reshape(tf.tile(features, [16384]), [256, 256]) def while_body(i): nonlocal var nonlocal grad with tf.control_dependencies([i]): new_grads = tf.split(grad, [64, 64, 64, 64], axis=1) output_grads = [] for ii in range(4): sum_grads = [] for jj in range(10): a, b, c, d = tf.split(new_grads[ii] + tf.cast(jj / 10, dtype=tf.float32), [16, 16, 16, 16], axis=1) for _ in range(10): sum_grads.append(tf.math.sqrt(tf.math.sqrt(a * b) * c + d)) output_grads.append(tf.math.add_n(sum_grads)) concat_grads = tf.concat(output_grads, -1) var_fetched = tf.identity(var) with tf.control_dependencies([var_fetched, concat_grads]): return i + 1 def cond(i): nonlocal ts_before with tf.control_dependencies([i]): ts_now = tf.timestamp() return ts_now - ts_before <= bm_config.benchmark_secs with tf.device(utils.ps_device(ps_i)): (i,) = tf.while_loop(cond, while_body, [i]) j = tf.identity(i) with tf.control_dependencies([j]): ts_now = tf.timestamp() throughput = tf.cast(j, tf.float32) / tf.cast( ts_now - ts_before, tf.float32) throughputs.append(throughput) mean_throughput, update_op = tf.compat.v1.metrics.mean_tensor( tf.stack(throughputs)) hook = _BenchmarkWorkerHook(bm_config, mean_throughput) saver_hook = _DummyCheckpointSaverHook() inc_global_step = global_step.assign_add(1) if mode == tf.estimator.ModeKeys.PREDICT: return tf.estimator.EstimatorSpec(mode=mode, predictions=tf.constant(0.0)) return tf.estimator.EstimatorSpec(mode=mode, loss=tf.constant(0.0), train_op=tf.group( update_op, inc_global_step), training_hooks=[hook], training_chief_hooks=[saver_hook]) return model_fn ================================================ FILE: monolith/native_training/ps_benchmark_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import tensorflow as tf from absl import app from monolith.native_training import ps_benchmark from monolith.native_training import cpu_training from monolith.native_training import utils class PsBenchmarkTest(tf.test.TestCase): def testBasic(self): p = ps_benchmark.PsBenchMarkTask.params() p.bm_config = ps_benchmark.BenchmarkConfig(ps_list=["ps0", "ps1"], num_ps_required=1, num_workers=1, index=0, benchmark_secs=1.0) cpu_training.local_train(p, num_ps=2, model_dir=utils.get_test_tmp_dir() + "/basic") self.assertEqual(len(p.bm_config.ps_list), 1) def testSkipBenchmark(self): p = ps_benchmark.PsBenchMarkTask.params() p.bm_config = ps_benchmark.BenchmarkConfig(ps_list=["ps0", "ps1"], num_ps_required=1, num_workers=1, index=0, benchmark_secs=1.0, ps_str_overridden="overridden") cpu_training.local_train(p, num_ps=2, model_dir=utils.get_test_tmp_dir() + "/skip_benchmark") self.assertEqual(p.bm_config.ps_list[0], "overridden") if __name__ == "__main__": tf.compat.v1.disable_eager_execution() app.run(tf.test.main) ================================================ FILE: monolith/native_training/ragged_utils.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import tensorflow as tf from monolith.native_training.runtime.ops import gen_monolith_ops ops = gen_monolith_ops def fused_value_rowids(rt: tf.RaggedTensor): """Equivalent to rt.value_rowids(), but with much less ops.""" if not isinstance(rt, tf.RaggedTensor): raise ValueError("rt must be RaggedTensor") if not hasattr(rt, "monolith_fused_value_rowids"): rt.monolith_fused_value_rowids = ops.monolith_fused_value_rowids( rt.row_splits) return rt.monolith_fused_value_rowids ================================================ FILE: monolith/native_training/ragged_utils_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import tensorflow as tf from monolith.native_training import ragged_utils class RaggedUtilsTestCase(tf.test.TestCase): def test_basic(self): rt = tf.ragged.constant([[], [1], [2, 3]]) valueids = ragged_utils.fused_value_rowids(rt) valueids2 = ragged_utils.fused_value_rowids(rt) self.assertIs(valueids, valueids2) self.assertAllEqual(valueids, [1, 2, 2]) if __name__ == "__main__": tf.compat.v1.disable_eager_execution() tf.test.main() ================================================ FILE: monolith/native_training/remote_predict_ops.py ================================================ ================================================ FILE: monolith/native_training/restore_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import tensorflow as tf from monolith.native_training import basic_restore_hook from monolith.native_training import hash_table_ops from monolith.native_training import save_utils from monolith.native_training import utils def _generate_config(servers, job_name=utils.PS_JOB_NAME): """Generates a config based on servers""" cluster_def = tf.train.ClusterDef() job = cluster_def.job.add() job.name = job_name for i, server in enumerate(servers): job.tasks[i] = server.target[len('grpc://'):] session_config = tf.compat.v1.ConfigProto(cluster_def=cluster_def) session_config.experimental.share_session_state_in_clusterspec_propagation = True return session_config def _get_id_tensor(x): return tf.constant(x, dtype=tf.int64) class PartialRestoreTest(tf.test.TestCase): def build_graph(self): with tf.device(utils.ps_device(0)): global_step = tf.compat.v1.train.get_or_create_global_step() global_step_op = tf.compat.v1.assign_add(global_step, 1) v0 = tf.Variable(0, name="v0") op0 = tf.compat.v1.assign_add(v0, 1) hash_table0 = hash_table_ops.test_hash_table(1, name_suffix="0") add_op0 = hash_table0.assign_add(_get_id_tensor([0]), tf.constant([[1]], dtype=tf.float32)).as_op() lookup0 = hash_table0.lookup(_get_id_tensor([0])) with tf.device(utils.ps_device(1)): v1 = tf.Variable(0, name="v1") op1 = tf.compat.v1.assign_add(v1, 1) hash_table1 = hash_table_ops.test_hash_table(1, name_suffix="1") add_op1 = hash_table1.assign_add(_get_id_tensor([1]), tf.constant([[1]], dtype=tf.float32)).as_op() lookup1 = hash_table1.lookup(_get_id_tensor([1])) return tf.group(global_step_op, op0, op1, add_op0, add_op1), v0, v1, lookup0, lookup1 def test_restore_with_ps_monitor(self): basename = os.path.join(os.environ["TEST_TMPDIR"], "test_restore_with_ps_monitor", "model.ckpt") with tf.compat.v1.Graph().as_default(): train_op, v0, v1, lookup0, lookup1 = self.build_graph() ps_monitor = save_utils.PsMonitor(2) saver = save_utils.PartialRecoverySaver(sharded=True, max_to_keep=10, keep_checkpoint_every_n_hours=2, ps_monitor=ps_monitor) tf.compat.v1.add_to_collection(tf.compat.v1.GraphKeys.SAVERS, saver) saver_listener = hash_table_ops.HashTableCheckpointSaverListener(basename) saver_hook = save_utils.NoFirstSaveCheckpointSaverHook( os.path.dirname(basename), save_steps=1, saver=saver, listeners=[saver_listener]) restore_listener = hash_table_ops.HashTableCheckpointRestorerListener( basename, ps_monitor) restore_hook = basic_restore_hook.CheckpointRestorerHook( listeners=[restore_listener]) server1 = tf.distribute.Server.create_local_server() server2 = tf.distribute.Server.create_local_server() config = _generate_config([server1, server2]) # save checkpoint at first session with tf.compat.v1.train.SingularMonitoredSession( hooks=[restore_hook, saver_hook], master=server1.target, config=config, checkpoint_dir=os.path.dirname(basename)) as mon_sess: sess = mon_sess.raw_session() sess.run(train_op) v0_val = sess.run(v0) v1_val = sess.run(v1) embedding0 = sess.run(lookup0) embedding1 = sess.run(lookup1) self.assertAllEqual(v0_val, 1) self.assertAllEqual(v1_val, 1) self.assertAllEqual(embedding0, [[1]]) self.assertAllEqual(embedding1, [[1]]) # change variables at second session with tf.compat.v1.Session(server1.target, config=config) as sess: sess.run(train_op) v0_val = sess.run(v0) v1_val = sess.run(v1) embedding0 = sess.run(lookup0) embedding1 = sess.run(lookup1) self.assertAllEqual(v0_val, 2) self.assertAllEqual(v1_val, 2) self.assertAllEqual(embedding0, [[2]]) self.assertAllEqual(embedding1, [[2]]) server3 = tf.distribute.Server.create_local_server() server4 = tf.distribute.Server.create_local_server() config = _generate_config([server3, server4]) # restore all variables at third session with tf.compat.v1.train.SingularMonitoredSession( hooks=[restore_hook, saver_hook], master=server3.target, config=config, checkpoint_dir=os.path.dirname(basename)) as mon_sess: sess = mon_sess.raw_session() v0_val = sess.run(v0) v1_val = sess.run(v1) embedding0 = sess.run(lookup0) embedding1 = sess.run(lookup1) self.assertAllEqual(v0_val, 1) self.assertAllEqual(v1_val, 1) self.assertAllEqual(embedding0, [[1]]) self.assertAllEqual(embedding1, [[1]]) server5 = tf.distribute.Server.create_local_server() config = _generate_config([server1, server5]) # partial restore at fourth session with tf.compat.v1.train.SingularMonitoredSession( hooks=[restore_hook, saver_hook], master=server1.target, config=config, checkpoint_dir=os.path.dirname(basename)) as mon_sess: sess = mon_sess.raw_session() v0_val = sess.run(v0) v1_val = sess.run(v1) embedding0 = sess.run(lookup0) embedding1 = sess.run(lookup1) self.assertAllEqual(v0_val, 2) self.assertAllEqual(v1_val, 1) self.assertAllEqual(embedding0, [[2]]) self.assertAllEqual(embedding1, [[1]]) def test_restore_without_ps_monitor(self): basename = os.path.join(os.environ["TEST_TMPDIR"], "test_restore_without_ps_monitor", "model.ckpt") with tf.compat.v1.Graph().as_default(): train_op, v0, v1, lookup0, lookup1 = self.build_graph() saver = save_utils.PartialRecoverySaver(sharded=True, max_to_keep=10, keep_checkpoint_every_n_hours=2) tf.compat.v1.add_to_collection(tf.compat.v1.GraphKeys.SAVERS, saver) saver_listener = hash_table_ops.HashTableCheckpointSaverListener(basename) saver_hook = save_utils.NoFirstSaveCheckpointSaverHook( os.path.dirname(basename), save_steps=1, saver=saver, listeners=[saver_listener]) restore_listener = hash_table_ops.HashTableCheckpointRestorerListener( basename) restore_hook = basic_restore_hook.CheckpointRestorerHook( listeners=[restore_listener]) server1 = tf.distribute.Server.create_local_server() server2 = tf.distribute.Server.create_local_server() config = _generate_config([server1, server2]) # save checkpoint at first session with tf.compat.v1.train.SingularMonitoredSession( hooks=[restore_hook, saver_hook], master=server1.target, config=config, checkpoint_dir=os.path.dirname(basename)) as mon_sess: sess = mon_sess.raw_session() sess.run(train_op) v0_val = sess.run(v0) v1_val = sess.run(v1) embedding0 = sess.run(lookup0) embedding1 = sess.run(lookup1) self.assertAllEqual(v0_val, 1) self.assertAllEqual(v1_val, 1) self.assertAllEqual(embedding0, [[1]]) self.assertAllEqual(embedding1, [[1]]) # change variables at second session with tf.compat.v1.Session(server1.target, config=config) as sess: sess.run(train_op) v0_val = sess.run(v0) v1_val = sess.run(v1) embedding0 = sess.run(lookup0) embedding1 = sess.run(lookup1) self.assertAllEqual(v0_val, 2) self.assertAllEqual(v1_val, 2) self.assertAllEqual(embedding0, [[2]]) self.assertAllEqual(embedding1, [[2]]) # restore all variables at third session with tf.compat.v1.train.SingularMonitoredSession( hooks=[restore_hook, saver_hook], master=server1.target, config=config, checkpoint_dir=os.path.dirname(basename)) as mon_sess: sess = mon_sess.raw_session() v0_val = sess.run(v0) v1_val = sess.run(v1) embedding0 = sess.run(lookup0) embedding1 = sess.run(lookup1) self.assertAllEqual(v0_val, 1) self.assertAllEqual(v1_val, 1) self.assertAllEqual(embedding0, [[1]]) self.assertAllEqual(embedding1, [[1]]) if __name__ == "__main__": tf.compat.v1.disable_eager_execution() tf.test.main() ================================================ FILE: monolith/native_training/runner_utils.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from absl import logging, flags from contextlib import contextmanager from dataclasses import dataclass, field, Field from enum import Enum import json import os, sys, traceback from threading import RLock import time from absl.flags import FlagValues from google.protobuf import text_format import tensorflow as tf from tensorflow.python.lib.io import file_io from tensorflow.python.util.tf_export import tf_export from tensorflow.python.training import checkpoint_management from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState from monolith.native_training.cpu_training import DistributedCpuTrainingConfig from monolith.native_training.service_discovery import ServiceDiscoveryType, \ ConsulServiceDiscovery, TfConfigServiceDiscovery, ZKServiceDiscovery, MLPServiceDiscovery from monolith.native_training import gflags_utils from monolith.native_training.monolith_checkpoint_state_pb2 import MonolithCheckpointState from monolith.native_training.net_utils import AddressFamily from monolith.native_training import save_utils from monolith.native_training.mlp_utils import mlp_pass, add_mpi_exception_hook, MLPEnv FLAGS = flags.FLAGS old_isabs = os.path.isabs old_get_checkpoint_state = checkpoint_management.get_checkpoint_state def isabs(path: str): if path.startswith('hdfs:/'): return True else: return old_isabs(path) # [todo](fitz) part of function will move to Rec Platfrom, this is a tem solution def gen_get_checkpoint_state(): # ensure get the same value when call in the same process _lock = RLock() @tf_export("train.get_checkpoint_state") def _get_checkpoint_state_internal(checkpoint_dir, latest_filename=None): latest_filename = latest_filename or 'checkpoint' with _lock: checkpoint_state = old_get_checkpoint_state(checkpoint_dir, latest_filename) cur_cnt, max_retry = 0, 5 coord_checkpoint_filename = checkpoint_management._GetCheckpointFilename( checkpoint_dir, latest_filename) while checkpoint_state is None and tf.io.gfile.exists( coord_checkpoint_filename) and cur_cnt < max_retry: checkpoint_state = old_get_checkpoint_state(checkpoint_dir, latest_filename) cur_cnt += 1 if cur_cnt >= max_retry: raise Exception("read ckpt error!") try: if FLAGS.restore_ckpt is not None: if latest_filename != 'checkpoint' or checkpoint_state is None: return checkpoint_state dirname_from_ckpt_state = os.path.dirname( checkpoint_state.model_checkpoint_path) restore_ckpt = os.path.join(dirname_from_ckpt_state, os.path.basename(FLAGS.restore_ckpt)) restore_ckpt_file = os.path.join(checkpoint_dir, 'restore_ckpt') if restore_ckpt != checkpoint_state.model_checkpoint_path and restore_ckpt in checkpoint_state.all_model_checkpoint_paths: if FLAGS.mode == tf.estimator.ModeKeys.TRAIN: if not tf.io.gfile.exists(restore_ckpt_file): checkpoint_state.model_checkpoint_path = restore_ckpt else: logging.info( f'mode is {FLAGS.mode} and {restore_ckpt_file} file exists, keep {checkpoint_state.model_checkpoint_path}' ) else: logging.info(f'mode is {FLAGS.mode}, ignore {restore_ckpt_file}') checkpoint_state.model_checkpoint_path = restore_ckpt else: if restore_ckpt == checkpoint_state.model_checkpoint_path: logging.warning( f"model_checkpoint_path and {FLAGS.restore_ckpt} are identity" ) else: logging.warning( f"checkpoint {FLAGS.restore_ckpt} not exists in {checkpoint_dir}" ) if FLAGS.mode == tf.estimator.ModeKeys.TRAIN: if not tf.io.gfile.exists(restore_ckpt_file): checkpoint_state.model_checkpoint_path = restore_ckpt with tf.io.gfile.GFile(restore_ckpt_file, 'w') as gfile: gfile.write(restore_ckpt) checkpoint_filename = os.path.join(checkpoint_dir, latest_filename) file_io.atomic_write_string_to_file( checkpoint_filename, text_format.MessageToString(checkpoint_state)) logging.info( f'mode is {FLAGS.mode} and no {restore_ckpt_file} file exists, apply {restore_ckpt}' ) except flags._exceptions.UnparsedFlagAccessError as e: pass except Exception as e: logging.info(f"get_checkpoint_state: {e}") exc_type, exc_value, exc_traceback_obj = sys.exc_info() logging.error(f"exc_type: {exc_type}") logging.error(f"exc_value: {exc_value}") traceback.print_tb(exc_traceback_obj, limit=10) return checkpoint_state return _get_checkpoint_state_internal os.path.isabs = isabs checkpoint_management.get_checkpoint_state = gen_get_checkpoint_state() tf.train.get_checkpoint_state = checkpoint_management.get_checkpoint_state class ContainerType(Enum): DOCKER = 1 NATIVE = 2 @gflags_utils.update_by_flags @gflags_utils.extract_flags_decorator(remove_flags={'device_fn'}, is_nested=True) @dataclass class RunnerConfig(DistributedCpuTrainingConfig): """RunnerConfig for start a running. attributes: :param task: Name of the task class to run, or the run py file name :param tf_config: The TF_CONFIG env variable from primus, a json string. :param deep_insight_name: the deep_insight name, which should be identity during the whole job. :param discovery_type: service discovery type, which can be primus, consul and zk. :param zk_server: The ZK server :param zk_watch_address_family: We register both ipv4 and ipv6 when serving, and watch either ipv4 or ipv6 when synchronizing parameters. :param is_local: Whether is local running. :param enable_fid_dedup: Whether enable fid dedup in PS. :param bzid: In realtime native training, business id of the job. :param ps_replica_num: In realtime native training, the number of online ps replica. :param tf_grpc_worker_cache_threads: Env variable for TF_GRPC_WORKER_CACHE_THREADS :param monolith_grpc_worker_service_handler_multiplier: the multiplier of the number of default gprc service handler. :param params_override: Override to model params. A JSON string. :param base_name: Base name while enable realtime training. :param data_type: The input data proto type, can be Instance/Example/ExampleBatch. :param feature_list: The feature list name :param lagrangex_header: Whether has lagrangex_header :param sort_id: Whether has sort_id :param kafka_dump: Whether has kafka_dump :param kafka_dump_prefix: Whether has kafka_dump_prefix :param restore_dir: The directory where the model restore. :param restore_ckpt: The directory where the model restore. :param deep_insight_target: Deep insight target name, if there are multi target, use comma split. :param deep_insight_sample_ratio: Deep insight sample ratio. :param unified_serving: Whether serving cluster is deployed in unified mode :param use_estimator: Whether use estimator to run a model :param kafka_topics: kafka topics for streaming, when no forier and flink :param kafka_group_id: kafka group_id for streaming, when no forier and flink :param kafka_servers: kafka servers for streaming, when no forier and flink :param input_path: The input hdfs path for training/eval. :param wildcard: Wildcard for filter input files. :param start_date: The start date of training/eval, include. :param end_date: The end date of training/eval, exclude. :param start_hour: The start hour of training, include. :param end_hour: The end hour of training, exclude. :param is_hourly: Whether the input data is hourly partitioned. :param enable_dynamic_sharding: Whether switch on dynamic_sharding :param max_task_num_per_worker: Number of data reader task per worker, the same as primus setting :param disable_native_metrics: Whether disable tensorflow native metrics, such as auc, mse. """ task: str = None tf_config: str = None deep_insight_name: str = None discovery_type: ServiceDiscoveryType = ServiceDiscoveryType.CONSUL zk_server: str = None zk_watch_address_family: str = AddressFamily.IPV4 is_local: bool = False enable_fid_dedup: bool = False bzid: str = None ps_replica_num: int = None tf_grpc_worker_cache_threads: int = 16 monolith_grpc_worker_service_handler_multiplier: float = 1.0 params_override: str = None base_name: str = None data_type: str = None feature_list: str = None lagrangex_header: bool = False sort_id: bool = True kafka_dump: bool = False kafka_dump_prefix: bool = False restore_dir: str = None restore_ckpt: str = None deep_insight_target: str = None deep_insight_sample_ratio: float = None unified_serving: bool = False use_estimator: bool = False kafka_topics: str = None kafka_group_id: str = None kafka_servers: str = None input_path: str = None is_hourly: bool = False wildcard: str = None start_date: str = None end_date: str = None start_hour: int = None end_hour: int = None enable_dynamic_sharding: bool = False max_task_num_per_worker: int = 1 disable_native_metrics: bool = True def __post_init__(self): mlp_pass() add_mpi_exception_hook() try: gflags_utils.update(self) except: logging.info("update RunnerConfig failed") if self.enable_gpu_training and self.enable_partial_sync_training: if (self.index <= 0 or self.index is None) and self.server_type == 'worker': self.index = int(os.environ.get('OMPI_COMM_WORLD_RANK') or '0') if self.kafka_topics: if isinstance(self.kafka_topics, str): self.kafka_topics = self.kafka_topics.split(',') FLAGS.kafka_topics = ','.join(self.kafka_topics) if self.kafka_group_id: FLAGS.kafka_group_id = self.kafka_group_id if self.kafka_servers: FLAGS.kafka_servers = self.kafka_servers assert self.zk_watch_address_family in [ AddressFamily.IPV4, AddressFamily.IPV6 ] try: if self.restore_ckpt != FLAGS.restore_ckpt and FLAGS.restore_ckpt is None: FLAGS.restore_ckpt = self.restore_ckpt except flags._exceptions.UnparsedFlagAccessError: pass is_chief = self.is_local or (self.server_type == "worker" and self.index == 0) if self.restore_dir is not None and len(self.restore_dir) > 0: if is_chief: self._copy_ckpt_file() else: monolith_checkpoint_filename = os.path.join( self.model_dir, save_utils.MONOLITH_CKPT_STATE_FILE_NAME) while True: if tf.io.gfile.exists(monolith_checkpoint_filename): break logging.info("Waiting for chief setting up restore_dir...") time.sleep(30) def _copy_ckpt_file(self): logging.info(f"restore_dir is {self.restore_dir}") src_file = os.path.join(self.restore_dir, 'checkpoint') if tf.io.gfile.exists(src_file): if not tf.io.gfile.exists(self.model_dir): tf.io.gfile.makedirs(self.model_dir) logging.info(f"makedirs {self.model_dir} done!") # because we fix os.path.isabs, path startswith 'hdfs:/' is view as abs path # 1) get_checkpoint_state will add restore_dir for relative path to make a abs path # if it is already abs path (including hdfs path), keep it as is # 2) for path start with 'hdfs:/' will seam as abs path, and do not add prefix any more try: restore_checkpoint_state = old_get_checkpoint_state( self.restore_dir) # abs path if self.restore_ckpt is None: model_checkpoint_path = restore_checkpoint_state.model_checkpoint_path else: dirname = os.path.dirname( restore_checkpoint_state.model_checkpoint_path) basename = os.path.basename(self.restore_ckpt) model_checkpoint_path = os.path.join(dirname, basename) if model_checkpoint_path not in restore_checkpoint_state.all_model_checkpoint_paths: logging.warning( f'{model_checkpoint_path} is not in restore all_model_checkpoint_paths' ) model_checkpoint_path = restore_checkpoint_state.model_checkpoint_path checkpoint_state = CheckpointState( model_checkpoint_path=model_checkpoint_path) checkpoint_state.all_model_checkpoint_paths.append( model_checkpoint_path) except Exception as e: logging.warning(e) return # we use the checkpoint file as a flag, if it exists, the restore_dir ckpt will not take action checkpoint_filename = os.path.join(self.model_dir, 'checkpoint') if tf.io.gfile.exists(checkpoint_filename): return try: file_io.atomic_write_string_to_file( checkpoint_filename, text_format.MessageToString(checkpoint_state)) logging.info("write checkpoint file done!") # write the restore ckpt to monolith_checkpoint, so that the previous ckpts would not remove by ckpt mamager monolith_checkpoint_filename = os.path.join( self.model_dir, save_utils.MONOLITH_CKPT_STATE_FILE_NAME) monolith_ckpt_state = save_utils.get_monolith_checkpoint_state( self.restore_dir, remove_invalid_path=True) or MonolithCheckpointState() exempt_model_checkpoint_paths = monolith_ckpt_state.exempt_model_checkpoint_paths del exempt_model_checkpoint_paths[:] if tf.io.gfile.exists(monolith_checkpoint_filename): # in case there is a 'monolith_checkpoint' file file_content = file_io.read_file_to_string( monolith_checkpoint_filename) text_format.Merge(file_content, monolith_ckpt_state) for restore_ckpt_path in checkpoint_state.all_model_checkpoint_paths: if restore_ckpt_path not in exempt_model_checkpoint_paths: exempt_model_checkpoint_paths.append(restore_ckpt_path) file_io.atomic_write_string_to_file( monolith_checkpoint_filename, text_format.MessageToString(monolith_ckpt_state), overwrite=True) logging.info("write monolith checkpoint file done!") except Exception as e: logging.warning(e) logging.warning(f"checkpoint exist in {self.model_dir}") else: logging.warning(f"no checkpoint in {self.restore_dir}") def get_discovery(runner_conf: RunnerConfig, psm: str = None): if runner_conf.is_local: discovery = None elif runner_conf.discovery_type == ServiceDiscoveryType.PRIMUS: assert runner_conf.tf_config is not None tf_config = json.loads(runner_conf.tf_config) discovery = TfConfigServiceDiscovery(tf_config) runner_conf.server_type = discovery.server_type runner_conf.index = discovery.index elif runner_conf.discovery_type == ServiceDiscoveryType.CONSUL: # For async training, PS discovery is inside the process. discovery = ConsulServiceDiscovery(psm) elif runner_conf.discovery_type == ServiceDiscoveryType.MLP: # For async training, PS discovery is inside the process. discovery = MLPServiceDiscovery() else: discovery = ZKServiceDiscovery(runner_conf.deep_insight_name, runner_conf.zk_server) return discovery @contextmanager def monolith_discovery(runner_conf: RunnerConfig): discovery = None try: if runner_conf.is_local: yield None else: from monolith.native_training import env_utils psm = env_utils.generate_psm_from_uuid(runner_conf.uuid) discovery = get_discovery(runner_conf, psm) logging.info('enter monolith_discovery!') yield discovery except Exception as e: raise e finally: if discovery is not None: discovery.close() logging.info('exit monolith_discovery!') ================================================ FILE: monolith/native_training/runner_utils_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from absl import flags, logging import json import os from google.protobuf import text_format from kazoo.handlers.threading import KazooTimeoutError import tensorflow as tf from tensorflow.python.lib.io import file_io from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState from monolith.native_training.runner_utils import RunnerConfig, get_discovery from monolith.native_training.service_discovery import ServiceDiscoveryType, \ ConsulServiceDiscovery, TfConfigServiceDiscovery, ZKServiceDiscovery class RunnerUtilsTest(tf.test.TestCase): def test_get_discovery_local(self): config = RunnerConfig(is_local=True) discovery = get_discovery(config) config.is_local = False self.assertEqual(discovery, None) def test_get_discovery_primus(self): tf_config = { 'cluster': { 'ps': ['localhost:1111', 'localhost:1112'], 'worker': ['localhost:1113', 'localhost:1114'], 'chief': ['localhost:1115'] }, 'task': { 'type': 'chief', 'index': 0 } } config = config = RunnerConfig(is_local=False, tf_config=json.dumps(tf_config), discovery_type=ServiceDiscoveryType.PRIMUS) discovery = get_discovery(config) self.assertEqual(isinstance(discovery, TfConfigServiceDiscovery), True) def test_get_discovery_consul(self): psm = 'data.monolith.123456' config = RunnerConfig(is_local=False, discovery_type=ServiceDiscoveryType.CONSUL) discovery = get_discovery(config, psm) self.assertEqual(isinstance(discovery, ConsulServiceDiscovery), True) def test_get_discovery_zk(self): config = RunnerConfig(is_local=False, discovery_type=ServiceDiscoveryType.ZK, zk_server="127.0.0.1:0") try: discovery = get_discovery(config) self.assertEqual(isinstance(discovery, ZKServiceDiscovery), True) except KazooTimeoutError as e: logging.info('kazoo example: {}'.format(e)) def test_copy_ckpt(self): restore_dir = os.path.join(os.environ["TEST_TMPDIR"], "runner_utils_test", "restore_dir") if not tf.io.gfile.exists(restore_dir): tf.io.gfile.makedirs(restore_dir) ckpt = CheckpointState(model_checkpoint_path='model.ckpt-61') ckpt.all_model_checkpoint_paths.extend( ['model.ckpt-61', 'model.ckpt-30', 'model.ckpt-0']) file_io.atomic_write_string_to_file(os.path.join(restore_dir, 'checkpoint'), text_format.MessageToString(ckpt)) model_dir = os.path.join(os.environ["TEST_TMPDIR"], "runner_utils_test", "model_dir") if not tf.io.gfile.exists(model_dir): tf.io.gfile.makedirs(model_dir) config = RunnerConfig(is_local=True, restore_dir=restore_dir, model_dir=model_dir, restore_ckpt='model.ckpt-30') ckpt2 = tf.train.get_checkpoint_state(model_dir) self.assertTrue( tf.io.gfile.exists(os.path.join(model_dir, 'monolith_checkpoint'))) self.assertTrue(tf.io.gfile.exists(os.path.join(model_dir, 'restore_ckpt'))) self.assertEqual(os.path.basename(ckpt2.model_checkpoint_path), 'model.ckpt-30') # Make sure other workers can go through once chief init the dir config = RunnerConfig(server_type="worker", index=2, restore_dir=restore_dir, model_dir=model_dir) if __name__ == "__main__": tf.compat.v1.disable_eager_execution() tf.test.main() ================================================ FILE: monolith/native_training/runtime/allocator/BUILD ================================================ load("@rules_cc//cc:defs.bzl", "cc_library", "cc_test") package(default_visibility = ["//monolith/native_training/runtime:__subpackages__"]) cc_library( name = "block_allocator", srcs = ["block_allocator.cc"], hdrs = ["block_allocator.h"], deps = [ "//monolith/native_training/runtime/concurrency:xorshift", "@com_google_absl//absl/container:flat_hash_map", "@com_google_glog//:glog", ], ) cc_test( name = "block_allocator_test", srcs = ["block_allocator_test.cc"], deps = [ ":block_allocator", "@com_google_googletest//:gtest_main", ], ) ================================================ FILE: monolith/native_training/runtime/allocator/block_allocator.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "monolith/native_training/runtime/allocator/block_allocator.h" #include "absl/container/flat_hash_map.h" #include "absl/synchronization/mutex.h" namespace monolith { namespace allocator { const int BlockAllocator::kStartBlcokSize = 1024; void* BlockAllocator::Allocate(size_t cl) { size_t size = Align(cl); if (size <= free_) { void* ptr = reinterpret_cast(free_ptr_); free_ptr_ += size; free_ -= size; return ptr; } else { const size_t block_size = std::max(current_block_size_, size); if (current_block_size_ < max_block_size_) { current_block_size_ *= 2; } allocated_size_ += block_size; blocks_.push_back(std::make_unique(block_size)); char* block_ptr = blocks_.back().get(); free_ptr_ = block_ptr + size; free_ = block_size - size; return reinterpret_cast(block_ptr); } } void BlockAllocator::DeallocateAll() { blocks_.clear(); free_ = 0; allocated_size_ = 0; } BlockAllocator* GetThreadLocalAllocator(size_t key) { thread_local absl::flat_hash_map> m; auto it = m.find(key); if (it == m.end()) { auto it2 = m.insert({key, std::make_unique()}); return it2.first->second.get(); } else { return it->second.get(); } } } // namespace allocator } // namespace monolith ================================================ FILE: monolith/native_training/runtime/allocator/block_allocator.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_NATIVE_TRAINING_RUNTIME_ALLOCATOR_BLOCK_ALLOCATOR_H_ #define MONOLITH_NATIVE_TRAINING_RUNTIME_ALLOCATOR_BLOCK_ALLOCATOR_H_ #include #include #include #include #include #include "absl/container/flat_hash_map.h" #include "glog/logging.h" #include "monolith/native_training/runtime/concurrency/xorshift.h" namespace monolith { namespace allocator { //------------------------------------------------------------------- // It is not thread safe! //------------------------------------------------------------------- class BlockAllocator { public: static const int kStartBlcokSize; BlockAllocator() : current_block_size_(kStartBlcokSize), allocated_size_(0), free_ptr_(nullptr), free_(0) {} BlockAllocator(const BlockAllocator &) = delete; BlockAllocator &operator=(const BlockAllocator &) = delete; ~BlockAllocator() {} // BlockAllocator owns the pointer. void *Allocate(size_t cl); void DeallocateAll(); size_t AllocatedSize() { return allocated_size_; } private: size_t Align(size_t size) { return (size + (kAlign - 1)) & ~(kAlign - 1); } // This must be the power of 2. static const size_t kAlign = 8; std::vector> blocks_; size_t current_block_size_; size_t max_block_size_ = 1 * 1024 * 1024; size_t allocated_size_; char *free_ptr_; size_t free_; }; // Thread safe version of BlockingAllocator by sharding. class TSBlockAllocator { public: explicit TSBlockAllocator(int num_shards = 8) : num_shards_(num_shards) { for (int i = 0; i < num_shards_; ++i) { mus_.push_back(std::make_unique()); allocs_.push_back(std::make_unique()); } } // BlockAllocator owns the pointer. void *Allocate(size_t cl) { const int shard = concurrency::XorShift::Rand32ThreadSafe() % num_shards_; { absl::MutexLock l(mus_[shard].get()); return allocs_[shard]->Allocate(cl); } } void DeallocateAll() { for (int shard = 0; shard < num_shards_; ++shard) { absl::MutexLock l(mus_[shard].get()); allocs_[shard]->DeallocateAll(); } } size_t AllocatedSize() { size_t allocated_size = 0; for (int shard = 0; shard < num_shards_; ++shard) { absl::MutexLock l(mus_[shard].get()); allocated_size += allocs_[shard]->AllocatedSize(); } return allocated_size; } private: int num_shards_; std::vector> mus_; std::vector> allocs_; }; // This defines an address space for EmbeddingHashTable's RawEntry, it supports // up to 2^32 entries. struct EntryAddress { // 2^3 = 8 shards per thread-safe embedding block allocator uint32_t shard_id : 3; // No more than 2^17 = 131072 blocks per embedding allocator uint32_t block_id : 17; // 2^12 = 4096 entries per block uint32_t entry_id : 12; }; // Thread compatible class EmbeddingBlockAllocator { public: // This must be the power of 2. static const size_t kAlign = 8; static const size_t kMaxBlockNum = 1 << 17; static const size_t kMaxEntryNum = 1 << 12; explicit EmbeddingBlockAllocator(size_t entry_byte_size) : entry_byte_size_aligned_(Align(entry_byte_size)), block_size_(Align(entry_byte_size) * kMaxEntryNum) { Reset(); } ~EmbeddingBlockAllocator() { FreeBlocks(); } EmbeddingBlockAllocator(const EmbeddingBlockAllocator &) = delete; EmbeddingBlockAllocator &operator=(const EmbeddingBlockAllocator &) = delete; void *GetEntryPointer(EntryAddress entry_address) const { return cur_block_head_.load( std::memory_order_relaxed)[entry_address.block_id] + entry_byte_size_aligned_ * entry_address.entry_id; } EntryAddress AllocateOne() { EntryAddress addr; if (entry_id_ < kMaxEntryNum) { addr.block_id = blocks_->size() - 1; addr.entry_id = entry_id_; entry_id_ += 1; } else { if (blocks_->size() == kMaxBlockNum) { throw std::bad_alloc(); } if (blocks_->size() == blocks_->capacity()) { auto new_blocks = std::make_unique>(); new_blocks->reserve(blocks_->capacity() * 2); new_blocks->insert(new_blocks->begin(), blocks_->begin(), blocks_->end()); cur_block_head_.store(new_blocks->data()); blocks_snapshots_.push_back(std::move(blocks_)); blocks_ = std::move(new_blocks); } allocated_size_ += block_size_; blocks_->push_back(new char[block_size_]); addr.block_id = blocks_->size() - 1; addr.entry_id = 0; entry_id_ = 1; } return addr; } void DeallocateAll() { Reset(); } size_t AllocatedSize() { return allocated_size_; } private: size_t Align(size_t size) const { return (size + (kAlign - 1)) & ~(kAlign - 1); } void Reset() { FreeBlocks(); blocks_snapshots_.clear(); blocks_snapshots_.shrink_to_fit(); blocks_ = std::make_unique>(); blocks_->reserve(1); allocated_size_ = 0; entry_id_ = kMaxEntryNum; cur_block_head_.store(blocks_->data()); } void FreeBlocks() { if (blocks_) { for (char *block : *blocks_) { delete[] block; } blocks_ = nullptr; } } std::unique_ptr> blocks_; // Stores blocks_.data(). Should be always valid. std::atomic cur_block_head_; // Used to save blocks snapshots, used for lock-free looking up std::vector>> blocks_snapshots_; size_t entry_byte_size_aligned_; size_t block_size_; size_t allocated_size_; size_t entry_id_; }; // Thread safe version of EmbeddingBlockAllocator by sharding. class TSEmbeddingBlockAllocator { public: explicit TSEmbeddingBlockAllocator(int64_t entry_byte_size) { for (int i = 0; i < kNumShards; ++i) { mus_.push_back(std::make_unique()); allocs_.push_back( std::make_unique(entry_byte_size)); } } void *GetEntryPointer(EntryAddress address) const { return allocs_[address.shard_id]->GetEntryPointer(address); } EntryAddress AllocateOne() { const int shard = concurrency::XorShift::Rand32ThreadSafe() % kNumShards; EntryAddress addr; { absl::WriterMutexLock l(mus_[shard].get()); addr = allocs_[shard]->AllocateOne(); } addr.shard_id = shard; return addr; } void DeallocateAll() { for (int shard = 0; shard < kNumShards; ++shard) { absl::MutexLock l(mus_[shard].get()); allocs_[shard]->DeallocateAll(); } } size_t AllocatedSize() { size_t allocated_size = 0; for (int shard = 0; shard < kNumShards; ++shard) { absl::MutexLock l(mus_[shard].get()); allocated_size += allocs_[shard]->AllocatedSize(); } return allocated_size; } private: static const size_t kNumShards = 1 << 3; std::vector> mus_; std::vector> allocs_; }; } // namespace allocator } // namespace monolith #endif // MONOLITH_NATIVE_TRAINING_RUNTIME_ALLOCATOR_BLOCK_ALLOCATOR_H_ ================================================ FILE: monolith/native_training/runtime/allocator/block_allocator_test.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/runtime/allocator/block_allocator.h" #include #include #include #include "gmock/gmock.h" #include "gtest/gtest.h" namespace monolith { namespace allocator { namespace { TEST(BlockAllocatorTest, Basic) { std::unique_ptr allocator = std::make_unique(); size_t size1 = 10; char* ptr1 = reinterpret_cast(allocator->Allocate(size1)); EXPECT_NE(ptr1, nullptr); ptr1[size1 - 1] = 'x'; EXPECT_EQ(allocator->AllocatedSize(), BlockAllocator::kStartBlcokSize); size_t size2 = 12; char* ptr2 = reinterpret_cast(allocator->Allocate(size2)); // Test Align. EXPECT_EQ(ptr2 - ptr1, 16); // Can't fit into current block anymore due to align. size_t size3 = BlockAllocator::kStartBlcokSize - size1 - size2; char* ptr3 = reinterpret_cast(allocator->Allocate(size3)); EXPECT_NE(ptr3, nullptr); // Next block will be doubled. EXPECT_EQ(allocator->AllocatedSize(), BlockAllocator::kStartBlcokSize * 3); allocator->DeallocateAll(); EXPECT_EQ(allocator->AllocatedSize(), 0); } TEST(BlockAllocatorTest, AllocateLarge) { size_t block_size = 1 << 20; auto allocator = std::make_unique(); char* p = reinterpret_cast(allocator->Allocate(block_size)); std::memset(p, 0, block_size); } TEST(BlockAllocatorTest, TSBlockAllocator) { TSBlockAllocator alloc; auto func = [&alloc]() { for (int i = 0; i < 100; ++i) { char* p = reinterpret_cast(alloc.Allocate(16)); std::memset(p, 0, 16); } }; std::vector ths; for (int i = 0; i < 15; ++i) { ths.push_back(std::thread(func)); } for (auto& th : ths) { th.join(); } alloc.DeallocateAll(); EXPECT_THAT(alloc.AllocatedSize(), 0); } TEST(EmbeddingBlockAllocatorTest, EmbeddingBlockAllocatorAllocateMany) { EmbeddingBlockAllocator alloc(8); for (int i = 0; i < EmbeddingBlockAllocator::kMaxEntryNum * 10; ++i) { auto addr = alloc.AllocateOne(); void* real_addr = alloc.GetEntryPointer(addr); std::memset(real_addr, 0, 8); } } TEST(EmbeddingBlockAllocatorTest, TSEmbeddingBlockAllocator) { TSEmbeddingBlockAllocator alloc(16); auto func = [&alloc]() { for (int i = 0; i < 100; ++i) { EntryAddress p = alloc.AllocateOne(); std::memset(static_cast(alloc.GetEntryPointer(p)), 0, 16); } }; std::vector threads; for (int i = 0; i < 15; ++i) { threads.emplace_back(func); } for (auto& t : threads) { t.join(); } alloc.DeallocateAll(); EXPECT_THAT(alloc.AllocatedSize(), 0); } } // namespace } // namespace allocator } // namespace monolith ================================================ FILE: monolith/native_training/runtime/common/BUILD ================================================ load("@rules_cc//cc:defs.bzl", "cc_library", "cc_test") package(default_visibility = ["//monolith/native_training/runtime:__subpackages__"]) cc_library( name = "cpu_info", srcs = ["cpu_info.cc"], hdrs = ["cpu_info.h"], deps = [ ], ) cc_library( name = "metrics_internal_deps", ) cc_library( name = "metrics", srcs = ["metrics.cc"], hdrs = ["metrics.h"], visibility = ["//visibility:public"], deps = [ ":metrics_internal_deps", "@com_google_glog//:glog", ], ) cc_library( name = "linalg_utils", hdrs = ["linalg_utils.h"], visibility = ["//visibility:public"], ) cc_test( name = "linalg_utils_test", srcs = [ "linalg_utils_test.cc", ], deps = [ ":linalg_utils", "@com_google_glog//:glog", "@com_google_googletest//:gtest_main", ], ) ================================================ FILE: monolith/native_training/runtime/common/cpu_info.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "monolith/native_training/runtime/common/cpu_info.h" #include #include #include #if defined(__x86_64__) || defined(__i386__) #define GETCPUID(a, b, c, d, a_inp, c_inp) \ asm("mov %%rbx, %%rdi\n" \ "cpuid\n" \ "xchg %%rdi, %%rbx\n" \ : "=a"(a), "=D"(b), "=c"(c), "=d"(d) \ : "a"(a_inp), "2"(c_inp)) #endif namespace monolith { class CPUIDInfo; void InitCPUIDInfo(); CPUIDInfo *cpuid = nullptr; #if defined(__x86_64__) || defined(__i386__) int GetXCR0EAX() { int eax, edx; asm("XGETBV" : "=a"(eax), "=d"(edx) : "c"(0)); return eax; } // Structure for basic CPUID info class CPUIDInfo { public: CPUIDInfo() : have_adx_(0), have_aes_(0), have_avx_(0), have_avx2_(0), have_avx512f_(0), have_avx512cd_(0), have_avx512er_(0), have_avx512pf_(0), have_avx512vl_(0), have_avx512bw_(0), have_avx512dq_(0), have_avx512vbmi_(0), have_avx512ifma_(0), have_avx512_4vnniw_(0), have_avx512_4fmaps_(0), have_bmi1_(0), have_bmi2_(0), have_cmov_(0), have_cmpxchg16b_(0), have_cmpxchg8b_(0), have_f16c_(0), have_fma_(0), have_mmx_(0), have_pclmulqdq_(0), have_popcnt_(0), have_prefetchw_(0), have_prefetchwt1_(0), have_rdrand_(0), have_rdseed_(0), have_smap_(0), have_sse_(0), have_sse2_(0), have_sse3_(0), have_sse4_1_(0), have_sse4_2_(0), have_ssse3_(0), have_hypervisor_(0) {} static void Initialize() { // Initialize cpuid struct cpuid = new CPUIDInfo; uint32_t eax, ebx, ecx, edx; // Get vendor string (issue CPUID with eax = 0) GETCPUID(eax, ebx, ecx, edx, 0, 0); cpuid->vendor_str_.append(reinterpret_cast(&ebx), 4); cpuid->vendor_str_.append(reinterpret_cast(&edx), 4); cpuid->vendor_str_.append(reinterpret_cast(&ecx), 4); // To get general information and extended features we send eax = 1 and // ecx = 0 to cpuid. The response is returned in eax, ebx, ecx and edx. // (See Intel 64 and IA-32 Architectures Software Developer's Manual // Volume 2A: Instruction Set Reference, A-M CPUID). GETCPUID(eax, ebx, ecx, edx, 1, 0); cpuid->model_num_ = static_cast((eax >> 4) & 0xf); cpuid->family_ = static_cast((eax >> 8) & 0xf); cpuid->have_aes_ = (ecx >> 25) & 0x1; cpuid->have_cmov_ = (edx >> 15) & 0x1; cpuid->have_cmpxchg16b_ = (ecx >> 13) & 0x1; cpuid->have_cmpxchg8b_ = (edx >> 8) & 0x1; cpuid->have_mmx_ = (edx >> 23) & 0x1; cpuid->have_pclmulqdq_ = (ecx >> 1) & 0x1; cpuid->have_popcnt_ = (ecx >> 23) & 0x1; cpuid->have_rdrand_ = (ecx >> 30) & 0x1; cpuid->have_sse2_ = (edx >> 26) & 0x1; cpuid->have_sse3_ = ecx & 0x1; cpuid->have_sse4_1_ = (ecx >> 19) & 0x1; cpuid->have_sse4_2_ = (ecx >> 20) & 0x1; cpuid->have_sse_ = (edx >> 25) & 0x1; cpuid->have_ssse3_ = (ecx >> 9) & 0x1; cpuid->have_hypervisor_ = (ecx >> 31) & 1; const uint64_t xcr0_xmm_mask = 0x2; const uint64_t xcr0_ymm_mask = 0x4; const uint64_t xcr0_maskreg_mask = 0x20; const uint64_t xcr0_zmm0_15_mask = 0x40; const uint64_t xcr0_zmm16_31_mask = 0x80; const uint64_t xcr0_avx_mask = xcr0_xmm_mask | xcr0_ymm_mask; const uint64_t xcr0_avx512_mask = xcr0_avx_mask | xcr0_maskreg_mask | xcr0_zmm0_15_mask | xcr0_zmm16_31_mask; const bool have_avx = // Does the OS support XGETBV instruction use by applications? ((ecx >> 27) & 0x1) && // Does the OS save/restore XMM and YMM state? ((GetXCR0EAX() & xcr0_avx_mask) == xcr0_avx_mask) && // Is AVX supported in hardware? ((ecx >> 28) & 0x1); const bool have_avx512 = // Does the OS support XGETBV instruction use by applications? ((ecx >> 27) & 0x1) && // Does the OS save/restore ZMM state? ((GetXCR0EAX() & xcr0_avx512_mask) == xcr0_avx512_mask); cpuid->have_avx_ = have_avx; cpuid->have_fma_ = have_avx && ((ecx >> 12) & 0x1); cpuid->have_f16c_ = have_avx && ((ecx >> 29) & 0x1); // Get standard level 7 structured extension features (issue CPUID with // eax = 7 and ecx= 0), which is required to check for AVX2 support as // well as other Haswell (and beyond) features. (See Intel 64 and IA-32 // Architectures Software Developer's Manual Volume 2A: Instruction Set // Reference, A-M CPUID). GETCPUID(eax, ebx, ecx, edx, 7, 0); cpuid->have_adx_ = (ebx >> 19) & 0x1; cpuid->have_avx2_ = have_avx && ((ebx >> 5) & 0x1); cpuid->have_bmi1_ = (ebx >> 3) & 0x1; cpuid->have_bmi2_ = (ebx >> 8) & 0x1; cpuid->have_prefetchwt1_ = ecx & 0x1; cpuid->have_rdseed_ = (ebx >> 18) & 0x1; cpuid->have_smap_ = (ebx >> 20) & 0x1; cpuid->have_avx512f_ = have_avx512 && ((ebx >> 16) & 0x1); cpuid->have_avx512cd_ = have_avx512 && ((ebx >> 28) & 0x1); cpuid->have_avx512er_ = have_avx512 && ((ebx >> 27) & 0x1); cpuid->have_avx512pf_ = have_avx512 && ((ebx >> 26) & 0x1); cpuid->have_avx512vl_ = have_avx512 && ((ebx >> 31) & 0x1); cpuid->have_avx512bw_ = have_avx512 && ((ebx >> 30) & 0x1); cpuid->have_avx512dq_ = have_avx512 && ((ebx >> 17) & 0x1); cpuid->have_avx512vbmi_ = have_avx512 && ((ecx >> 1) & 0x1); cpuid->have_avx512ifma_ = have_avx512 && ((ebx >> 21) & 0x1); cpuid->have_avx512_4vnniw_ = have_avx512 && ((edx >> 2) & 0x1); cpuid->have_avx512_4fmaps_ = have_avx512 && ((edx >> 3) & 0x1); } static bool TestFeature(CPUFeature feature) { InitCPUIDInfo(); // clang-format off switch (feature) { case ADX: return cpuid->have_adx_; case AES: return cpuid->have_aes_; case AVX2: return cpuid->have_avx2_; case AVX: return cpuid->have_avx_; case AVX512F: return cpuid->have_avx512f_; case AVX512CD: return cpuid->have_avx512cd_; case AVX512PF: return cpuid->have_avx512pf_; case AVX512ER: return cpuid->have_avx512er_; case AVX512VL: return cpuid->have_avx512vl_; case AVX512BW: return cpuid->have_avx512bw_; case AVX512DQ: return cpuid->have_avx512dq_; case AVX512VBMI: return cpuid->have_avx512vbmi_; case AVX512IFMA: return cpuid->have_avx512ifma_; case AVX512_4VNNIW: return cpuid->have_avx512_4vnniw_; case AVX512_4FMAPS: return cpuid->have_avx512_4fmaps_; case BMI1: return cpuid->have_bmi1_; case BMI2: return cpuid->have_bmi2_; case CMOV: return cpuid->have_cmov_; case CMPXCHG16B: return cpuid->have_cmpxchg16b_; case CMPXCHG8B: return cpuid->have_cmpxchg8b_; case F16C: return cpuid->have_f16c_; case FMA: return cpuid->have_fma_; case MMX: return cpuid->have_mmx_; case PCLMULQDQ: return cpuid->have_pclmulqdq_; case POPCNT: return cpuid->have_popcnt_; case PREFETCHW: return cpuid->have_prefetchw_; case PREFETCHWT1: return cpuid->have_prefetchwt1_; case RDRAND: return cpuid->have_rdrand_; case RDSEED: return cpuid->have_rdseed_; case SMAP: return cpuid->have_smap_; case SSE2: return cpuid->have_sse2_; case SSE3: return cpuid->have_sse3_; case SSE4_1: return cpuid->have_sse4_1_; case SSE4_2: return cpuid->have_sse4_2_; case SSE: return cpuid->have_sse_; case SSSE3: return cpuid->have_ssse3_; case HYPERVISOR: return cpuid->have_hypervisor_; default:break; } // clang-format on return false; } std::string vendor_str() const { return vendor_str_; } int family() const { return family_; } int model_num() { return model_num_; } private: int have_adx_ : 1; int have_aes_ : 1; int have_avx_ : 1; int have_avx2_ : 1; int have_avx512f_ : 1; int have_avx512cd_ : 1; int have_avx512er_ : 1; int have_avx512pf_ : 1; int have_avx512vl_ : 1; int have_avx512bw_ : 1; int have_avx512dq_ : 1; int have_avx512vbmi_ : 1; int have_avx512ifma_ : 1; int have_avx512_4vnniw_ : 1; int have_avx512_4fmaps_ : 1; int have_bmi1_ : 1; int have_bmi2_ : 1; int have_cmov_ : 1; int have_cmpxchg16b_ : 1; int have_cmpxchg8b_ : 1; int have_f16c_ : 1; int have_fma_ : 1; int have_mmx_ : 1; int have_pclmulqdq_ : 1; int have_popcnt_ : 1; int have_prefetchw_ : 1; int have_prefetchwt1_ : 1; int have_rdrand_ : 1; int have_rdseed_ : 1; int have_smap_ : 1; int have_sse_ : 1; int have_sse2_ : 1; int have_sse3_ : 1; int have_sse4_1_ : 1; int have_sse4_2_ : 1; int have_ssse3_ : 1; int have_hypervisor_ : 1; std::string vendor_str_; int family_; int model_num_; }; std::once_flag cpuid_once_flag; void InitCPUIDInfo() { // This ensures that CPUIDInfo::Initialize() is called exactly // once regardless of how many threads concurrently call us std::call_once(cpuid_once_flag, CPUIDInfo::Initialize); } #endif bool TestCPUFeature(CPUFeature feature) { #if defined(__x86_64__) || defined(__i386__) return CPUIDInfo::TestFeature(feature); #else return false; #endif } std::string CPUVendorIDString() { #if defined(__x86_64__) || defined(__i386__) InitCPUIDInfo(); return cpuid->vendor_str(); #else return ""; #endif } int CPUFamily() { #if defined(__x86_64__) || defined(__i386__) InitCPUIDInfo(); return cpuid->family(); #else return 0; #endif } int CPUModelNum() { #if defined(__x86_64__) || defined(__i386__) InitCPUIDInfo(); return cpuid->model_num(); #else return 0; #endif } int CPUIDNumSMT() { // https://software.intel.com/en-us/articles/intel-64-architecture-processor-topology-enumeration // https://software.intel.com/en-us/articles/intel-sdm (Vol 3A) // Section: Detecting Hardware Multi-threads Support and Topology // Uses CPUID Leaf 11 to enumerate system topology on Intel x86 architectures // Other cases not supported #if defined(__x86_64__) || defined(__i386__) uint32_t eax, ebx, ecx, edx; // Check if system supports Leaf 11 GETCPUID(eax, ebx, ecx, edx, 0, 0); if (eax >= 11) { // 1) Leaf 11 available? CPUID.(EAX=11, ECX=0):EBX != 0 // 2) SMT_Mask_Width = CPUID.(EAX=11, ECX=0):EAX[4:0] if CPUID.(EAX=11, // ECX=0):ECX[15:8] is 1 GETCPUID(eax, ebx, ecx, edx, 11, 0); if (ebx != 0 && ((ecx & 0xff00) >> 8) == 1) { return 1 << (eax & 0x1f); // 2 ^ SMT_Mask_Width } } return 0; #else return 0; #endif } // If the CPU feature isn't present, log a fatal error. void CheckFeatureOrDie(monolith::CPUFeature feature, const std::string &feature_name) { if (!monolith::TestCPUFeature(feature)) { std::cerr << "The library was compiled to use " << feature_name << " instructions, but these aren't available on your machine."; std::abort(); } } void RunCPUGuard() { #if defined(_ENABLE_AVX) && defined(__AVX__) CheckFeatureOrDie(monolith::CPUFeature::AVX, "AVX"); #endif } } // namespace monolith ================================================ FILE: monolith/native_training/runtime/common/cpu_info.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #ifndef MONOLITH_MONOLITH_NATIVE_TRAINING_RUNTIME_COMMON_CPU_INFO_H_ #define MONOLITH_MONOLITH_NATIVE_TRAINING_RUNTIME_COMMON_CPU_INFO_H_ #include namespace monolith { // Mostly ISA related features that we care about enum CPUFeature { // Do not change numeric assignments. MMX = 0, SSE = 1, SSE2 = 2, SSE3 = 3, SSSE3 = 4, SSE4_1 = 5, SSE4_2 = 6, CMOV = 7, CMPXCHG8B = 8, CMPXCHG16B = 9, POPCNT = 10, AES = 11, AVX = 12, RDRAND = 13, AVX2 = 14, FMA = 15, F16C = 16, PCLMULQDQ = 17, RDSEED = 18, ADX = 19, SMAP = 20, // Prefetch Vector Data Into Caches with Intent to Write and T1 Hint // http://www.felixcloutier.com/x86/PREFETCHWT1.html. // You probably want PREFETCHW instead. PREFETCHWT1 = 21, BMI1 = 22, BMI2 = 23, HYPERVISOR = 25, // 0 when on a real CPU, 1 on (well-behaved) hypervisor. // Prefetch Data into Caches in Anticipation of a Write (3D Now!). // http://www.felixcloutier.com/x86/PREFETCHW.html PREFETCHW = 26, // AVX-512: 512-bit vectors (plus masking, etc.) in Knights Landing, // Skylake // Xeon, etc.; each of these entries is a different subset of // instructions, // various combinations of which occur on various CPU types. AVX512F = 27, // Foundation AVX512CD = 28, // Conflict detection AVX512ER = 29, // Exponential and reciprocal AVX512PF = 30, // Prefetching AVX512VL = 31, // Shorter vector lengths AVX512BW = 32, // Byte and word AVX512DQ = 33, // Dword and qword AVX512VBMI = 34, // Bit manipulation AVX512IFMA = 35, // Integer multiply-add AVX512_4VNNIW = 36, // Integer neural network AVX512_4FMAPS = 37, // Floating point neural network }; // Checks whether the current processor supports one of the features above. // Checks CPU registers to return hardware capabilities. bool TestCPUFeature(CPUFeature feature); // Returns CPU Vendor string (i.e. 'GenuineIntel', 'AuthenticAMD', etc.) std::string CPUVendorIDString(); // Returns CPU family. int CPUFamily(); // Returns CPU model number. int CPUModelNum(); // Returns nominal core processor cycles per second of each processor. double NominalCPUFrequency(); // Returns num of hyperthreads per physical core int CPUIDNumSMT(); void CheckFeatureOrDie(monolith::CPUFeature feature, const std::string &feature_name); void RunCPUGuard(); } // namespace monolith #endif // MONOLITH_MONOLITH_NATIVE_TRAINING_RUNTIME_COMMON_CPU_INFO_H_ ================================================ FILE: monolith/native_training/runtime/common/linalg_utils.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_MONOLITH_NATIVE_TRAINING_RUNTIME_COMMON_LINALG_UTILS_H_ #define MONOLITH_MONOLITH_NATIVE_TRAINING_RUNTIME_COMMON_LINALG_UTILS_H_ #include #include #include #include namespace monolith { namespace common { template typename std::enable_if::is_integer, bool>::type IsAlmostEqual(T x, T y, int ulp = 2) { // the machine epsilon has to be scaled to the magnitude of the values used // and multiplied by the desired precision in ULPs (units in the last place) return std::abs(x - y) <= std::numeric_limits::epsilon() * std::abs(x + y) * ulp // unless the result is subnormal || std::abs(x - y) < std::numeric_limits::min(); } template typename std::enable_if::is_integer, T>::type L2NormSquare(const T* data, size_t length) { T sum = 0; for (size_t i = 0; i < length; ++i) { sum += data[i] * data[i]; } return sum; } } // namespace common } // namespace monolith #endif // MONOLITH_MONOLITH_NATIVE_TRAINING_RUNTIME_COMMON_LINALG_UTILS_H_ ================================================ FILE: monolith/native_training/runtime/common/linalg_utils_test.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/runtime/common/linalg_utils.h" #include "glog/logging.h" #include "gmock/gmock.h" #include "gtest/gtest.h" namespace monolith { namespace common { TEST(LinalgUtils, IsAlmostEqual) { EXPECT_TRUE(IsAlmostEqual(0.f, 0.f)); EXPECT_FALSE(IsAlmostEqual(0.f, 1e-6f)); } TEST(LinalgUtils, L2NormSquare) { std::vector vec1 = {}; EXPECT_TRUE(IsAlmostEqual(L2NormSquare(vec1.data(), vec1.size()), 0.f)); std::vector vec2 = {1}; EXPECT_TRUE(IsAlmostEqual(L2NormSquare(vec2.data(), vec2.size()), 1.f)); std::vector vec3 = {1, 2, 3, 4}; EXPECT_TRUE(IsAlmostEqual(L2NormSquare(vec3.data(), vec3.size()), 30.f)); } } // namespace common } // namespace monolith ================================================ FILE: monolith/native_training/runtime/common/metrics.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/runtime/common/metrics.h" #include #include "glog/logging.h" namespace monolith { cpputil::metrics2::MetricCollector *GetMetrics() { static auto *metrics = new cpputil::metrics2::MetricCollector(); return metrics; } } // namespace monolith ================================================ FILE: monolith/native_training/runtime/common/metrics.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_NATIVE_TRAINING_RUNTIME_COMMON_METRICS_H_ #define MONOLITH_NATIVE_TRAINING_RUNTIME_COMMON_METRICS_H_ #include #include namespace cpputil { namespace metrics2 { // This is a dummy implementation // Will be replaced by a unified interface class MetricCollector { public: typedef std::vector> TagkvList; MetricCollector() = default; virtual ~MetricCollector() = default; template int init(const T& conf) { return 0; } int define_tagk(const std::string& tagk) { return 0; } int define_tagkv(const std::string& tagk, const std::vector& tagv_list) { return 0; } int define_counter(const std::string& name) { return 0; } int define_counter(const std::string& name, const std::string& ) { return 0; } int define_rate_counter(const std::string& name) { return 0; } int define_rate_counter(const std::string& name, const std::string& ) { return 0; } int define_meter(const std::string& name) { return 0; } int define_meter(const std::string& name, const std::string& ) { return 0; } int define_timer(const std::string& name) { return 0; } int define_timer(const std::string& name, const std::string& ) { return 0; } int define_store(const std::string& name) { return 0; } int define_store(const std::string& name, const std::string& ) { return 0; } int define_ts_store(const std::string& name) { return 0; } int define_ts_store(const std::string& name, const std::string& ) { return 0; } int emit_counter(const std::string& name, double value) const { return 0; } int emit_counter(const std::string& name, double value, std::string tagkv) const { return 0; } int emit_counter(const std::string& name, double value, const TagkvList& tagkv_list) const { return 0; } int emit_rate_counter(const std::string& name, double value) const { return 0; } int emit_rate_counter(const std::string& name, double value, const std::string& tagkv) const { return 0; } int emit_rate_counter(const std::string& name, double value, const TagkvList& tagkv_list) { return 0; } int emit_meter(const std::string& name, double value) const { return 0; } int emit_meter(const std::string& name, double value, const std::string& tagkv) const { return 0; } int emit_meter(const std::string& name, double value, const TagkvList& tagkv_list) { return 0; } int emit_timer(const std::string& name, double value) const { return 0; } int emit_timer(const std::string& name, double value, std::string tagkv) const { return 0; } int emit_timer(const std::string& name, double value, const TagkvList& tagkv_list) const { return 0; } int emit_store(const std::string& name, double value) const { return 0; } int emit_store(const std::string& name, double value, std::string tagkv) const { return 0; } int emit_store(const std::string& name, double value, const TagkvList& tagkv_list) const { return 0; } int emit_ts_store(const std::string& name, double value, time_t ts) const { return 0; } int emit_ts_store(const std::string& name, double value, time_t ts, std::string tagkv) const { return 0; } int emit_ts_store(const std::string& name, double value, time_t ts, const TagkvList& tagkv_list) const { return 0; } int reset_counter(const std::string& name) const { return 0; } int reset_counter(const std::string& name, std::string tagkv) const { return 0; } int reset_counter(const std::string& name, const TagkvList& tagkv_list) const { return 0; } int reset_rate_counter(const std::string& name) const { return 0; } int reset_rate_counter(const std::string& name, const std::string& tagkv) { return 0; } int reset_rate_counter(const std::string& name, const TagkvList& tagkv_list) { return 0; } int reset_timer(const std::string& name) const { return 0; } int reset_timer(const std::string& name, std::string tagkv) const { return 0; } int reset_timer(const std::string& name, const TagkvList& tagkv_list) const { return 0; } int reset_store(const std::string& name) const { return 0; } int reset_store(const std::string& name, std::string tagkv) const { return 0; } int reset_store(const std::string& name, const TagkvList& tagkv_list) const { return 0; } int reset_ts_store(const std::string& name) const { return 0; } int reset_ts_store(const std::string& name, std::string tagkv) const { return 0; } int reset_ts_store(const std::string& name, const TagkvList& tagkv_list) const { return 0; } // deprecated static int start_flush_thread() { return 1; } // deprecated static int start_listening_thread() { return 1; } static std::string make_tagkv(const TagkvList& tagkv_list) { return ""; } }; } // namespace metrics2 } // namespace cpputil namespace monolith { cpputil::metrics2::MetricCollector *GetMetrics(); } // namespace monolith #endif // MONOLITH_NATIVE_TRAINING_RUNTIME_COMMON_METRICS_H_ ================================================ FILE: monolith/native_training/runtime/common/metrics_test.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include "glog/logging.h" #include "gmock/gmock.h" #include "gtest/gtest.h" #include "monolith/native_training/runtime/common/metrics.h" namespace monolith { TEST(MetricsTest, Default) { putenv(const_cast("TCE_PSM=data.tob.test")); static cpputil::metrics2::MetricCollector *metrics1 = monolith::GetMetrics(); static cpputil::metrics2::MetricCollector *metrics2 = monolith::GetMetrics(); EXPECT_EQ(metrics1, metrics2); } } // namespace monolith ================================================ FILE: monolith/native_training/runtime/concurrency/BUILD ================================================ load("@rules_cc//cc:defs.bzl", "cc_library", "cc_proto_library", "cc_test") package(default_visibility = ["//monolith/native_training/runtime:__subpackages__"]) cc_library( name = "thread_pool", srcs = ["thread_pool.cc"], hdrs = ["thread_pool.h"], deps = [ "@com_google_absl//absl/synchronization", ], ) cc_library( name = "queue", hdrs = ["queue.h"], deps = [], visibility = ["//visibility:public"], ) cc_test( name = "queue_test", srcs = ["queue_test.cc"], deps = [ ":queue", "@com_google_googletest//:gtest_main", ], ) cc_library( name = "sleeper", hdrs = ["sleeper.h"], deps = [], ) cc_library( name = "micro_one_bit_spin_lock", hdrs = ["micro_one_bit_spin_lock.h"], deps = [ ":sleeper", "@com_google_glog//:glog" ], ) cc_library( name = "xorshift", hdrs = ["xorshift.h"], srcs = ["xorshift.cc"], deps = [], ) cc_binary( name = "xorshift_test", srcs = ["xorshift_test.cc"], deps = [ ":xorshift", "@com_google_googletest//:gtest_main", ], ) cc_binary( name = "random_number_generator_benchmark", srcs = ["random_number_generator_benchmark.cc"], deps = [ ":xorshift", "@com_google_absl//absl/random", "@com_github_google_benchmark//:benchmark", ], ) ================================================ FILE: monolith/native_training/runtime/concurrency/micro_one_bit_spin_lock.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_MONOLITH_NATIVE_TRAINING_RUNTIME_CONCURRENCY_MICRO_ONE_BIT_SPIN_LOCK_H_ #define MONOLITH_MONOLITH_NATIVE_TRAINING_RUNTIME_CONCURRENCY_MICRO_ONE_BIT_SPIN_LOCK_H_ #include "monolith/native_training/runtime/concurrency/sleeper.h" #include "glog/logging.h" namespace monolith { namespace concurrency { // ensure never modify the other bits of lock_ when are not holding the lock struct MicroOneBitSpinLock { private: static const uint8_t MASK = 0x1; public: enum { FREE = 0, LOCKED = MASK }; // lock_ can't be std::atomic<> to preserve POD-ness. mutable uint8_t lock_; // Initialize this MSL. It is unnecessary to call this if you // zero-initialize the MicroSpinLock. void Init() { Payload()->store(Payload()->load() & ~MASK); } bool TryLock() { uint8_t val = Payload()->load(); return CompareAndSwap(val & ~MASK, val | MASK); } void Lock() { Sleeper sleeper; do { while ((Payload()->load() & MASK) != FREE) { sleeper.Wait(); } } while (!TryLock()); DCHECK((Payload()->load() & MASK) == LOCKED); } void Unlock() { uint8_t val = Payload()->load(); CHECK((val & MASK) == LOCKED); Payload()->store(val & ~MASK, std::memory_order_release); } uint8_t Value() const { return Payload()->load() >> 1; } void Set(uint8_t val) { Payload()->store((val << 1) + (Payload()->load() & MASK)); } private: std::atomic* Payload() const { return reinterpret_cast*>(&this->lock_); } bool CompareAndSwap(uint8_t compare, uint8_t newVal) { return std::atomic_compare_exchange_strong_explicit(Payload(), &compare, newVal, std::memory_order_acquire, std::memory_order_relaxed); } }; static_assert(std::is_pod::value, "MicroOneBitSpinLock must be kept a POD type."); } // namespace concurrency } // namespace monolith #endif // MONOLITH_MONOLITH_NATIVE_TRAINING_RUNTIME_CONCURRENCY_MICRO_ONE_BIT_SPIN_LOCK_H_ ================================================ FILE: monolith/native_training/runtime/concurrency/queue.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_MONOLITH_NATIVE_TRAINING_RUNTIME_CONCURRENCY_QUEUE_H_ #define MONOLITH_MONOLITH_NATIVE_TRAINING_RUNTIME_CONCURRENCY_QUEUE_H_ #include #include #include #include #include namespace monolith { namespace concurrency { template class Queue { public: // Create a queue object with a given maximum size(default: max_size=1). // If max_size is 0, the queue size is infinite. explicit Queue(size_t max_size = 1) : max_size_(max_size == 0 ? std::numeric_limits::max() : max_size) {} Queue(const Queue&) = delete; // disable copying Queue& operator=(const Queue&) = delete; // disable assignment // Return the front item of the queue, it blocks if no item was available. T front() { std::unique_lock lock(mutex_); while (queue_.empty()) { enqueue_cond_.wait(lock); } return _top(); } // Remove and return an item from the queue, it blocks if no item was // available. T pop() { std::unique_lock lock(mutex_); while (queue_.empty()) { enqueue_cond_.wait(lock); } auto val = _top(); queue_.pop(); lock.unlock(); dequeue_cond_.notify_one(); return val; } // Remove an item(and assign to T& item) from the queue, it blocks if // no item was available. void pop(T& item) { // NOLINT std::unique_lock lock(mutex_); while (queue_.empty()) { enqueue_cond_.wait(lock); } item = _top(); queue_.pop(); lock.unlock(); dequeue_cond_.notify_one(); } // Try to remove an item(and assign to T& item) from the queue, it blocks // at most 'timeout' duration and return false if no item was available // within that time. template bool try_pop(T& item, std::chrono::duration timeout) { // NOLINT std::unique_lock lock(mutex_); if (!enqueue_cond_.wait_for(lock, timeout, [this] { return !queue_.empty(); })) { return false; } item = _top(); queue_.pop(); lock.unlock(); dequeue_cond_.notify_one(); return true; } // Try to push an item into the queue, it blocks at most 'timeout' // duration and return false if no free slot was available within // that time. template bool try_push(T item, std::chrono::duration timeout) { std::unique_lock lock(mutex_); if (!dequeue_cond_.wait_for(lock, timeout, [this] { return queue_.size() < max_size_; })) { return false; } queue_.push(std::move(item)); lock.unlock(); enqueue_cond_.notify_one(); return true; } // Put an item into the queue, it blocks if no free slot was available. void push(T item) { std::unique_lock lock(mutex_); while (queue_.size() >= max_size_) { dequeue_cond_.wait(lock); } queue_.push(std::move(item)); lock.unlock(); enqueue_cond_.notify_one(); } // Return true if the queue is empty, false otherwise (not reliable). bool empty() { std::unique_lock lock(mutex_); return queue_.empty(); } private: template inline typename std::enable_if::type _top() { return queue_.top(); } template inline typename std::enable_if::type _top() { return queue_.front(); } private: size_t max_size_; typename std::conditional, std::queue>::type queue_; std::mutex mutex_; std::condition_variable enqueue_cond_; std::condition_variable dequeue_cond_; }; } // namespace concurrency } // namespace monolith #endif // MONOLITH_MONOLITH_NATIVE_TRAINING_RUNTIME_CONCURRENCY_QUEUE_H_ ================================================ FILE: monolith/native_training/runtime/concurrency/queue_test.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/runtime/concurrency/queue.h" #include #include #include #include "gtest/gtest.h" using std::chrono::duration_cast; using std::chrono::high_resolution_clock; using std::chrono::microseconds; using std::chrono::milliseconds; namespace monolith { namespace concurrency { namespace { float PushTimeout(int timeout /* milliseconds */) { monolith::concurrency::Queue queue(1); queue.push(1); auto start = high_resolution_clock::now(); EXPECT_FALSE(queue.try_push(2, milliseconds(timeout))); auto elapsed = high_resolution_clock::now() - start; return duration_cast(elapsed).count() / 1000.f; } float PopTimeout(int timeout /* milliseconds */) { monolith::concurrency::Queue queue(1); queue.push(1); int item = queue.pop(); EXPECT_EQ(item, 1); auto start = high_resolution_clock::now(); EXPECT_FALSE(queue.try_pop(item, milliseconds(timeout))); auto elapsed = high_resolution_clock::now() - start; return duration_cast(elapsed).count() / 1000.f; } TEST(QueueTest, Basic) { std::atomic_int producer_count(0); std::atomic_int consumer_count(0); std::atomic done(false); monolith::concurrency::Queue queue(128); const int iterations = 10 * 10000; const int producer_thread_count = 10; const int consumer_thread_count = 10; const std::chrono::microseconds timeout(10); auto producer = [&]() { for (int i = 0; i != iterations; ++i) { int value = ++producer_count; while (!queue.try_push(value, timeout)) {} } }; auto consumer = [&]() { int value; while (!done) { while (queue.try_pop(value, timeout)) ++consumer_count; } while (queue.try_pop(value, timeout)) ++consumer_count; }; std::vector producer_threads, consumer_threads; for (int i = 0; i != producer_thread_count; ++i) { producer_threads.emplace_back(producer); } for (int i = 0; i != consumer_thread_count; ++i) { consumer_threads.emplace_back(consumer); } for (auto& t : producer_threads) { if (t.joinable()) { t.join(); } } done = true; for (auto& t : consumer_threads) { if (t.joinable()) { t.join(); } } EXPECT_EQ(producer_count, iterations * producer_thread_count); EXPECT_EQ(consumer_count, iterations * consumer_thread_count); } TEST(QueueTest, Timeout) { EXPECT_NEAR(PushTimeout(1), 1.f, 0.5); EXPECT_NEAR(PushTimeout(10), 10.f, 2); EXPECT_NEAR(PushTimeout(1000), 1000.f, 20); EXPECT_NEAR(PopTimeout(1), 1.f, 0.5); EXPECT_NEAR(PopTimeout(10), 10.f, 2); EXPECT_NEAR(PopTimeout(1000), 1000.f, 20); } } // namespace } // namespace concurrency } // namespace monolith ================================================ FILE: monolith/native_training/runtime/concurrency/random_number_generator_benchmark.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "absl/random/random.h" #include "benchmark/benchmark.h" #include "monolith/native_training/runtime/concurrency/xorshift.h" namespace monolith { namespace concurrency { namespace { const int NUM = 1000000; void BM_STL(benchmark::State& state) { // NOLINT std::random_device random_device; std::mt19937 engine{random_device()}; std::uniform_int_distribution dist( 0, std::numeric_limits::max()); for (auto _ : state) { for (int i = 0; i < NUM; ++i) { dist(engine); } } } void BM_Absl(benchmark::State& state) { // NOLINT absl::BitGen bit_gen; for (auto _ : state) { for (int i = 0; i < NUM; ++i) { absl::Uniform(bit_gen, 0u, std::numeric_limits::max()); } } } void BM_XorShift(benchmark::State& state) { // NOLINT for (auto _ : state) { for (int i = 0; i < NUM; ++i) { XorShift::Rand32ThreadSafe(); } } } // Run on (96 X 3900 MHz CPU s) // CPU Caches: // L1 Data 32K (x48) // L1 Instruction 32K (x48) // L2 Unified 1024K (x48) // L3 Unified 36608K (x2) // Load Average: 10.64, 12.83, 14.44 // ------------------------------------------------------ // Benchmark Time CPU Iterations // ------------------------------------------------------ // BM_STL 5849400 ns 5849341 ns 117 // BM_Absl 5574646 ns 5574647 ns 126 // BM_XorShift 3250932 ns 3244318 ns 216 BENCHMARK(BM_STL); BENCHMARK(BM_Absl); BENCHMARK(BM_XorShift); } // namespace } // namespace concurrency } // namespace monolith BENCHMARK_MAIN(); ================================================ FILE: monolith/native_training/runtime/concurrency/sleeper.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. /* * Copyright (c) Meta Platforms, Inc. and affiliates. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MONOLITH_NATIVE_TRAINING_RUNTIME_CONCURRENCY_SLEEPER_H_ #define MONOLITH_NATIVE_TRAINING_RUNTIME_CONCURRENCY_SLEEPER_H_ namespace monolith { namespace concurrency { // detection for 64 bit #if defined(__x86_64__) || defined(_M_X64) #define FOLLY_X64 1 #else #define FOLLY_X64 0 #endif #if defined(__aarch64__) #define FOLLY_AARCH64 1 #else #define FOLLY_AARCH64 0 #endif inline void asm_volatile_pause() { #if defined(_MSC_VER) && (defined(_M_IX86) || defined(_M_X64)) ::_mm_pause(); #elif defined(__i386__) || FOLLY_X64 asm volatile("pause"); #elif FOLLY_AARCH64 || defined(__arm__) asm volatile("yield"); #elif FOLLY_PPC64 asm volatile("or 27,27,27"); #endif } /* * A helper object for the contended case. Starts off with eager * spinning, and falls back to sleeping for small quantums. */ class Sleeper { static const uint32_t kMaxActiveSpin = 4000; uint32_t spinCount; public: Sleeper() : spinCount(0) {} void Wait() { if (spinCount < kMaxActiveSpin) { ++spinCount; asm_volatile_pause(); } else { /* * Always sleep 0.5ms, assuming this will make the kernel put * us down for whatever its minimum timer resolution is (in * linux this varies by kernel version from 1ms to 10ms). */ struct timespec ts = {0, 500000}; nanosleep(&ts, nullptr); } } }; } // namespace concurrency } // namespace monolith #endif // MONOLITH_NATIVE_TRAINING_RUNTIME_CONCURRENCY_SLEEPER_H_ ================================================ FILE: monolith/native_training/runtime/concurrency/thread_pool.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "thread_pool.h" namespace monolith { namespace concurrency { ThreadPool::~ThreadPool() { { absl::MutexLock l(&mu_); for (size_t i = 0; i < threads_.size(); i++) { queue_.push(nullptr); // Shutdown signal. } } for (auto &t : threads_) { t.join(); } } void ThreadPool::Schedule(std::function func) { assert(func != nullptr); absl::MutexLock l(&mu_); queue_.push(std::move(func)); } void ThreadPool::WorkLoop() { while (true) { std::function func; { absl::MutexLock l(&mu_); mu_.Await(absl::Condition(this, &ThreadPool::WorkAvailable)); func = std::move(queue_.front()); queue_.pop(); } if (func == nullptr) { // Shutdown signal. break; } func(); } } } // namespace concurrency } // namespace monolith ================================================ FILE: monolith/native_training/runtime/concurrency/thread_pool.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #ifndef MONOLITH_MONOLITH_NATIVE_TRAINING_RUNTIME_CONCURRENCY_THREAD_POOL_H_ #define MONOLITH_MONOLITH_NATIVE_TRAINING_RUNTIME_CONCURRENCY_THREAD_POOL_H_ #include #include #include #include #include #include #include "absl/base/thread_annotations.h" #include "absl/synchronization/mutex.h" namespace monolith { namespace concurrency { /** * A simple ThreadPool implementation for tests/benchmarks. */ class ThreadPool { public: explicit ThreadPool(int num_threads) { for (int i = 0; i < num_threads; ++i) { threads_.emplace_back(&ThreadPool::WorkLoop, this); } } ThreadPool(const ThreadPool &) = delete; ThreadPool &operator=(const ThreadPool &) = delete; ~ThreadPool(); // Schedule a function to be run on a ThreadPool thread immediately. void Schedule(std::function func); private: bool WorkAvailable() const ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) { return !queue_.empty(); } void WorkLoop(); absl::Mutex mu_; std::queue> queue_ ABSL_GUARDED_BY(mu_); std::vector threads_; }; } // namespace concurrency } // namespace monolith #endif // MONOLITH_MONOLITH_NATIVE_TRAINING_RUNTIME_CONCURRENCY_THREAD_POOL_H_ ================================================ FILE: monolith/native_training/runtime/concurrency/xorshift.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/runtime/concurrency/xorshift.h" namespace monolith { namespace concurrency { uint64_t XorShift::XorShift1024Star() { uint64_t s0 = s[p]; uint64_t s1 = s[p = (p + 1) & 15]; s1 ^= s1 << 31; // a s1 ^= s1 >> 11; // b s0 ^= s0 >> 30; // c return (s[p] = s0 ^ s1) * UINT64_C(1181783497276652981); } uint64_t XorShift::XorShift128Plus() { uint64_t x = s[0]; uint64_t const y = s[1]; s[0] = y; x ^= x << 23; // a x ^= x >> 17; // b x ^= y ^ (y >> 26); // c s[1] = x; return x + y; } uint64_t XorShift::XorShift64Star() { x ^= x >> 12; // a x ^= x << 25; // b x ^= x >> 27; // c return x * UINT64_C(2685821657736338717); } uint32_t XorShift::Rand32ThreadSafe() { static thread_local XorShift xor_shift; return xor_shift.Rand32(); } } // namespace concurrency } // namespace monolith ================================================ FILE: monolith/native_training/runtime/concurrency/xorshift.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_MONOLITH_NATIVE_TRAINING_RUNTIME_CONCURRENCY_XORSHIFT_H_ #define MONOLITH_MONOLITH_NATIVE_TRAINING_RUNTIME_CONCURRENCY_XORSHIFT_H_ #include #include #include #include #include namespace monolith { namespace concurrency { class XorShift { public: XorShift() : p(0) { srand(time(0)); x = (uint64_t)std::rand() * RAND_MAX + std::rand(); for (uint64_t& i : s) { i = XorShift64Star(); } } uint32_t Rand32() { return (uint32_t)XorShift128Plus(); } static uint32_t Rand32ThreadSafe(); private: uint64_t XorShift1024Star(); uint64_t XorShift128Plus(); uint64_t XorShift64Star(); private: uint64_t s[16]; int p; uint64_t x; /* The state must be seeded with a nonzero value. */ }; } // namespace concurrency } // namespace monolith #endif // MONOLITH_MONOLITH_NATIVE_TRAINING_RUNTIME_CONCURRENCY_XORSHIFT_H_ ================================================ FILE: monolith/native_training/runtime/concurrency/xorshift_test.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/runtime/concurrency/xorshift.h" #include #include #include "gtest/gtest.h" namespace monolith { namespace concurrency { namespace { const int64_t NUM = 1000000, RADIUS = std::numeric_limits::max() / 2; const int64_t RADIUS_SQUARE = RADIUS * RADIUS; const float PI = 3.14f, eps = 0.01f; float EstimatingPiWithMonteCarlo(const std::function& generator) { int count = 0; for (int i = 0; i < NUM; ++i) { auto x = generator() % RADIUS; auto y = generator() % RADIUS; if (x * x + y * y < RADIUS_SQUARE) { ++count; } } return static_cast(count * 4) / NUM; } TEST(XorShift, SingleThread) { std::random_device random_device; std::mt19937 engine{random_device()}; std::uniform_int_distribution<> dist(0, RADIUS); float pi1 = EstimatingPiWithMonteCarlo([&]() { return dist(engine); }); EXPECT_NEAR(pi1, PI, eps); XorShift generator; float pi2 = EstimatingPiWithMonteCarlo([&]() { return generator.Rand32(); }); EXPECT_NEAR(pi2, PI, eps); } TEST(XorShift, MultiThread) { std::random_device random_device; std::mt19937 engine{random_device()}; std::uniform_int_distribution<> dist(0, RADIUS); float pi = EstimatingPiWithMonteCarlo([&]() { return dist(engine); }); EXPECT_NEAR(pi, PI, eps); int thread_num = 10; std::vector pi_array(thread_num); std::vector threads; for (int i = 0; i < thread_num; ++i) { threads.emplace_back([&]() { float pi = EstimatingPiWithMonteCarlo( [&]() { return XorShift::Rand32ThreadSafe(); }); EXPECT_NEAR(pi, PI, eps); }); } for (auto& t : threads) { t.join(); } } } // namespace } // namespace concurrency } // namespace monolith ================================================ FILE: monolith/native_training/runtime/deep_insight/BUILD ================================================ load("@rules_cc//cc:defs.bzl", "cc_library", "cc_test") load("@org_tensorflow//tensorflow:tensorflow.bzl", "tf_custom_op_library") package(default_visibility = ["//monolith/native_training:__subpackages__"]) cc_library( name = "deep_insight_internal_deps", ) cc_library( name = "deep_insight", srcs = ["deep_insight.cc"], hdrs = ["deep_insight.h"], deps = [ ":deep_insight_internal_deps", "//monolith/native_training/runtime/common:metrics", "//third_party/nlohmann:json", "@com_google_glog//:glog", ], ) cc_test( name = "deep_insight_test", srcs = ["deep_insight_test.cc"], deps = [ ":deep_insight", "//monolith/native_training/runtime/common:metrics", "@com_google_glog//:glog", "@com_google_googletest//:gtest_main", ], ) ================================================ FILE: monolith/native_training/runtime/deep_insight/deep_insight.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "glog/logging.h" #include "monolith/native_training/runtime/common/metrics.h" namespace monolith { namespace deep_insight { } // namespace deep_insight } // namespace monolith ================================================ FILE: monolith/native_training/runtime/deep_insight/deep_insight.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_NATIVE_TRAINING_RUNTIME_DEEP_INSIGHT #define MONOLITH_NATIVE_TRAINING_RUNTIME_DEEP_INSIGHT #include #include #include #include "gflags/gflags.h" #include "glog/logging.h" #include "third_party/nlohmann/json.hpp" namespace monolith { namespace deep_insight { class ExtraField { public: explicit ExtraField(const std::string& k) : key_(k) {} virtual void add_to(nlohmann::json* j) = 0; const std::string& key() { return key_; } private: std::string key_; }; class FloatExtraField : public ExtraField { public: explicit FloatExtraField(const std::string& k, const float& v) : ExtraField(k), value_(v) {} void add_to(nlohmann::json* j) { (*j)["extra_float"][key()] = value_; } private: float value_; }; class Int64ExtraField : public ExtraField { public: explicit Int64ExtraField(const std::string& k, const int64_t& v) : ExtraField(k), value_(v) {} void add_to(nlohmann::json* j) { (*j)["extra_int"][key()] = value_; } private: int64_t value_; }; class StringExtraField : public ExtraField { public: explicit StringExtraField(const std::string& k, const std::string& v) : ExtraField(k), value_(v) {} void add_to(nlohmann::json* j) { (*j)["extra_str"][key()] = value_; } private: std::string value_; }; class DeepInsight { public: template explicit DeepInsight(Args...) {} template std::string SendV2(Args...) { return ""; } template bool HitSampleRatio(Args...) { return false; } int64_t GenerateTrainingTime() { return 0; } uint64_t GetTotalSendCounter() { return 0; } }; } // namespace deep_insight } // namespace monolith #endif // MONOLITH_NATIVE_TRAINING_RUNTIME_DEEP_INSIGHT ================================================ FILE: monolith/native_training/runtime/deep_insight/deep_insight_test.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/runtime/deep_insight/deep_insight.h" #include #include #include "glog/logging.h" #include "gtest/gtest.h" using monolith::deep_insight::ExtraField; using monolith::deep_insight::FloatExtraField; using monolith::deep_insight::Int64ExtraField; using monolith::deep_insight::StringExtraField; using json = nlohmann::json; namespace monolith { namespace deep_insight { } // namespace deep_insight } // namespace monolith ================================================ FILE: monolith/native_training/runtime/hash_filter/BUILD ================================================ load("@rules_cc//cc:defs.bzl", "cc_library", "cc_test") package(default_visibility = ["//monolith/native_training/runtime:__subpackages__"]) cc_library( name = "types", hdrs = ["types.h"], ) cc_library( name = "filter", hdrs = ["filter.h"], deps = [ ":types", "//monolith/native_training/runtime/hash_table:embedding_hash_table_cc_proto", "//monolith/native_training/runtime/hash_table:embedding_hash_table_factory", ], ) cc_library( name = "hash_filter", hdrs = ["hash_filter.h"], srcs = ["hash_filter.cc"], deps = [ ":filter", ":types", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/hash", "@com_google_absl//absl/types:span", "@com_google_glog//:glog", ], ) cc_test( name = "hash_filter_test", srcs = ["hash_filter_test.cc"], deps = [ ":hash_filter", "@com_google_googletest//:gtest_main", ], ) cc_library( name = "sliding_hash_filter", srcs = ["sliding_hash_filter.cc"], hdrs = ["sliding_hash_filter.h"], deps = [ ":hash_filter", "//monolith/native_training/runtime/hash_table:embedding_hash_table_cc_proto", "@com_google_absl//absl/strings:str_format", ], ) cc_library( name = "dummy_hash_filter", hdrs = ["dummy_hash_filter.h"], deps = [ ":hash_filter", ], ) cc_test( name = "sliding_hash_filter_test", srcs = ["sliding_hash_filter_test.cc"], deps = [ ":sliding_hash_filter", "@com_google_googletest//:gtest_main", ], ) cc_library( name = "probabilistic_filter", hdrs = ["probabilistic_filter.h"], srcs = ["probabilistic_filter.cc"], deps = [ ":filter", ":hash_filter", "//monolith/native_training/runtime/concurrency:xorshift", ], ) cc_binary( name = "probabilistic_filter_test", srcs = ["probabilistic_filter_test.cc"], deps = [ ":probabilistic_filter", "@com_google_googletest//:gtest_main", ], ) ================================================ FILE: monolith/native_training/runtime/hash_filter/dummy_hash_filter.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. /** * Implement a dummy hash filter which has no real hash filter logic inside. * It will always return HashFilter::max_count() so that all FIDs can * pass the hash filter check. **/ #ifndef MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_FILTER_DUMMY_HASH_FILTER_H_ #define MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_FILTER_DUMMY_HASH_FILTER_H_ #include #include "monolith/native_training/runtime/hash_filter/filter.h" #include "monolith/native_training/runtime/hash_filter/hash_filter.h" namespace monolith { namespace hash_filter { class DummyHashFilter : public Filter { public: DummyHashFilter() = default; DummyHashFilter(const DummyHashFilter& other) {} uint32_t add(FID fid, uint32_t count) { return HashFilter::max_count(); } uint32_t get(FID fid) const { return HashFilter::max_count(); } uint32_t size_mb() const { return 0; } size_t estimated_total_element() const { return 0; } size_t failure_count() const { return 0; } size_t split_num() const { return 0; } bool ShouldBeFiltered( int64_t fid, int64_t count, int64_t slot_occurrence_threshold, const monolith::hash_table::EmbeddingHashTableInterface* table) override { return false; } DummyHashFilter& operator=(DummyHashFilter const&) = delete; DummyHashFilter* clone() const { return new DummyHashFilter(*this); } }; } // namespace hash_filter } // namespace monolith #endif // MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_FILTER_DUMMY_HASH_FILTER_H_ ================================================ FILE: monolith/native_training/runtime/hash_filter/filter.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_FILTER_FILTER_H_ #define MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_FILTER_FILTER_H_ #include #include "monolith/native_training/runtime/hash_filter/types.h" #include "monolith/native_training/runtime/hash_table/embedding_hash_table.pb.h" #include "monolith/native_training/runtime/hash_table/embedding_hash_table_interface.h" namespace monolith { namespace hash_filter { class Filter { public: Filter() : capacity_(0) {} virtual uint32_t add(FID fid, uint32_t count) = 0; virtual uint32_t get(FID fid) const = 0; virtual uint32_t size_mb() const = 0; virtual size_t estimated_total_element() const = 0; virtual size_t failure_count() const = 0; virtual size_t capacity() const { return capacity_; } virtual size_t split_num() const = 0; virtual bool exceed_limit() const { return false; } virtual void set_name(const std::string& name) { name_ = name; } virtual Filter* clone() const = 0; virtual bool ShouldBeFiltered( int64_t fid, int64_t count, int64_t slot_occurrence_threshold, const monolith::hash_table::EmbeddingHashTableInterface* table) = 0; virtual void Save( int split_idx, std::function write_meta_fn, std::function write_data_fn) const {} virtual void Restore( int split_idx, std::function get_meta_fn, std::function get_data_fn) {} virtual ~Filter() {} constexpr static unsigned char count_bit = 4; constexpr static uint32_t max_count() { return (1 << count_bit) - 1; } protected: size_t capacity_; std::string name_; }; } // namespace hash_filter } // namespace monolith #endif // MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_FILTER_FILTER_H_ ================================================ FILE: monolith/native_training/runtime/hash_filter/hash_filter.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/runtime/hash_filter/hash_filter.h" namespace monolith { namespace hash_filter { namespace proto2 = google::protobuf; using ::monolith::hash_table::SlidingHashFilterMetaDump; using ::monolith::hash_table::HashFilterSplitMetaDump; using ::monolith::hash_table::HashFilterSplitDataDump; const int kMaxNumPerTfRecord = 10000; template <> void HashFilter::Save( const SlidingHashFilterMetaDump& sliding_hash_filter_meta_dump, std::function write_meta_fn, std::function write_data_fn) const { // Write meta part with one tf-record with HashFilterSplitMetaDump type. HashFilterSplitMetaDump meta_dump; meta_dump.set_failure_count(failure_count_); meta_dump.set_total_size(total_size_); meta_dump.set_num_elements(num_elements_); meta_dump.set_fill_rate(fill_rate_); *(meta_dump.mutable_sliding_hash_filter_meta()) = sliding_hash_filter_meta_dump; write_meta_fn(std::move(meta_dump)); // Write data part with multiple tf-records of HashFilterSplitDataDump type. int tf_record_num = (map_.size() + kMaxNumPerTfRecord - 1) / kMaxNumPerTfRecord; for (int record_idx = 0; record_idx < tf_record_num; record_idx++) { int start = record_idx * kMaxNumPerTfRecord; int end = std::min(start + kMaxNumPerTfRecord, static_cast(map_.size())); HashFilterSplitDataDump data_dump; data_dump.set_offset(start); for (int i = start; i < end; i++) { data_dump.add_data(map_[i]); } write_data_fn(std::move(data_dump)); } } template <> void HashFilter::Restore( HashFilterSplitMetaDump meta_dump, std::function get_data_fn) { // Restore hash filter meta. failure_count_ = meta_dump.failure_count(); total_size_ = meta_dump.total_size(); num_elements_ = meta_dump.num_elements(); fill_rate_ = meta_dump.fill_rate(); // Restore hash filter data. map_.resize(total_size_ + MAX_STEP, 0); HashFilterSplitDataDump data_dump; while (get_data_fn(&data_dump)) { int offset = data_dump.offset(); for (int i = 0; i < data_dump.data_size(); i++) { map_[offset + i] = (uint16_t)(data_dump.data(i)); } } } } // namespace hash_filter } // namespace monolith ================================================ FILE: monolith/native_training/runtime/hash_filter/hash_filter.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_FILTER_HASH_FILTER_H_ #define MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_FILTER_HASH_FILTER_H_ #include #include #include #include "absl/algorithm/container.h" #include "absl/hash/hash.h" #include "absl/types/span.h" #include "monolith/native_training/runtime/hash_filter/filter.h" #include "monolith/native_training/runtime/hash_table/embedding_hash_table.pb.h" namespace monolith { namespace hash_filter { template class HashFilter; template class HashFilterIterator { friend class HashFilter; public: HashFilterIterator() : filter_(NULL), pvalue_(NULL), sign_(0) {} uint32_t add(uint32_t add_count) { assert(valid() && "check validation before add"); if (add_count > HashFilter::max_count()) add_count = HashFilter::max_count(); if (*pvalue_ == 0) { ++filter_->num_elements_; if (filter_->num_elements_ > filter_->capacity_) { filter_->num_elements_ = filter_->capacity_; /* InvalidOperation io; io.why = "total elements exceeds filter capacity"; throw io; */ } *pvalue_ = (sign_ << HashFilter::count_bit) + add_count; return 0; } unsigned char count = *pvalue_ & HashFilter::max_count(); if (count + add_count >= HashFilter::max_count()) *pvalue_ |= HashFilter::max_count(); else *pvalue_ += add_count; return count; } uint32_t get() const { if (!pvalue_) return HashFilter::max_count(); return *pvalue_ & HashFilter::max_count(); } bool valid() const { return pvalue_ != NULL; } bool empty() const { assert(valid() && "check validation before empty"); return *pvalue_ == 0; } private: explicit HashFilterIterator(HashFilter* filter, DATA* pvalue, DATA sign) : filter_(filter), pvalue_(pvalue), sign_(sign) {} HashFilter* filter_; DATA* pvalue_; DATA sign_; }; template class HashFilter : public Filter { friend class HashFilterIterator; public: explicit HashFilter(size_t capacity, double fill_rate = 1.5) : failure_count_(0), total_size_(capacity * fill_rate), num_elements_(0), fill_rate_(fill_rate) { capacity_ = capacity; map_.resize(total_size_ + MAX_STEP, 0); } uint32_t add(FID fid, uint32_t count) override { HashFilterIterator iter = find(fid, MAX_STEP); if (iter.valid()) { return iter.add(count); } failure_count_ += 1; return max_count(); } uint32_t get(FID fid) const override { return const_cast(this)->find(fid, MAX_STEP).get(); } HashFilterIterator find(FID fid, int max_step) { assert(max_step <= MAX_STEP && "illegal max_step"); DATA sign = signature(fid); int step = 0; size_t hash_value = hash(fid) % total_size_; DATA* pvalue = reinterpret_cast(&map_[hash_value]); do { if (*pvalue == 0 || (*pvalue >> count_bit) == sign) { return HashFilterIterator(this, pvalue, sign); } ++pvalue; if (pvalue == &(*map_.end())) { pvalue = &map_[0]; } } while (++step < max_step); return HashFilterIterator(this, NULL, sign); } bool full() const { return num_elements_ >= capacity_ - 1; } // TODO make this async void async_clear() { fill(map_.begin(), map_.end(), 0); num_elements_ = 0; failure_count_ = 0; } uint32_t size_mb() const override { return map_.size() * sizeof(DATA) / 1024.0 / 1024.0; } static size_t size_byte(size_t capacity, double fill_rate = 1.5) { return capacity * sizeof(DATA) * fill_rate + MAX_STEP; } size_t failure_count() const override { return failure_count_; } size_t split_num() const override { return 0; } DATA signature(FID fid) const { return (fid >> 17 | fid << 15) & sign_mask; } size_t estimated_total_element() const override { return num_elements_; } bool exceed_limit() const override { return num_elements_ >= capacity_; } HashFilter* clone() const override { return new HashFilter(*this); } bool ShouldBeFiltered( int64_t fid, int64_t count, int64_t slot_occurrence_threshold, const monolith::hash_table::EmbeddingHashTableInterface* table) override { if (slot_occurrence_threshold <= 0) { return false; } return add(fid, count) < slot_occurrence_threshold; } bool operator==(const HashFilter& other) const { return total_size_ == other.total_size_ && num_elements_ == other.num_elements_ && capacity_ == other.capacity_ && fill_rate_ == other.fill_rate_ && map_ == other.map_; } void Save(const ::monolith::hash_table::SlidingHashFilterMetaDump& sliding_hash_filter_meta_dump, std::function write_meta_fn, std::function write_data_fn) const; void Restore( ::monolith::hash_table::HashFilterSplitMetaDump dump, std::function get_data_fn); constexpr static DATA sign_mask = ((1 << (sizeof(DATA) * 8)) - 1) >> count_bit; private: constexpr static int DUMP_VALUE_SIZE = 1024 * 1024 * 20; // 10-20MB constexpr static int MAX_STEP = 64; size_t hash(FID fid) const { return absl::Hash()(fid); } std::vector map_; uint64_t failure_count_; uint64_t total_size_; uint64_t num_elements_; double fill_rate_; }; } // namespace hash_filter } // namespace monolith /* vim: set expandtab ts=4 sw=4 sts=4 tw=100: */ #endif // MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_FILTER_HASH_FILTER_H_ ================================================ FILE: monolith/native_training/runtime/hash_filter/hash_filter_test.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include "monolith/native_training/runtime/hash_filter/hash_filter.h" namespace monolith { namespace hash_filter { template void test_simple_bloom(size_t key_num) { HashFilter counter(key_num); for (uint32_t i = 0; i <= HashFilter::max_count(); ++i) { uint32_t expect = std::min(i + 1, HashFilter::max_count()); EXPECT_EQ(i, counter.add(1, 1)); EXPECT_EQ(expect, counter.get(1)); } ASSERT_EQ(HashFilter::max_count(), counter.add(1, 1)); ASSERT_EQ(HashFilter::max_count(), counter.get(1)); } TEST(HashFilterTest, test_simple) { test_simple_bloom(1); test_simple_bloom(3); test_simple_bloom(100); test_simple_bloom(1); test_simple_bloom(3); test_simple_bloom(100); } template void test_count() { HashFilter filter(1000000); std::srand(std::time(NULL)); filter.add(std::rand(), 2); EXPECT_EQ(1llu, filter.estimated_total_element()); int key_number = 10; for (int i = 0; i != key_number; ++i) { filter.add(std::rand(), 2); } EXPECT_EQ(size_t(key_number + 1), filter.estimated_total_element()); HashFilter filter2(1000000); filter2.add(10000002961562801052lu, 1); filter2.add(10000002961562801052lu, 20); filter2.add(10000002961562801052lu, 1); EXPECT_EQ(1llu, filter2.estimated_total_element()); std::unique_ptr> filter3(filter2.clone()); EXPECT_TRUE(filter2 == *filter3); } TEST(HashFilterTest, test_count) { test_count(); test_count(); } template void compare_to_unordered_map(int key_number, double expected_rate, double fill_rate) { HashFilter filter(key_number, fill_rate); std::unordered_map map_counter; std::srand(std::time(NULL)); for (int i = 0; i != key_number; ++i) { int num = std::rand(); if (map_counter[num] < HashFilter::max_count() - 1) { map_counter[num] += 2; filter.add(num, 2); } } ASSERT_LE(filter.estimated_total_element(), map_counter.size()); int error_counter = 0; for (auto iter = map_counter.begin(); iter != map_counter.end(); ++iter) { if (filter.get(iter->first) != iter->second) error_counter += 1; } double error_rate = error_counter / double(map_counter.size()); double diff = error_rate / expected_rate; std::cout << "conflict rate diff " << diff << std::endl; EXPECT_GT(diff, 0.5); EXPECT_LT(diff, 2); EXPECT_EQ(0llu, filter.failure_count()); } TEST(HashFilterTest, compare_to_unordered_map) { compare_to_unordered_map(1000000, 0.0208, 4); compare_to_unordered_map(1000000, 0.000242, 2); } template void TestSkipZeroThresholdFeatures() { HashFilter filter(1000000, 10); for (uint32_t i = 0; i < 5; ++i) { int64_t fid_with_zero_threshold = i; EXPECT_FALSE(filter.ShouldBeFiltered(fid_with_zero_threshold, 1, 0, nullptr /* table */)); int64_t normal_fid = i * 2; EXPECT_TRUE(filter.ShouldBeFiltered(normal_fid, 1, 1, nullptr /* table */)); } // Only the normal_fids can be added to the filter. EXPECT_EQ(5llu, filter.estimated_total_element()); } TEST(HashFilterTest, SkipZeroThresholdFeatures) { TestSkipZeroThresholdFeatures(); TestSkipZeroThresholdFeatures(); } /* vim: set expandtab ts=4 sw=4 sts=4 tw=100: */ } // namespace hash_filter } // namespace monolith ================================================ FILE: monolith/native_training/runtime/hash_filter/probabilistic_filter.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/runtime/hash_filter/probabilistic_filter.h" namespace monolith { namespace hash_filter { using ::monolith::concurrency::XorShift; bool ProbabilisticFilter::InsertedIntoHashTableUnequalProbability( int64_t count, int64_t slot_occurrence_threshold) { return XorShift::Rand32ThreadSafe() * slot_occurrence_threshold < std::numeric_limits::max() * count; } bool ProbabilisticFilter::InsertedIntoHashTableEqualProbability( int64_t count, int64_t slot_occurrence_threshold) { float epsilon = 0.05; float p = 1 - std::pow(epsilon, 1.f / static_cast(slot_occurrence_threshold)); return XorShift::Rand32ThreadSafe() < std::numeric_limits::max() * (1.f - std::pow(1.f - p, count)); } bool ProbabilisticFilter::ShouldBeFiltered( int64_t fid, int64_t count, int64_t slot_occurrence_threshold, const monolith::hash_table::EmbeddingHashTableInterface* table) { if (table && !table->Contains(fid)) { if (equal_probability_) { return !InsertedIntoHashTableEqualProbability(count, slot_occurrence_threshold); } else { return !InsertedIntoHashTableUnequalProbability( count, slot_occurrence_threshold); } } return false; } } // namespace hash_filter } // namespace monolith ================================================ FILE: monolith/native_training/runtime/hash_filter/probabilistic_filter.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_FILTER_PROBABILISTIC_FILTER_H_ #define MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_FILTER_PROBABILISTIC_FILTER_H_ #include "monolith/native_training/runtime/concurrency/xorshift.h" #include "monolith/native_training/runtime/hash_filter/filter.h" #include "monolith/native_training/runtime/hash_filter/hash_filter.h" namespace monolith { namespace hash_filter { class ProbabilisticFilter : public Filter { public: explicit ProbabilisticFilter(bool equal_probability = false) : equal_probability_(equal_probability) {} uint32_t add(FID fid, uint32_t count) override { return HashFilter::max_count(); } uint32_t get(FID fid) const override { return HashFilter::max_count(); } uint32_t size_mb() const override { return 0; } size_t estimated_total_element() const override { return 0; } size_t failure_count() const override { return 0; } size_t split_num() const override { return 0; } Filter* clone() const override { return new ProbabilisticFilter(*this); } bool InsertedIntoHashTableUnequalProbability( int64_t count, int64_t slot_occurrence_threshold); bool InsertedIntoHashTableEqualProbability(int64_t count, int64_t slot_occurrence_threshold); bool ShouldBeFiltered( int64_t fid, int64_t count, int64_t slot_occurrence_threshold, const monolith::hash_table::EmbeddingHashTableInterface* table) override; private: bool equal_probability_; }; } // namespace hash_filter } // namespace monolith #endif // MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_FILTER_PROBABILISTIC_FILTER_H_ ================================================ FILE: monolith/native_training/runtime/hash_filter/probabilistic_filter_test.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/runtime/hash_filter/probabilistic_filter.h" #include "google/protobuf/text_format.h" #include "gtest/gtest.h" #include "monolith/native_training/runtime/hash_table/embedding_hash_table_factory.h" namespace monolith { namespace hash_filter { namespace { using ::monolith::hash_table::EmbeddingHashTableConfig; using ::monolith::hash_table::EmbeddingHashTableInterface; namespace proto2 = google::protobuf; TEST(ProbabilisticFilterTest, UnequalProbability) { ProbabilisticFilter filter; int fid_num = 10000, slot_occurrence_threshold = 7; int count1 = 0, count2 = 0; float tolerance_ratio = 0.01f; for (int i = 0; i < fid_num; ++i) { if (filter.InsertedIntoHashTableUnequalProbability( 1, slot_occurrence_threshold)) { ++count1; } if (filter.InsertedIntoHashTableUnequalProbability( 2, slot_occurrence_threshold)) { ++count2; } } EXPECT_NEAR(fid_num / slot_occurrence_threshold, count1, fid_num * tolerance_ratio); EXPECT_NEAR(fid_num / slot_occurrence_threshold * 2, count2, fid_num * tolerance_ratio * 2); } TEST(ProbabilisticFilterTest, EqualProbability) { ProbabilisticFilter filter; int fid_num = 10000, slot_occurrence_threshold = 7; int count1 = 0, count2 = 0; float tolerance_ratio = 0.05; float p = 1 - std::pow(tolerance_ratio, 1.f / static_cast(slot_occurrence_threshold)); for (int i = 0; i < fid_num; ++i) { if (filter.InsertedIntoHashTableEqualProbability( 1, slot_occurrence_threshold)) { ++count1; } if (filter.InsertedIntoHashTableEqualProbability( 2, slot_occurrence_threshold)) { ++count2; } } EXPECT_NEAR(fid_num * p, count1, fid_num * tolerance_ratio); EXPECT_NEAR(fid_num * (1.f - std::pow(1 - p, 2)), count2, fid_num * tolerance_ratio * 2); } TEST(ProbabilisticFilterTest, ShouldBeFiltered) { EmbeddingHashTableConfig config; EXPECT_TRUE(proto2::TextFormat::ParseFromString(R"( entry_config { segments { dim_size: 1 init_config { zeros {} } opt_config { sgd {} } } } cuckoo {} )", &config)); std::unique_ptr table = NewEmbeddingHashTableFromConfig(config); ProbabilisticFilter filter(false); int fid_num = 10000, slot_occurrence_threshold = 7; int count1 = 0, count2 = 0, count3 = 0; float tolerance_ratio = 0.01f; for (int i = 0; i < fid_num; ++i) { if (!filter.ShouldBeFiltered(i, 1, slot_occurrence_threshold, table.get())) { ++count1; } if (!filter.ShouldBeFiltered(i, 2, slot_occurrence_threshold, table.get())) { ++count2; } if (!filter.ShouldBeFiltered(i, 3, slot_occurrence_threshold, nullptr)) { ++count3; } } EXPECT_NEAR(fid_num / slot_occurrence_threshold, count1, fid_num * tolerance_ratio); EXPECT_NEAR(fid_num / slot_occurrence_threshold * 2, count2, fid_num * tolerance_ratio); EXPECT_EQ(fid_num, count3); } } // namespace } // namespace hash_filter } // namespace monolith ================================================ FILE: monolith/native_training/runtime/hash_filter/sliding_hash_filter.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "absl/strings/str_format.h" #include "glog/logging.h" #include "monolith/native_training/runtime/hash_filter/sliding_hash_filter.h" namespace monolith { namespace hash_filter { using ::monolith::hash_table::HashFilterSplitDataDump; using ::monolith::hash_table::HashFilterSplitMetaDump; using ::monolith::hash_table::SlidingHashFilterMetaDump; SlidingHashFilter::SlidingHashFilter(size_t capacity, int split_num) : split_num_(split_num), head_(0), head_increment_(0), failure_count_(0) { capacity_ = capacity; if (capacity_ < 300) capacity_ = 300; if (split_num < 5) split_num = 5; filters_.resize(split_num); max_backward_step_ = split_num - max_forward_step_; // max_forward_step_ - 1 blocks are kept empty for looing forward size_t split_capacity = get_split_capacity(capacity_, split_num); for (auto& filter : filters_) { filter.reset(new HashFilter(split_capacity, 1.2)); } } SlidingHashFilter::SlidingHashFilter(const SlidingHashFilter& other) : max_backward_step_(other.max_backward_step_), filters_(other.filters_.size()), head_(other.head_), head_increment_(other.head_increment_), failure_count_(other.failure_count_) { capacity_ = other.capacity_; if (&other == this) { return; } for (size_t i = 0; i != filters_.size(); ++i) { filters_[i].reset(other.filters_[i]->clone()); } } uint32_t SlidingHashFilter::add(FID fid, uint32_t count) { uint32_t old_count = 0; // Look forward to find current value HashFilterIterator curr_iter = bidirectional_find( head_, max_forward_step_, fid, false, std::bind(&SlidingHashFilter::next, this, std::placeholders::_1)); if (curr_iter.valid()) { if (!curr_iter.empty()) { return curr_iter.add(count); } } else { failure_count_ += 1; return HashFilter::max_count(); } // Look backward to find old value HashFilterIterator old_iter = bidirectional_find( prev(head_), std::min(head_increment_, max_backward_step_), fid, true, std::bind(&SlidingHashFilter::prev, this, std::placeholders::_1)); if (old_iter.valid()) { old_count = old_iter.get(); curr_iter.add(old_count + count); } else { curr_iter.add(count); } if (filters_[head_]->full()) { head_ = next(head_); head_increment_ += 1; filters_[(head_ + max_forward_step_ - 1) % filters_.size()]->async_clear(); } return old_count; } uint32_t SlidingHashFilter::get(FID fid) const { // Look forward to find current value HashFilterIterator curr_iter = bidirectional_find( head_, max_forward_step_, fid, false, std::bind(&SlidingHashFilter::next, this, std::placeholders::_1)); if (curr_iter.valid()) { if (!curr_iter.empty()) { return curr_iter.get(); } } else { return HashFilter::max_count(); } // Look backward to find old value HashFilterIterator iter = bidirectional_find( prev(head_), std::min(head_increment_, max_backward_step_), fid, true, std::bind(&SlidingHashFilter::prev, this, std::placeholders::_1)); if (iter.valid()) { return iter.get(); } else { return 0; } } HashFilterIterator SlidingHashFilter::bidirectional_find( size_t begin, int max_look, FID fid, bool exhaust, std::function go) const { size_t index = begin; for (int i = 0; i != max_look; ++i) { HashFilterIterator iter = filters_[index]->find(fid, MAX_STEP); // Looking forward only needs a valid position // Looking backward needs a non-empty position if (iter.valid() && (!exhaust || (exhaust && !iter.empty()))) return iter; index = go(index); } return HashFilterIterator(); } size_t SlidingHashFilter::estimated_total_element() const { size_t result = 0; for (auto& filter : filters_) { result += filter->estimated_total_element(); } return result; } SlidingHashFilter* SlidingHashFilter::clone() const { return new SlidingHashFilter(*this); } bool SlidingHashFilter::operator==(const SlidingHashFilter& other) const { if (!(max_forward_step_ == other.max_forward_step_ && head_ == other.head_ && head_increment_ % filters_.size() == other.head_increment_ % filters_.size() && capacity_ == other.capacity_ && filters_.size() == other.filters_.size())) { return false; } for (size_t i = 0; i != filters_.size(); ++i) { if (!(*filters_[i] == *other.filters_[i])) { return false; } } return true; } SlidingHashFilterMetaDump SlidingHashFilter::GetMetaDump() const { SlidingHashFilterMetaDump dump; dump.set_split_num(split_num_); dump.set_max_forward_step(max_forward_step_); dump.set_max_backward_step(max_backward_step_); dump.set_max_step(MAX_STEP); dump.set_head(head_); dump.set_head_increment(head_increment_); dump.set_failure_count(failure_count_); return dump; } // Saves the data. void SlidingHashFilter::Save( int split_idx, std::function write_meta_fn, std::function write_data_fn) const { auto meta_dump = GetMetaDump(); filters_[split_idx]->Save(meta_dump, std::move(write_meta_fn), std::move(write_data_fn)); } void SlidingHashFilter::ValidateData(uint32_t expect_value, uint32_t ckpt_value, const char* msg) { if (ckpt_value != expect_value) { throw std::runtime_error( absl::StrFormat("%s: %d does't match with : %d read from hash " "filter checkpoint file.", msg, expect_value, ckpt_value)); } } void SlidingHashFilter::RestoreMetaDump(const HashFilterSplitMetaDump& dump) { auto& sliding_hash_filter_meta_dump = dump.sliding_hash_filter_meta(); ValidateData(split_num_, sliding_hash_filter_meta_dump.split_num(), "split_num"); ValidateData(max_forward_step_, sliding_hash_filter_meta_dump.max_forward_step(), "max_forward_step"); ValidateData(max_backward_step_, sliding_hash_filter_meta_dump.max_backward_step(), "max_backward_step"); ValidateData(MAX_STEP, sliding_hash_filter_meta_dump.max_step(), "max_step"); head_ = sliding_hash_filter_meta_dump.head(); head_increment_ = sliding_hash_filter_meta_dump.head_increment(); failure_count_ = sliding_hash_filter_meta_dump.failure_count(); } // Restores the data from get_fn. void SlidingHashFilter::Restore( int split_idx, std::function get_meta_fn, std::function get_data_fn) { HashFilterSplitMetaDump meta_dump; get_meta_fn(&meta_dump); RestoreMetaDump(meta_dump); filters_[split_idx]->Restore(meta_dump, std::move(get_data_fn)); } } // namespace hash_filter } // namespace monolith ================================================ FILE: monolith/native_training/runtime/hash_filter/sliding_hash_filter.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_FILTER_SLIDING_HASH_FILTER_H_ #define MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_FILTER_SLIDING_HASH_FILTER_H_ #include #include "monolith/native_training/runtime/hash_filter/filter.h" #include "monolith/native_training/runtime/hash_filter/hash_filter.h" namespace monolith { namespace hash_filter { class SlidingHashFilter : public Filter { public: SlidingHashFilter(size_t capacity, int split_num); SlidingHashFilter(const SlidingHashFilter& other); uint32_t add(FID fid, uint32_t count); uint32_t get(FID fid) const; uint32_t size_mb() const { return filters_[0]->size_mb() * filters_.size(); } static size_t get_split_capacity(size_t capacity, int split_num) { return capacity / (split_num - max_forward_step_ + 1); } static uint32_t size_byte(size_t capacity, int split_num) { return HashFilter::size_byte( get_split_capacity(capacity, split_num), 1.2) * split_num; } size_t estimated_total_element() const; size_t failure_count() const { return failure_count_; } size_t split_num() const { return split_num_; } bool ShouldBeFiltered( int64_t fid, int64_t count, int64_t slot_occurrence_threshold, const monolith::hash_table::EmbeddingHashTableInterface* table) override { if (slot_occurrence_threshold <= 0) { return false; } return this->add(fid, count) < slot_occurrence_threshold; } SlidingHashFilter& operator=(SlidingHashFilter const&) = delete; SlidingHashFilter* clone() const; bool operator==(SlidingHashFilter const&) const; // Saves the data. virtual void Save( int split_idx, std::function write_meta_fn, std::function write_data_fn) const; // Restores the data from get_fn. virtual void Restore( int split_idx, std::function get_meta_fn, std::function get_data_fn); private: size_t prev(size_t index) const { if (index == 0) return filters_.size() - 1; return index - 1; } size_t next(size_t index) const { if (index == filters_.size() - 1) return 0; return index + 1; } HashFilterIterator bidirectional_find( size_t begin, int max_look, FID fid, bool exhaust, std::function go) const; ::monolith::hash_table::SlidingHashFilterMetaDump GetMetaDump() const; void RestoreMetaDump( const ::monolith::hash_table::HashFilterSplitMetaDump& dump); void ValidateData(uint32_t expect_value, uint32_t ckpt_value, const char* msg); size_t split_num_; constexpr static int max_forward_step_ = 2; int max_backward_step_; constexpr static int MAX_STEP = 16; std::vector>> filters_; size_t head_; int head_increment_; size_t failure_count_; }; } // namespace hash_filter } // namespace monolith #endif // MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_FILTER_SLIDING_HASH_FILTER_H_ ================================================ FILE: monolith/native_training/runtime/hash_filter/sliding_hash_filter_test.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/runtime/hash_filter/sliding_hash_filter.h" #include #include #include "gtest/gtest.h" #include "monolith/native_training/runtime/hash_filter/types.h" namespace monolith { namespace hash_filter { namespace { static void test_simple_bloom(size_t key_num) { SlidingHashFilter counter(key_num, 10); for (uint32_t i = 0; i <= HashFilter::max_count(); ++i) { EXPECT_EQ(i, counter.add(1, 1)); } ASSERT_EQ(HashFilter::max_count(), counter.add(1, 1)); } TEST(SlidingHashFilterTest, test_simple) { test_simple_bloom(1); test_simple_bloom(3); test_simple_bloom(100); } TEST(SlidingHashFilterTest, test_count) { SlidingHashFilter filter(1000000, 10); filter.add(std::rand(), 2); EXPECT_EQ(1llu, filter.estimated_total_element()); size_t key_number = 10; for (size_t i = 0; i != key_number; ++i) { filter.add(std::rand(), 2); } EXPECT_EQ(key_number + 1, filter.estimated_total_element()); SlidingHashFilter filter2(1000000, 10); filter2.add(10000002961562801052lu, 1); filter2.add(10000002961562801052lu, 20); filter2.add(10000002961562801052lu, 1); EXPECT_EQ(1llu, filter2.estimated_total_element()); std::unique_ptr filter3(filter2.clone()); EXPECT_TRUE(filter2 == *filter3); } template static void check_conflict_rate(SlidingHashFilter& filter, const Container& map_counter, double expected_rate) { int error_counter = 0; for (auto iter = map_counter.begin(); iter != map_counter.end(); ++iter) { if (filter.get(iter->first) != iter->second) error_counter += 1; } double error_rate = error_counter / double(map_counter.size()); std::cout << "expect:" << expected_rate << " actual:" << error_rate << std::endl; EXPECT_NEAR(expected_rate, error_rate, expected_rate / 2); EXPECT_GT(map_counter.size() / 10000, filter.failure_count()); } static void compare_to_unordered_map(int key_number, double expected_rate, size_t capacity) { std::srand(capacity); int split_num = 10; SlidingHashFilter filter(capacity, split_num); std::unordered_map map_counter; for (int i = 0; i != key_number; ++i) { int num = std::rand(); if (map_counter[num] < HashFilter::max_count() - 1) { map_counter[num] += 2; filter.add(num, 2); } } check_conflict_rate(filter, map_counter, expected_rate); } TEST(SlidingHashFilterTest, compare_to_unordered_map) { compare_to_unordered_map(1000000, 0.00908, 1000000); compare_to_unordered_map(1000000, 0.50, 500000); } TEST(SlidingHashFilterTest, SkipZeroThresholdFeatures) { SlidingHashFilter filter(1000000, 10); for (uint32_t i = 0; i < 5; ++i) { int64_t fid_with_zero_threshold = i; EXPECT_FALSE(filter.ShouldBeFiltered(fid_with_zero_threshold, 1, 0, nullptr /* table */)); int64_t normal_fid = i * 2; EXPECT_TRUE(filter.ShouldBeFiltered(normal_fid, 1, 1, nullptr /* table */)); } // Only the normal_fids can be added to the filter. EXPECT_EQ(5llu, filter.estimated_total_element()); } /* vim: set expandtab ts=4 sw=4 sts=4 tw=100: */ } // namespace } // namespace hash_filter } // namespace monolith ================================================ FILE: monolith/native_training/runtime/hash_filter/types.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_FILTER_TYPES_H_ #define MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_FILTER_TYPES_H_ #include namespace monolith { namespace hash_filter { using FID = uint64_t; } // namespace hash_filter } // namesapce monolith #endif // MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_FILTER_TYPES_H_ ================================================ FILE: monolith/native_training/runtime/hash_table/BUILD ================================================ load("@rules_cc//cc:defs.bzl", "cc_library", "cc_proto_library", "cc_test") load("@com_google_protobuf//:protobuf.bzl", "py_proto_library") load("@rules_proto//proto:defs.bzl", "proto_library") load( "@org_tensorflow//tensorflow:tensorflow.bzl", "if_cuda_is_configured_compat", "tf_gpu_kernel_library_allow_except", ) package(default_visibility = ["//monolith/native_training/runtime:__subpackages__"]) proto_library( name = "embedding_hash_table_proto", srcs = ["embedding_hash_table.proto"], deps = [ "//monolith/native_training/runtime/hash_table/compressor:float_compressor_proto", "//monolith/native_training/runtime/hash_table/initializer:initializer_config_proto", "//monolith/native_training/runtime/hash_table/optimizer:optimizer_proto", ], ) cc_proto_library( name = "embedding_hash_table_cc_proto", visibility = ["//visibility:public"], deps = [":embedding_hash_table_proto"], ) py_proto_library( name = "embedding_hash_table_py_proto", srcs = ["embedding_hash_table.proto"], srcs_version = "PY2AND3", visibility = ["//visibility:public"], deps = [ "//monolith/native_training/runtime/hash_table/compressor:float_compressor_py_proto", "//monolith/native_training/runtime/hash_table/initializer:initializer_config_py_proto", "//monolith/native_training/runtime/hash_table/optimizer:optimizer_py_proto", ], ) cc_library( name = "entry_accessor", srcs = ["entry_accessor.cc"], hdrs = [ "entry_accessor.h", "entry_accessor_decorator.h", "quantized_entry_accessor.h", ], deps = [ ":embedding_hash_table_cc_proto", ":utils", "//monolith/native_training/runtime/hash_table/compressor:float_compressor", "//monolith/native_training/runtime/hash_table/initializer:initializer_combination", "//monolith/native_training/runtime/hash_table/initializer:initializer_factory", "//monolith/native_training/runtime/hash_table/optimizer:optimizer_combination", "//monolith/native_training/runtime/hash_table/optimizer:optimizer_factory", "//monolith/native_training/runtime/hash_table/retriever:fake_quant_retriever", "//monolith/native_training/runtime/hash_table/retriever:hash_net_retriever", "//monolith/native_training/runtime/hash_table/retriever:raw_retriever", "//monolith/native_training/runtime/hash_table/retriever:retriever_combination", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/status", "@com_google_absl//absl/time", "@com_google_absl//absl/types:span", "@com_google_glog//:glog", ], ) cc_test( name = "entry_accessor_test", srcs = ["entry_accessor_test.cc"], deps = [ ":embedding_hash_table_cc_proto", ":entry_accessor", "@com_google_googletest//:gtest_main", ], ) cc_library( name = "quantized_entry_accessor", hdrs = [ "entry_accessor_decorator.h", "quantized_entry_accessor.h", ], deps = [ ":entry_accessor", "//monolith/native_training/runtime/hash_table/compressor:fake_quantizer", ], ) cc_test( name = "quantized_entry_accessor_test", srcs = ["quantized_entry_accessor_test.cc"], deps = [ ":quantized_entry_accessor", "@com_google_googletest//:gtest_main", ], ) tf_gpu_kernel_library_allow_except( name = "embedding_hash_table_interface", srcs = [], hdrs = ["embedding_hash_table_interface.h"], deps = [ ":embedding_hash_table_cc_proto", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", ], ) cc_library( name = "embedding_hash_table_factory_internal_deps", ) tf_gpu_kernel_library_allow_except( name = "embedding_hash_table_factory", srcs = ["embedding_hash_table_factory.cc"], hdrs = ["embedding_hash_table_factory.h"], deps = [ ":embedding_hash_table_cc_proto", ":embedding_hash_table_factory_internal_deps", ":embedding_hash_table_interface", ":entry_accessor", "//monolith/native_training/runtime/hash_table/cuckoohash:cuckoo_embedding_hash_table", "@com_google_absl//absl/strings:str_format", ], ) cc_library( name = "embedding_hash_table_test", hdrs = ["embedding_hash_table_test.h"], deps = [ ":embedding_hash_table_cc_proto", ":embedding_hash_table_factory", ":embedding_hash_table_interface", "@com_google_googletest//:gtest_main", ], ) cc_binary( name = "hash_table_benchmark", srcs = ["hash_table_benchmark.cc"], deps = [ ":embedding_hash_table_cc_proto", ":embedding_hash_table_factory", ":embedding_hash_table_interface", "//monolith/native_training/runtime/concurrency:thread_pool", "@com_github_google_benchmark//:benchmark", "@com_google_absl//absl/random", "@com_google_glog//:glog", ], ) cc_library( name = "utils", hdrs = ["utils.h"], ) cc_library( name = "entry_defs", hdrs = ["entry_defs.h"], deps = [ "//monolith/native_training/runtime/allocator:block_allocator", ], ) cc_test( name = "entry_defs_test", srcs = ["entry_defs_test.cc"], deps = [ ":entry_defs", "@com_google_googletest//:gtest_main", ], ) ================================================ FILE: monolith/native_training/runtime/hash_table/compressor/BUILD ================================================ load("@rules_cc//cc:defs.bzl", "cc_library", "cc_proto_library", "cc_test") load("@rules_proto//proto:defs.bzl", "proto_library") load("@com_google_protobuf//:protobuf.bzl", "py_proto_library") package(default_visibility = ["//monolith/native_training/runtime/hash_table:__subpackages__"]) proto_library( name = "float_compressor_proto", srcs = ["float_compressor.proto"], ) cc_proto_library( name = "float_compressor_cc_proto", deps = [ ":float_compressor_proto", ], ) py_proto_library( name = "float_compressor_py_proto", srcs = ["float_compressor.proto"], srcs_version = "PY2AND3", visibility = ["//visibility:public"], ) cc_library( name = "float_compressor", srcs = ["float_compressor.cc"], hdrs = ["float_compressor.h"], defines = ["HALF_ENABLE_F16C_INTRINSICS=0"], deps = [ ":float_compressor_cc_proto", ":fake_quantizer", ":hash_net_quantizer", "//third_party/half_sourceforge_net:half", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", ], ) cc_test( name = "float_compressor_test", srcs = ["float_compressor_test.cc"], deps = [ ":float_compressor", "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", ], ) cc_library( name = "fake_quantizer", hdrs = ["fake_quantizer.h"], deps = [], ) cc_test( name = "fake_quantizer_test", srcs = ["fake_quantizer_test.cc"], deps = [ ":fake_quantizer", "@com_google_googletest//:gtest_main", ], ) cc_library( name = "hash_net_quantizer", hdrs = ["hash_net_quantizer.h"], deps = [ ":float_compressor_cc_proto", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", "@com_google_glog//:glog", "@org_tensorflow//tensorflow/core/platform:logging", ], ) cc_test( name = "hash_net_quantizer_test", srcs = ["hash_net_quantizer_test.cc"], deps = [ ":hash_net_quantizer", "@com_google_googletest//:gtest_main", ], ) ================================================ FILE: monolith/native_training/runtime/hash_table/compressor/fake_quantizer.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_COMPRESSOR_FAKE_QUANTIZER_H_ #define MONOLITH_MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_COMPRESSOR_FAKE_QUANTIZER_H_ #include namespace monolith { namespace hash_table { // Quantization Aware Training. // This class quantize a float32 number into an int8_t. // TODO(zhangbiao.david): support specifying min, max, num_bits etc. class FakeQuantizer { public: explicit FakeQuantizer(float r) : r_(r), step_(r_ / kNegativeSlotNum), half_step_(step_ / 2) {} // Quantize a given floating-point number. float Quantize(float f) const { return IntegerToFloat(QuantizeToInteger(f)); } // Quantize a floating-point number into integer representation. int8_t QuantizeToInteger(float f) const { // Round f to nearest float slot. E.g., // Assuming step = 1.0, and f = 3.6, we want nstep = 4. if (std::isnan(f)) { return 0; } if (f >= 0) { f += half_step_; } else { f -= half_step_; } int nstep = f / step_; if (nstep > kPositiveSlotNum) { nstep = kPositiveSlotNum; } else if (nstep < -kNegativeSlotNum) { nstep = -kNegativeSlotNum; } return nstep; } // Restores an integer representation to a floating-point number. float IntegerToFloat(int8_t x) const { return x * step_; } private: static constexpr int kNumBits = sizeof(int8_t) * 8; static constexpr int kSlotNum = 1 << kNumBits; static constexpr int kPositiveSlotNum = kSlotNum / 2 - 1; static constexpr int kNegativeSlotNum = kSlotNum / 2; const float r_; const float step_; const float half_step_; }; } // namespace hash_table } // namespace monolith #endif // MONOLITH_MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_COMPRESSOR_FAKE_QUANTIZER_H_ ================================================ FILE: monolith/native_training/runtime/hash_table/compressor/fake_quantizer_test.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "gtest/gtest.h" #include "gmock/gmock.h" #include "monolith/native_training/runtime/hash_table/compressor/fake_quantizer.h" namespace monolith { namespace hash_table { namespace { using ::testing::Lt; using ::testing::Le; using ::testing::FloatEq; using ::testing::Not; using ::testing::Eq; TEST(FakeQuantizer, Quantization) { FakeQuantizer model(5.0f); EXPECT_THAT(model.Quantize(100.0), Lt(5.0)); // Symmetric EXPECT_THAT(model.Quantize(0.0), 0.0f); const float kStep = 5.0f / 128; // Make sure quantization result is small enough. EXPECT_THAT(std::abs(model.Quantize(3.5) - 3.5), Lt(kStep)); // Make sure round works correctly. EXPECT_THAT(model.Quantize(kStep * 1.4), kStep); EXPECT_THAT(model.Quantize(kStep * 1.6), kStep * 2); EXPECT_THAT(model.Quantize(-kStep * 1.4), -kStep); EXPECT_THAT(model.Quantize(-kStep * 1.6), -kStep * 2); EXPECT_THAT(model.Quantize(std::nan("")), 0); } } // namespace } // namespace hash_table } // namespace monolith ================================================ FILE: monolith/native_training/runtime/hash_table/compressor/float_compressor.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/runtime/hash_table/compressor/float_compressor.h" #include #include "absl/strings/str_format.h" #include "monolith/native_training/runtime/hash_table/compressor/fake_quantizer.h" #include "monolith/native_training/runtime/hash_table/compressor/hash_net_quantizer.h" #include "third_party/half_sourceforge_net/half.hpp" namespace monolith { namespace hash_table { namespace { class FloatCompressorBase : public FloatCompressorInterface { public: FloatCompressorBase(int dim_size, int64_t size_bytes, int64_t uncompressed_size_bytes) : dim_size_(dim_size), size_bytes_(size_bytes), uncompressed_size_bytes_(uncompressed_size_bytes) {} // Use final to inline this when possible. int64_t SizeBytes() const final { return size_bytes_; } int64_t UncompressedSizeBytes() const final { return uncompressed_size_bytes_; } int DimSize() const final { return dim_size_; } private: int dim_size_; int64_t size_bytes_; int64_t uncompressed_size_bytes_; }; class Fp32FloatCompressor final : public FloatCompressorBase { public: explicit Fp32FloatCompressor(const FloatCompressorConfig::Fp32& config) : FloatCompressorBase(config.dim_size(), config.dim_size() * sizeof(float), config.dim_size() * sizeof(float)) {} std::string DebugString() const override { return absl::StrFormat("Fp32(D=%d)", FloatCompressorBase::DimSize()); } void Encode(absl::Span num, void* compressed) const override { auto* f = reinterpret_cast(compressed); for (int i = 0; i < DimSize(); ++i) { f[i] = num[i]; } } void Decode(const void* compressed, absl::Span num) const override { const auto* f = reinterpret_cast(compressed); for (int i = 0; i < DimSize(); ++i) { num[i] = f[i]; } } }; // Converts a float to fp16 class Fp16FloatCompressor final : public FloatCompressorBase { public: explicit Fp16FloatCompressor(const FloatCompressorConfig::Fp16& config) : FloatCompressorBase(config.dim_size(), config.dim_size() * sizeof(int16_t), config.dim_size() * sizeof(float)) {} std::string DebugString() const override { return absl::StrFormat("Fp16(D=%d)", FloatCompressorBase::DimSize()); } void Encode(absl::Span num, void* compressed) const override { auto* i16 = reinterpret_cast(compressed); for (int i = 0; i < DimSize(); ++i) { half_float::half x(num[i]); i16[i] = *reinterpret_cast(&x); } } void Decode(const void* compressed, absl::Span num) const override { const auto* i16 = reinterpret_cast(compressed); for (int i = 0; i < DimSize(); ++i) { num[i] = *reinterpret_cast(&i16[i]); } } }; // Converts a float to fixed range int8. class FixedR8FloatCompressor final : public FloatCompressorBase { public: explicit FixedR8FloatCompressor(const FloatCompressorConfig::FixedR8& config) : FloatCompressorBase(config.dim_size(), config.dim_size() * sizeof(int8_t), config.dim_size() * sizeof(float)), fake_quantizer_(config.r()) { LOG_EVERY_N(INFO, 100) << "FixedR8FloatCompressor config: " << config.DebugString(); } std::string DebugString() const override { return absl::StrFormat("FixedR8(D=%d)", FloatCompressorBase::DimSize()); } void Encode(absl::Span num, void* compressed) const override { auto* i8 = reinterpret_cast(compressed); for (int i = 0; i < DimSize(); ++i) { i8[i] = fake_quantizer_.QuantizeToInteger(num[i]); } } void Decode(const void* compressed, absl::Span num) const override { const auto* i8 = reinterpret_cast(compressed); for (int i = 0; i < DimSize(); ++i) { num[i] = fake_quantizer_.IntegerToFloat(i8[i]); } } private: FakeQuantizer fake_quantizer_; }; // Converts a float to one bit. // We use an int8_t for testing stage class OneBitFloatCompressor final : public FloatCompressorBase { public: explicit OneBitFloatCompressor(const FloatCompressorConfig::OneBit& config) : FloatCompressorBase(config.dim_size(), config.dim_size() * sizeof(int8_t), config.dim_size() * sizeof(float)), hash_net_quantizer_(config) {} std::string DebugString() const override { return absl::StrFormat("OneBit(D=%d)", FloatCompressorBase::DimSize()); } void Encode(absl::Span num, void* compressed) const override { auto* i8 = reinterpret_cast(compressed); for (int i = 0; i < DimSize(); ++i) { i8[i] = hash_net_quantizer_.Forward(num[i]) > 0 ? 1 : -1; } } void Decode(const void* compressed, absl::Span num) const override { float amplitude = hash_net_quantizer_.GetConfig().amplitude(); const auto* i8 = reinterpret_cast(compressed); for (int i = 0; i < DimSize(); ++i) { num[i] = static_cast(i8[i]) * amplitude; } } private: HashNetQuantizer hash_net_quantizer_; }; class CombinedFloatCompressor final : public FloatCompressorBase { public: CombinedFloatCompressor(std::unique_ptr compressor1, std::unique_ptr compressor2) : FloatCompressorBase(compressor1->DimSize() + compressor2->DimSize(), compressor1->SizeBytes() + compressor2->SizeBytes(), compressor1->UncompressedSizeBytes() + compressor2->UncompressedSizeBytes()), compressor1_(std::move(compressor1)), compressor2_(std::move(compressor2)), compressor1_dim_size_(compressor1_->DimSize()), compressor1_size_bytes_(compressor1_->SizeBytes()) {} std::string DebugString() const override { return absl::StrFormat("%s|%s", compressor1_->DebugString(), compressor2_->DebugString()); } void Encode(absl::Span num, void* compressed) const override { absl::Span num1 = num.subspan(0, compressor1_dim_size_); compressor1_->Encode(num1, compressed); absl::Span num2 = num.subspan(compressor1_dim_size_); void* compressed2 = reinterpret_cast(compressed) + compressor1_size_bytes_; compressor2_->Encode(num2, compressed2); } void Decode(const void* compressed, absl::Span num) const override { absl::Span num1 = num.subspan(0, compressor1_dim_size_); compressor1_->Decode(compressed, num1); absl::Span num2 = num.subspan(compressor1_dim_size_); const void* compressed2 = reinterpret_cast(compressed) + compressor1_size_bytes_; compressor2_->Decode(compressed2, num2); } private: std::unique_ptr compressor1_; std::unique_ptr compressor2_; const int compressor1_dim_size_; const int64_t compressor1_size_bytes_; }; } // namespace std::unique_ptr NewFloatCompressor( FloatCompressorConfig config) { switch (config.type_case()) { case FloatCompressorConfig::kFp32: return std::make_unique( std::move(*config.mutable_fp32())); case FloatCompressorConfig::kFp16: return std::make_unique( std::move(*config.mutable_fp16())); case FloatCompressorConfig::kFixedR8: return std::make_unique( std::move(*config.mutable_fixed_r8())); case FloatCompressorConfig::kOneBit: return std::make_unique( std::move(*config.mutable_one_bit())); default: throw std::invalid_argument(absl::StrFormat( "Unknown tpye of float compressor. %s", config.ShortDebugString())); } } std::unique_ptr CombineFloatCompressor( std::unique_ptr compressor1, std::unique_ptr compressor2) { return std::make_unique(std::move(compressor1), std::move(compressor2)); } } // namespace hash_table } // namespace monolith ================================================ FILE: monolith/native_training/runtime/hash_table/compressor/float_compressor.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_COMPRESSOR_ENTRY_SERVING_COMPRESSOR #define MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_COMPRESSOR_ENTRY_SERVING_COMPRESSOR #include "absl/types/span.h" #include "monolith/native_training/runtime/hash_table/compressor/float_compressor.pb.h" namespace monolith { namespace hash_table { // Used to compress float number in online serving PS to save the memory. class FloatCompressorInterface { public: virtual ~FloatCompressorInterface() = default; virtual std::string DebugString() const = 0; // How many bytes are required for the compressor. virtual int64_t SizeBytes() const = 0; // How many bytes are required if not compressed. virtual int64_t UncompressedSizeBytes() const = 0; // How many dimensions this compressor support. virtual int DimSize() const = 0; // Encodes a list of floats into compressed. virtual void Encode(absl::Span num, void* compressed) const = 0; // Decodes a list of Int into a list of float. virtual void Decode(const void* compressed, absl::Span num) const = 0; }; std::unique_ptr NewFloatCompressor( FloatCompressorConfig config); std::unique_ptr CombineFloatCompressor( std::unique_ptr compressor1, std::unique_ptr compressor2); } // namespace hash_table } // namespace monolith #endif // MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_COMPRESSOR_ENTRY_SERVING_COMPRESSOR ================================================ FILE: monolith/native_training/runtime/hash_table/compressor/float_compressor.proto ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. syntax = "proto2"; package monolith.hash_table; message FloatCompressorConfig { // Not compressed. Useful in the test. message Fp32 { optional int32 dim_size = 1; } // Using half-precision floating-point format. message Fp16 { optional int32 dim_size = 1; } // Corresponding to qat8 in Bytedance PS. message FixedR8 { optional int32 dim_size = 1; optional float r = 2 [default = 1.0]; } // HashNet message OneBit { optional int32 dim_size = 1; optional int64 step_size = 2 [default = 200]; optional float init_scale = 3 [default = 1.0]; optional float max_scale = 4 [default = 10000.0]; optional float amplitude = 5 [default = 0.1]; } oneof type { Fp32 fp32 = 1; Fp16 fp16 = 2; FixedR8 fixed_r8 = 3; OneBit one_bit = 4; } } ================================================ FILE: monolith/native_training/runtime/hash_table/compressor/float_compressor_test.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/runtime/hash_table/compressor/float_compressor.h" #include #include "absl/types/span.h" #include "gmock/gmock.h" #include "google/protobuf/text_format.h" #include "gtest/gtest.h" namespace monolith { namespace hash_table { namespace { using ::testing::ElementsAre; using ::testing::FloatNear; using ::testing::Pointwise; std::vector EncodeDecode(const FloatCompressorInterface& compressor, absl::Span num) { auto compressed = std::make_unique(compressor.SizeBytes()); compressor.Encode(num, compressed.get()); std::vector decoded(compressor.DimSize()); compressor.Decode(compressed.get(), absl::MakeSpan(decoded)); return decoded; } std::vector EncodeDecode(const FloatCompressorConfig& config, absl::Span num) { auto compressor = NewFloatCompressor(config); return EncodeDecode(*compressor, num); } FloatCompressorConfig ParseConfig(const std::string& text) { FloatCompressorConfig c; GOOGLE_CHECK(google::protobuf::TextFormat::ParseFromString(text, &c)); return c; } TEST(Fp32FloatCompressorTest, Basic) { EXPECT_THAT( EncodeDecode(ParseConfig(R"(fp32 { dim_size: 3})"), {0.1, 0.2, 10000.0}), ElementsAre(0.1, 0.2, 10000.0)); } TEST(Fp16FloatCompressorTest, Basic) { EXPECT_THAT( EncodeDecode(ParseConfig(R"(fp16 { dim_size: 3})"), {0.1, 0.2, 10000.0}), Pointwise(FloatNear(1e-4), {0.1, 0.2, 10000.0})); } TEST(FixedR8FloatCompressorTest, Basic) { const float kStep = 5.0f / 128; EXPECT_THAT(EncodeDecode(ParseConfig(R"(fixed_r8 { dim_size: 3 r : 5})"), {100.0, 0.0, 3.5}), Pointwise(FloatNear(kStep), {5.0, 0.0, 3.5})); EXPECT_THAT( EncodeDecode(ParseConfig(R"(fixed_r8 { dim_size: 4 r : 5})"), {kStep * 1.4f, kStep * 1.6f, -kStep * 1.4f, -kStep * 1.6f}), ElementsAre(kStep, kStep * 2, -kStep, -kStep * 2)); } TEST(OneBitFloatCompressorTest, Basic) { EXPECT_THAT( EncodeDecode( ParseConfig(R"(one_bit { dim_size: 7 step_size : 5 amplitude: 1.0})"), {100.0, 0.1, 0.00001, 0.0, -0.00001, -0.1, -100.0}), Pointwise(FloatNear(0.1f), {1.0, 1.0, 1.0, -1.0, -1.0, -1.0, -1.0})); } TEST(CombinedFloatCompressorTest, Basic) { auto compressor1 = NewFloatCompressor(ParseConfig(R"(fp16 { dim_size: 1 })")); auto compressor2 = NewFloatCompressor(ParseConfig(R"(fp16 { dim_size: 2 })")); auto compressor = CombineFloatCompressor(std::move(compressor1), std::move(compressor2)); EXPECT_THAT(EncodeDecode(*compressor, {1.0, 2.0, 3.0, 4.0}), Pointwise(FloatNear(1e-4), {1.0, 2.0, 3.0})); } } // namespace } // namespace hash_table } // namespace monolith ================================================ FILE: monolith/native_training/runtime/hash_table/compressor/hash_net_quantizer.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_COMPRESSOR_HASH_NET_QUANTIZER_H_ #define MONOLITH_MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_COMPRESSOR_HASH_NET_QUANTIZER_H_ #include #include #include #include #include "absl/strings/str_format.h" #include "absl/synchronization/mutex.h" #include "glog/logging.h" #include "tensorflow/core/platform/logging.h" #include "monolith/native_training/runtime/hash_table/compressor/float_compressor.pb.h" namespace monolith { namespace hash_table { class HashNetQuantizer { public: explicit HashNetQuantizer(FloatCompressorConfig_OneBit config) : config_(std::move(config)) { scale_ = config_.init_scale(); LOG(INFO) << absl::StrFormat("HashNetQuantizer: %s, scale = %.6f", config_.ShortDebugString(), scale_.load()); } float Forward(float f) const { return config_.amplitude() * std::tanh(scale_ * f); } void Backward(float num, float* grad, int64_t global_step) const { if (global_step % config_.step_size() == 0) { scale_ = config_.init_scale() * std::pow(1.f + kGamma * static_cast(global_step), kPower); scale_ = std::min(scale_.load(), config_.max_scale()); LOG_EVERY_N_SEC(INFO, 60) << absl::StrFormat( "HashNetQuantizer: %s, scale = %.6f, global_step = %ld", config_.ShortDebugString(), scale_, global_step); } float y = std::tanh(scale_ * num); *grad *= config_.amplitude() * scale_ * (1.f - y * y); } float GetScale() const { return scale_; } const FloatCompressorConfig_OneBit& GetConfig() const { return config_; } private: static constexpr float kGamma = 0.005; static constexpr float kPower = 0.5; mutable std::atomic scale_; FloatCompressorConfig_OneBit config_; }; } // namespace hash_table } // namespace monolith #endif // MONOLITH_MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_COMPRESSOR_HASH_NET_QUANTIZER_H_ ================================================ FILE: monolith/native_training/runtime/hash_table/compressor/hash_net_quantizer_test.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/runtime/hash_table/compressor/hash_net_quantizer.h" #include "gmock/gmock.h" #include "gtest/gtest.h" namespace monolith { namespace hash_table { namespace { TEST(HashNetQuantizer, ForwardAndBackward) { FloatCompressorConfig_OneBit config; config.set_step_size(1000); HashNetQuantizer model(config); EXPECT_FLOAT_EQ(model.GetScale(), 1.0f); EXPECT_FLOAT_EQ(model.Forward(1.0f), 0.07615941f); EXPECT_FLOAT_EQ(model.Forward(2.0f), 0.09640275f); float grad = 1.0f, grad2 = 2.0f; model.Backward(2.0f, &grad, 0); model.Backward(2.0f, &grad2, 999); EXPECT_FLOAT_EQ(grad, 0.00706508f); EXPECT_FLOAT_EQ(grad2, 0.01413016f); grad = 100.0f, grad2 = 200.0f; model.Backward(2.0f, &grad, 1000); model.Backward(2.0f, &grad2, 1001); EXPECT_FLOAT_EQ(grad, 0.005442613f); EXPECT_FLOAT_EQ(grad2, 0.01088523f); EXPECT_FLOAT_EQ(model.Forward(2.0f), 0.09998888f); EXPECT_FLOAT_EQ(model.GetScale(), 2.44948974f); } } // namespace } // namespace hash_table } // namespace monolith ================================================ FILE: monolith/native_training/runtime/hash_table/cuckoohash/BUILD ================================================ load("@rules_cc//cc:defs.bzl", "cc_binary", "cc_library", "cc_test") package(default_visibility = ["//monolith/native_training/runtime:__subpackages__"]) cc_library( name = "cuckoohash", hdrs = [ "bucket_container.hpp", "cuckoohash_config.hpp", "cuckoohash_map.hpp", "cuckoohash_util.hpp", ], visibility = [ "//monolith/feature_engineering/runtime:__subpackages__", "//monolith/native_training/runtime:__subpackages__", ], deps = [ "//monolith/native_training/runtime/hash_table:embedding_hash_table_interface", "@com_google_absl//absl/container:flat_hash_map", "@com_google_glog//:glog", ], ) cc_library( name = "cuckoo_embedding_hash_table", srcs = ["cuckoo_embedding_hash_table.cc"], hdrs = ["cuckoo_embedding_hash_table.h"], deps = [ ":cuckoohash", "//monolith/native_training/runtime/allocator:block_allocator", "//monolith/native_training/runtime/common:linalg_utils", "//monolith/native_training/runtime/hash_table:embedding_hash_table_cc_proto", "//monolith/native_training/runtime/hash_table:embedding_hash_table_interface", "//monolith/native_training/runtime/hash_table:entry_accessor", "//monolith/native_training/runtime/hash_table:entry_defs", # TODO(zhen.li1): refactor the methods in this experimental library to normal # library. "//monolith/native_training/data/training_instance:reader_util", ], ) cc_test( name = "cuckoo_embedding_hash_table_test", srcs = ["cuckoo_embedding_hash_table_test.cc"], deps = [ "//monolith/native_training/runtime/hash_table:embedding_hash_table_test", ], ) cc_binary( name = "cuckoo_embedding_hash_table_benchmark", srcs = ["cuckoo_embedding_hash_table_benchmark.cc"], deps = [ "//monolith/native_training/runtime/concurrency:thread_pool", "//monolith/native_training/runtime/hash_table:embedding_hash_table_cc_proto", "//monolith/native_training/runtime/hash_table:embedding_hash_table_factory", "//monolith/native_training/runtime/hash_table:embedding_hash_table_interface", "@com_github_google_benchmark//:benchmark", "@com_google_absl//absl/random", "@com_google_glog//:glog", ], ) ================================================ FILE: monolith/native_training/runtime/hash_table/cuckoohash/CUCKOO_ORIGINAL_LICENSE ================================================ Copyright (C) 2013, Carnegie Mellon University and Intel Corporation Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. --------------------------- The third-party libraries have their own licenses, as detailed in their source files. ================================================ FILE: monolith/native_training/runtime/hash_table/cuckoohash/bucket_container.hpp ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef _MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_CUCKOOHASH_BUCKET_CONTAINER_HPP #define _MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_CUCKOOHASH_BUCKET_CONTAINER_HPP #include #include #include #include #include // NOLINT #include #include #include #include "monolith/native_training/runtime/hash_table/cuckoohash/cuckoohash_util.hpp" namespace libcuckoo { /** * bucket_container manages storage of key-value pairs for the table. * It stores the items inline in uninitialized memory, and keeps track of which * slots have live data and which do not. It also stores a partial hash for * each live key. It is sized by powers of two. * * @tparam Key type of keys in the table * @tparam T type of values in the table * @tparam Allocator type of key-value pair allocator * @tparam Partial type of partial keys * @tparam SLOT_PER_BUCKET number of slots for each bucket in the table */ template class bucket_container { public: using key_type = Key; using mapped_type = T; using value_type = std::pair; private: using traits_ = typename std::allocator_traits< Allocator>::template rebind_traits; public: using allocator_type = typename traits_::allocator_type; using partial_t = Partial; using size_type = typename traits_::size_type; using reference = value_type &; using const_reference = const value_type &; using pointer = typename traits_::pointer; using const_pointer = typename traits_::const_pointer; /* * The bucket type holds SLOT_PER_BUCKET key-value pairs, along with their * partial keys and occupancy info. It uses aligned_storage arrays to store * the keys and values to allow constructing and destroying key-value pairs * in place. The lifetime of bucket data should be managed by the container. * It is the user's responsibility to confirm whether the data they are * accessing is live or not. */ class bucket { public: bucket() noexcept : occupied_() {} const value_type &kvpair(size_type ind) const { return *static_cast( static_cast(&values_[ind])); } value_type &kvpair(size_type ind) { return *static_cast(static_cast(&values_[ind])); } const key_type &key(size_type ind) const { return storage_kvpair(ind).first; } key_type &&movable_key(size_type ind) { return std::move(storage_kvpair(ind).first); } const mapped_type &mapped(size_type ind) const { return storage_kvpair(ind).second; } mapped_type &mapped(size_type ind) { return storage_kvpair(ind).second; } partial_t partial(size_type ind) const { return partials_[ind]; } partial_t &partial(size_type ind) { return partials_[ind]; } bool occupied(size_type ind) const { return occupied_[ind]; } bool &occupied(size_type ind) { return occupied_[ind]; } private: friend class bucket_container; using storage_value_type = std::pair; const storage_value_type &storage_kvpair(size_type ind) const { return *static_cast( static_cast(&values_[ind])); } storage_value_type &storage_kvpair(size_type ind) { return *static_cast( static_cast(&values_[ind])); } std::array::type, SLOT_PER_BUCKET> values_; std::array partials_; std::array occupied_; }; bucket_container(size_type hp, const allocator_type &allocator) : allocator_(allocator), bucket_allocator_(allocator), hashpower_(hp), buckets_(bucket_allocator_.allocate(size())) { // The bucket default constructor is nothrow, so we don't have to // worry about dealing with exceptions when constructing all the // elements. static_assert(std::is_nothrow_constructible::value, "bucket_container requires bucket to be nothrow " "constructible"); for (size_type i = 0; i < size(); ++i) { traits_::construct(allocator_, &buckets_[i]); } } ~bucket_container() noexcept { destroy_buckets(); } bucket_container(const bucket_container &bc) : allocator_( traits_::select_on_container_copy_construction(bc.allocator_)), bucket_allocator_(allocator_), hashpower_(bc.hashpower()), buckets_(transfer(bc.hashpower(), bc, std::false_type())) {} bucket_container(const bucket_container &bc, const allocator_type &a) : allocator_(a), bucket_allocator_(allocator_), hashpower_(bc.hashpower()), buckets_(transfer(bc.hashpower(), bc, std::false_type())) {} bucket_container(bucket_container &&bc) : allocator_(std::move(bc.allocator_)), bucket_allocator_(allocator_), hashpower_(bc.hashpower()), buckets_(std::move(bc.buckets_)) { // De-activate the other buckets container bc.buckets_ = nullptr; } bucket_container(bucket_container &&bc, const allocator_type &a) : allocator_(a), bucket_allocator_(allocator_) { move_assign(bc, std::false_type()); } bucket_container &operator=(const bucket_container &bc) { destroy_buckets(); copy_allocator(allocator_, bc.allocator_, typename traits_::propagate_on_container_copy_assignment()); bucket_allocator_ = allocator_; hashpower(bc.hashpower()); buckets_ = transfer(bc.hashpower(), bc, std::false_type()); return *this; } bucket_container &operator=(bucket_container &&bc) { destroy_buckets(); move_assign(bc, typename traits_::propagate_on_container_move_assignment()); return *this; } void swap(bucket_container &bc) noexcept { swap_allocator(allocator_, bc.allocator_, typename traits_::propagate_on_container_swap()); swap_allocator(bucket_allocator_, bc.bucket_allocator_, typename traits_::propagate_on_container_swap()); // Regardless of whether we actually swapped the allocators or not, it will // always be okay to do the remainder of the swap. This is because if the // allocators were swapped, then the subsequent operations are okay. If the // allocators weren't swapped but compare equal, then we're okay. If they // weren't swapped and compare unequal, then behavior is undefined, so // we're okay. size_t bc_hashpower = bc.hashpower(); bc.hashpower(hashpower()); hashpower(bc_hashpower); std::swap(buckets_, bc.buckets_); } size_type hashpower() const { return hashpower_.load(std::memory_order_acquire); } void hashpower(size_type val) { hashpower_.store(val, std::memory_order_release); } size_type size() const { return size_type(1) << hashpower(); } allocator_type get_allocator() const { return allocator_; } bucket &operator[](size_type i) { return buckets_[i]; } const bucket &operator[](size_type i) const { return buckets_[i]; } // Constructs live data in a bucket template void setKV(size_type ind, size_type slot, partial_t p, K &&k, Args &&...args) { bucket &b = buckets_[ind]; assert(!b.occupied(slot)); b.partial(slot) = p; traits_::construct(allocator_, std::addressof(b.storage_kvpair(slot)), std::piecewise_construct, std::forward_as_tuple(std::forward(k)), std::forward_as_tuple(std::forward(args)...)); // This must occur last, to enforce a strong exception guarantee b.occupied(slot) = true; } // Destroys live data in a bucket void eraseKV(size_type ind, size_type slot) { bucket &b = buckets_[ind]; assert(b.occupied(slot)); b.occupied(slot) = false; traits_::destroy(allocator_, std::addressof(b.storage_kvpair(slot))); } // Destroys all the live data in the buckets. Does not deallocate the bucket // memory. void clear() noexcept { static_assert(std::is_nothrow_destructible::value && std::is_nothrow_destructible::value, "bucket_container requires key and value to be nothrow " "destructible"); for (size_type i = 0; i < size(); ++i) { bucket &b = buckets_[i]; for (size_type j = 0; j < SLOT_PER_BUCKET; ++j) { if (b.occupied(j)) { eraseKV(i, j); } } } } // Destroys and deallocates all data in the buckets. After this operation, // the bucket container will have no allocated data. It is still valid to // swap, move or copy assign to this container. void clear_and_deallocate() noexcept { destroy_buckets(); } private: using bucket_traits_ = typename traits_::template rebind_traits; using bucket_pointer = typename bucket_traits_::pointer; // true here means the allocators from `src` are propagated on libcuckoo_copy template void copy_allocator(A &dst, const A &src, std::true_type) { // NOLINT dst = src; } template void copy_allocator(A &dst, const A &src, std::false_type) {} // NOLINT // true here means the allocators from `src` are propagated on libcuckoo_swap template void swap_allocator(A &dst, A &src, std::true_type) { // NOLINT std::swap(dst, src); } template void swap_allocator(A &, A &, std::false_type) {} // true here means the bucket allocator should be propagated void move_assign(bucket_container &src, std::true_type) { // NOLINT allocator_ = std::move(src.allocator_); bucket_allocator_ = allocator_; hashpower(src.hashpower()); buckets_ = src.buckets_; src.buckets_ = nullptr; } void move_assign(bucket_container &src, std::false_type) { // NOLINT hashpower(src.hashpower()); if (allocator_ == src.allocator_) { buckets_ = src.buckets_; src.buckets_ = nullptr; } else { buckets_ = transfer(src.hashpower(), src, std::true_type()); } } void destroy_buckets() noexcept { if (buckets_ == nullptr) { return; } // The bucket default constructor is nothrow, so we don't have to // worry about dealing with exceptions when constructing all the // elements. static_assert(std::is_nothrow_destructible::value, "bucket_container requires bucket to be nothrow " "destructible"); clear(); for (size_type i = 0; i < size(); ++i) { traits_::destroy(allocator_, &buckets_[i]); } bucket_allocator_.deallocate(buckets_, size()); buckets_ = nullptr; } // `true` here refers to whether or not we should move void move_or_copy(size_type dst_ind, size_type dst_slot, bucket &src, // NOLINT size_type src_slot, std::true_type) { setKV(dst_ind, dst_slot, src.partial(src_slot), src.movable_key(src_slot), std::move(src.mapped(src_slot))); } void move_or_copy(size_type dst_ind, size_type dst_slot, bucket &src, // NOLINT size_type src_slot, std::false_type) { setKV(dst_ind, dst_slot, src.partial(src_slot), src.key(src_slot), src.mapped(src_slot)); } template bucket_pointer transfer( size_type dst_hp, typename std::conditional::type src, std::integral_constant move) { assert(dst_hp >= src.hashpower()); bucket_container dst(dst_hp, get_allocator()); // Move/copy all occupied slots of the source buckets for (size_t i = 0; i < src.size(); ++i) { for (size_t j = 0; j < SLOT_PER_BUCKET; ++j) { if (src.buckets_[i].occupied(j)) { dst.move_or_copy(i, j, src.buckets_[i], j, move); } } } // Take away the pointer from `dst` and return it bucket_pointer dst_pointer = dst.buckets_; dst.buckets_ = nullptr; return dst_pointer; } // This allocator matches the value_type, but is not used to construct // storage_value_type pairs, or allocate buckets allocator_type allocator_; // This allocator is used for actually allocating buckets. It is simply // copy-constructed from `allocator_`, and will always be copied whenever // allocator_ is copied. typename traits_::template rebind_alloc bucket_allocator_; // This needs to be atomic, since it can be read and written by multiple // threads not necessarily synchronized by a lock. std::atomic hashpower_; // These buckets are protected by striped locks (external to the // BucketContainer), which must be obtained before accessing a bucket. bucket_pointer buckets_; // If the key and value are Trivial, the bucket be serilizable. Since we // already disallow user-specialized instances of std::pair, we know that the // default implementation of std::pair uses a default copy constructor, so // this should be okay. We could in theory just check if the type is // TriviallyCopyable but this check is not available on some compilers we // want to support. template friend typename std::enable_if::value && std::is_trivial::value, std::ostream &>::type operator<<(std::ostream &os, const bucket_container &bc) { size_type hp = bc.hashpower(); os.write(reinterpret_cast(&hp), sizeof(size_type)); os.write(reinterpret_cast(bc.buckets_), sizeof(bucket) * bc.size()); return os; } template friend typename std::enable_if::value && std::is_trivial::value, std::istream &>::type operator>>(std::istream &is, bucket_container &bc) { size_type hp; is.read(reinterpret_cast(&hp), sizeof(size_type)); bucket_container new_bc(hp, bc.get_allocator()); is.read(reinterpret_cast(new_bc.buckets_), new_bc.size() * sizeof(bucket)); bc.swap(new_bc); return is; } }; } // namespace libcuckoo #endif // _MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_CUCKOOHASH_BUCKET_CONTAINER_HPP ================================================ FILE: monolith/native_training/runtime/hash_table/cuckoohash/cuckoo_embedding_hash_table.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/runtime/hash_table/cuckoohash/cuckoo_embedding_hash_table.h" #include #include #include "absl/strings/str_format.h" #include "absl/synchronization/mutex.h" #include "monolith/native_training/data/training_instance/cc/reader_util.h" #include "monolith/native_training/runtime/allocator/block_allocator.h" #include "monolith/native_training/runtime/common/linalg_utils.h" #include "monolith/native_training/runtime/hash_table/cuckoohash/cuckoohash_map.hpp" #include "monolith/native_training/runtime/hash_table/entry_defs.h" namespace monolith { namespace hash_table { namespace { using allocator::EntryAddress; using allocator::TSEmbeddingBlockAllocator; using common::IsAlmostEqual; using common::L2NormSquare; const int64_t kSecPerDay = 24 * 60 * 60; // A helper that wraps the object with a init_fn. template class WithInitFn : public T { public: template explicit WithInitFn(const std::function& init_fn, Args&&... args) : T(std::forward(args)...) { init_fn(*this); } }; template class EntryHelper {}; template <> class EntryHelper { public: EntryHelper(size_t entry_size) : entry_size_(entry_size), alloc_(entry_size) {} template bool Upsert(Map* m, Args&&... args) { return m->upsert(std::forward(args)..., &alloc_); } void* Get(const PackedEntry& entry) const { return alloc_.GetEntryPointer(entry.get_entry_addr()); } void DeallocateAll() { alloc_.DeallocateAll(); } private: size_t entry_size_; allocator::TSEmbeddingBlockAllocator alloc_; }; template <> class EntryHelper { public: EntryHelper(size_t entry_size) : entry_size_(entry_size) {} template bool Upsert(Map* m, Args&&... args) { return m->upsert(std::forward(args)..., entry_size_); } void* Get(const RawEntry& entry) const { return entry.get(); } void DeallocateAll() {} private: size_t entry_size_; }; template class EntryHelper> { public: template bool Upsert(Map* m, Args&&... args) { return m->upsert(std::forward(args)...); } const void* Get(const InlineEntry& entry) const { return entry.get(); } void* Get(InlineEntry& entry) { return entry.get(); } void DeallocateAll() {} }; struct Params { CuckooEmbeddingHashTableConfig config; std::unique_ptr accessor; uint64_t initial_capacity; SlotExpireTimeConfig slot_expire_time_config; bool skip_zero_embedding = false; }; template class CuckooEmbeddingHashTable : public EmbeddingHashTableInterface { public: using MapType = libcuckoo::cuckoohash_map>; explicit CuckooEmbeddingHashTable(Params p, EntryHelper entry_helper) : config_(std::move(p.config)), accessor_(std::move(p.accessor)), entry_helper_(std::move(entry_helper)), default_expire_time_(p.slot_expire_time_config.default_expire_time()), skip_zero_embedding_(p.skip_zero_embedding), m_(p.initial_capacity) { slot_to_expire_time_ = std::make_unique>(); for (const auto& slot_expire_time : p.slot_expire_time_config.slot_expire_times()) { (*slot_to_expire_time_)[slot_expire_time.slot()] = slot_expire_time.expire_time(); } LOG_FIRST_N(INFO, 1) << "skip_zero_embedding: " << skip_zero_embedding_; } // Returns the corresponding entry for |ids|. int64_t BatchLookup(absl::Span ids, absl::Span> embeddings) const override { int64_t found = 0; for (unsigned int index = 0; index < ids.size(); ++index) { int64_t id = ids[index]; found += Lookup(id, embeddings[index]); } return found; } // Handles the corresponding entry for |ids|. void BatchLookupEntry(absl::Span ids, absl::Span entries) const override { for (unsigned int index = 0; index < ids.size(); ++index) { int64_t id = ids[index]; LookupEntry(id, entries.subspan(index, index + 1)); } } // Returns the corresponding entry for |id|. int64_t Lookup(int64_t id, absl::Span embedding) const override { auto find_fn = [&](EntryType& entry) { accessor_->Fill(entry_helper_.Get(entry), embedding); }; if (m_.find_fn(id, find_fn)) { return 1; } // By default, returns all zero. std::memset(embedding.data(), 0, sizeof(float) * embedding.size()); return 0; } // Handles the corresponding entry for |id|. void LookupEntry(int64_t id, absl::Span entry) const override { auto find_fn = [&](EntryType& raw_entry) { entry[0] = std::move(accessor_->Save(entry_helper_.Get(raw_entry), raw_entry.GetTimestamp())); }; if (m_.find_fn(id, find_fn)) { return; } } // Update the hash table entry directly. void Assign(absl::Span ids, absl::Span> updates, int64_t update_time) override { for (size_t i = 0; i < ids.size(); ++i) { int64_t id = ids[i]; auto update = updates[i]; if (skip_zero_embedding_ && IsAlmostEqual(L2NormSquare(update.data(), update.size()), 0.f)) { m_.erase(id); LOG_EVERY_N(INFO, 10000) << "Assign erase " << google::COUNTER << " zero embeddings."; } else { UpsertEntry(id, [&](EntryType& entry) { entry.SetTimestamp(update_time); accessor_->Assign(update, entry_helper_.Get(entry)); }); } } } // Update the hash table entry directly. void AssignAdd(int64_t id, absl::Span update, int64_t update_time) override { UpsertEntry(id, [&](EntryType& entry) { entry.SetTimestamp(update_time); accessor_->AssignAdd(update, entry_helper_.Get(entry)); }); } void Reinitialize(absl::Span ids, absl::Span status) override { int64_t update_time = absl::ToUnixSeconds(absl::Now()); for (size_t i = 0; i < ids.size(); ++i) { int64_t id = ids[i]; bool existed = !UpsertEntry(id, [&](EntryType& entry) { entry.SetTimestamp(update_time); accessor_->Init(entry_helper_.Get(entry)); }); status[i] = existed; } } // Update the hash table based on optimizer. void BatchOptimize(absl::Span ids, absl::Span> grads, absl::Span learning_rates, int64_t update_time, const int64_t global_step) override { for (size_t i = 0; i < ids.size(); ++i) { Optimize(ids[i], grads[i], learning_rates, update_time, global_step); } } // Update the hash table based on optimizer. void Optimize(int64_t id, absl::Span grad, absl::Span learning_rates, int64_t update_time, const int64_t global_step) override { UpsertEntry(id, [&](EntryType& entry) { entry.SetTimestamp(update_time); accessor_->Optimize(entry_helper_.Get(entry), grad, learning_rates, global_step); }); } // Evict the outdated hash table values based on the expire time and last // updated time. virtual void Evict(int64_t max_update_time) { auto should_be_evict_fn = [this, max_update_time](const int64_t& key, const EntryType& entry) { const int64_t timestamp = entry.GetTimestamp(); int expire_time = default_expire_time_; // TODO(zhen.li1): evict assumes the fid is v2 version. auto expire_time_iter = slot_to_expire_time_->find(slot_id_v2(key)); if (expire_time_iter != slot_to_expire_time_->end()) { expire_time = expire_time_iter->second; } return max_update_time - timestamp >= expire_time * kSecPerDay; }; m_.evict(should_be_evict_fn); } // Check if a given id exists in the hashtable bool Contains(const int64_t id) { return m_.contains(id); } class CuckooLockCtx : public LockCtx { public: explicit CuckooLockCtx(typename MapType::locked_table table) : table_(std::move(table)) {} ~CuckooLockCtx() override = default; private: typename MapType::locked_table table_; }; std::unique_ptr LockAll() override { return std::make_unique(m_.lock_table()); } // Saves the data. The implementation should guarantee that different shard // can be dumped in the parallel. void Save(DumpShard shard, WriteFn write_fn, DumpIterator* iter) const override { auto dump_fn = [&](const int64_t& key, const EntryType& entry) { EntryDump dump = accessor_->Save(entry_helper_.Get(entry), entry.GetTimestamp()); dump.set_id(key); return write_fn(std::move(dump)); }; m_.partial_dump(shard, dump_fn, iter); } // Restores the data from get_fn. The implementation should guarantee that // different shard can be dumped in the parallel. // |get_fn| returns false if it is end of stream. int64_t Restore(DumpShard shard, std::function get_fn) override { EntryDump dump; int64_t max_update_ts = 0; while (get_fn(&dump, &max_update_ts)) { if (skip_zero_embedding_ && IsAlmostEqual(L2NormSquare(dump.num().data(), dump.num_size()), 0.f)) { LOG_EVERY_N(INFO, 1000000) << "Restore skip " << google::COUNTER << " zero embeddings."; continue; } UpsertEntry(dump.id(), [&](EntryType& entry) { uint32_t timestamp_sec = 0; accessor_->Restore(entry_helper_.Get(entry), ×tamp_sec, std::move(dump)); entry.SetTimestamp(timestamp_sec); }); } return max_update_ts; } // Clears data of hash table. void Clear() override { auto fn = [this]() { entry_helper_.DeallocateAll(); }; m_.clear_with_callback(fn); } int64_t Size() const override { return m_.size(); } int DimSize() const override { return accessor_->DimSize(); } int SliceSize() const override { return accessor_->SliceSize(); } bool Contains(int64_t id) const override { return m_.contains(id); } std::string DebugString() const override { return absl::StrFormat( R"({"accessor": %s, "size": %ld, "memory": %ld, "memory_if_not_compressed": %ld, "load_factor": %f})", accessor_->DebugString(), Size(), Size() * (accessor_->SizeBytes() + sizeof(int64_t)), Size() * (accessor_->UncompressedSizeBytes() + sizeof(int64_t)), m_.load_factor()); } private: bool UpsertEntry(int64_t id, const std::function& upsert_fn) { auto init_fn = [&](EntryType& entry) { accessor_->Init(entry_helper_.Get(entry)); upsert_fn(entry); }; return entry_helper_.Upsert(&m_, id, upsert_fn, init_fn); } CuckooEmbeddingHashTableConfig config_; std::unique_ptr accessor_; EntryHelper entry_helper_; std::unique_ptr> slot_to_expire_time_; int64_t default_expire_time_; bool skip_zero_embedding_; MapType m_; }; template std::unique_ptr CreateInlineEntryTable( Params p, int64_t size_bytes) { if (size_bytes > InlineEntry::capacity()) { std::abort(); } return std::make_unique>>( std::move(p), EntryHelper>()); } } // namespace std::unique_ptr NewCuckooEmbeddingHashTable( CuckooEmbeddingHashTableConfig config, std::unique_ptr accessor, EmbeddingHashTableConfig::EntryType type, uint64_t initial_capacity, const SlotExpireTimeConfig& slot_expire_time_config, bool skip_zero_embedding) { const int64_t size_bytes = accessor->SizeBytes(); Params p = {std::move(config), std::move(accessor), initial_capacity, slot_expire_time_config, skip_zero_embedding}; if (type == EmbeddingHashTableConfig::PACKED) { EntryHelper helper(size_bytes); return std::make_unique>( std::move(p), std::move(helper)); } else if (type == EmbeddingHashTableConfig::RAW) { if (size_bytes <= 12) { return CreateInlineEntryTable<16>(std::move(p), size_bytes); } else if (size_bytes <= 20) { return CreateInlineEntryTable<24>(std::move(p), size_bytes); } else if (size_bytes <= 28) { return CreateInlineEntryTable<32>(std::move(p), size_bytes); } else { EntryHelper helper(size_bytes); return std::make_unique>( std::move(p), std::move(helper)); } } // Should not reach here. throw std::invalid_argument( absl::StrFormat("Unknown entry type table. %d", type)); return nullptr; } } // namespace hash_table } // namespace monolith ================================================ FILE: monolith/native_training/runtime/hash_table/cuckoohash/cuckoo_embedding_hash_table.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_CUCKOO_EMBEDDING_HASH_TABLE #define MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_CUCKOO_EMBEDDING_HASH_TABLE #include "monolith/native_training/runtime/hash_table/embedding_hash_table.pb.h" #include "monolith/native_training/runtime/hash_table/embedding_hash_table_interface.h" #include "monolith/native_training/runtime/hash_table/entry_accessor.h" namespace monolith { namespace hash_table { std::unique_ptr NewCuckooEmbeddingHashTable( CuckooEmbeddingHashTableConfig config, std::unique_ptr accessor, EmbeddingHashTableConfig::EntryType type, uint64_t initial_capacity, const SlotExpireTimeConfig& slot_expire_time_config, bool skip_zero_embedding); } // namespace hash_table } // namespace monolith #endif // MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_CUCKOO_EMBEDDING_HASH_TABLE ================================================ FILE: monolith/native_training/runtime/hash_table/cuckoohash/cuckoo_embedding_hash_table_benchmark.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "absl/random/random.h" #include "absl/strings/str_format.h" #include "benchmark/benchmark.h" #include "glog/logging.h" #include "google/protobuf/text_format.h" #include "monolith/native_training/runtime/concurrency/thread_pool.h" #include "monolith/native_training/runtime/hash_table/embedding_hash_table_factory.h" namespace monolith { namespace hash_table { namespace { namespace proto2 = ::google::protobuf; constexpr int64_t kMaxId = 1 << 15; std::unique_ptr SetupHashTable() { EmbeddingHashTableConfig config; CHECK(proto2::TextFormat::ParseFromString(R"( entry_config { segments { dim_size: 1 init_config { zeros {} } opt_config { ftrl {} } } segments { dim_size: 32 init_config { zeros {} } opt_config { sgd {} } } } cuckoo {} )", &config)); auto table = NewEmbeddingHashTableFromConfig(config); for (int64_t i = 0; i < kMaxId; ++i) { table->AssignAdd(i, std::vector(33, 0.0f), 0); } return table; } std::vector SetupPickedIds(int num) { absl::BitGen bitgen; std::vector ids(num); for (int i = 0; i < num; ++i) { ids[i] = absl::Uniform(bitgen, 0u, kMaxId); } return ids; } std::vector ids = SetupPickedIds(1000 * 256); // NOLINT auto table = SetupHashTable(); // NOLINT void BM_LookUp(benchmark::State& state) { // NOLINT int64_t thread_num = state.range(0); monolith::concurrency::ThreadPool thread_pool(thread_num); for (auto _ : state) { std::atomic_int join(thread_num); auto optimize = [&]() { // OPTIMIZE: remove memory allocation overhead std::vector embeddings(33, 0); for (int64_t id : ids) { table->Lookup(id, absl::MakeSpan(embeddings)); } --join; }; // Simulate multi-workers lookup simultaneously for (int64_t i = 0; i < thread_num; ++i) { thread_pool.Schedule(optimize); } while (join) { } } } void BM_BatchLookUp(benchmark::State& state) { // NOLINT int64_t thread_num = state.range(0); monolith::concurrency::ThreadPool thread_pool(thread_num); for (auto _ : state) { std::atomic_int join(thread_num); auto optimize = [&]() { // OPTIMIZE: remove memory allocation overhead std::vector data(ids.size() * 33); std::vector> embeddings; embeddings.reserve(ids.size()); for (size_t i = 0; i < ids.size(); ++i) { embeddings.push_back(absl::MakeSpan(data.data() + i * 33, 33)); } table->BatchLookup(absl::MakeSpan(ids), absl::MakeSpan(embeddings)); --join; }; // Simulate multi-workers lookup simultaneously for (int64_t i = 0; i < thread_num; ++i) { thread_pool.Schedule(optimize); } while (join) { } } } void BM_Optimize(benchmark::State& state) { // NOLINT int64_t thread_num = state.range(0); monolith::concurrency::ThreadPool thread_pool(thread_num); std::vector grad(33, 1.f); for (auto _ : state) { std::atomic_int join(thread_num); auto optimize = [&]() { for (int64_t id : ids) { table->Optimize(id, absl::MakeSpan(grad), {0.01f, 0.01f}, 0); } --join; }; // Simulate multi-workers optimize simultaneously for (int64_t i = 0; i < thread_num; ++i) { thread_pool.Schedule(optimize); } while (join) { } } } void BM_BatchOptimize(benchmark::State& state) { // NOLINT int64_t thread_num = state.range(0); monolith::concurrency::ThreadPool thread_pool(thread_num); std::vector data(ids.size() * 33, 1.f); std::vector> grads; grads.reserve(ids.size()); for (size_t i = 0; i < ids.size(); ++i) { grads.emplace_back(absl::MakeSpan(data.data() + i * 33, 33)); } for (auto _ : state) { std::atomic_int join(thread_num); auto optimize = [&]() { table->BatchOptimize(absl::MakeSpan(ids), absl::MakeSpan(grads), {0.01f, 0.01f}, 0); --join; }; // Simulate multi-workers optimize simultaneously for (int64_t i = 0; i < thread_num; ++i) { thread_pool.Schedule(optimize); } while (join) { } } } BENCHMARK(BM_LookUp)->Arg(1)->Arg(10); BENCHMARK(BM_BatchLookUp)->Arg(1)->Arg(10); BENCHMARK(BM_Optimize)->Arg(1)->Arg(10); BENCHMARK(BM_BatchOptimize)->Arg(1)->Arg(10); } // namespace } // namespace hash_table } // namespace monolith BENCHMARK_MAIN(); ================================================ FILE: monolith/native_training/runtime/hash_table/cuckoohash/cuckoo_embedding_hash_table_test.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "gmock/gmock.h" #include "google/protobuf/text_format.h" #include "gtest/gtest.h" #include "monolith/native_training/runtime/hash_table/embedding_hash_table.pb.h" #include "monolith/native_training/runtime/hash_table/embedding_hash_table_test.h" namespace monolith { namespace hash_table { namespace { namespace proto2 = google::protobuf; using ::testing::ElementsAre; std::tuple> GetTestOneDimSgdHashTable( EmbeddingHashTableConfig::EntryType type = EmbeddingHashTableConfig::PACKED, bool skip_zero_embedding = false) { EmbeddingHashTableConfig config; if (skip_zero_embedding) { EXPECT_TRUE(proto2::TextFormat::ParseFromString(R"( entry_config { segments { dim_size: 1 comp_config { fp32 {} } } entry_type: SERVING } initial_capacity: 1 cuckoo {} skip_zero_embedding: true )", &config)); } else { EXPECT_TRUE(proto2::TextFormat::ParseFromString(R"( entry_config { segments { dim_size: 1 init_config { zeros {} } opt_config { sgd {} } } } initial_capacity: 1 cuckoo {} )", &config)); } config.set_entry_type(type); std::vector learning_rates(1, 0.01f); return std::make_tuple(config, learning_rates); } INSTANTIATE_TEST_CASE_P( CuckooHashmapReadWrite, ReadWriteEmbeddingHashTableTest, ::testing::Values( GetTestOneDimSgdHashTable(EmbeddingHashTableConfig::PACKED), GetTestOneDimSgdHashTable(EmbeddingHashTableConfig::RAW))); INSTANTIATE_TEST_CASE_P(CuckooHashmapRestore, SaveRestoreEmbeddingHashTestTest, ::testing::Values(GetTestOneDimSgdHashTable())); INSTANTIATE_TEST_CASE_P(OneTimeEvict, EmbeddingHashTableEvictTest, ::testing::Values(GetTestOneDimSgdHashTable())); INSTANTIATE_TEST_CASE_P(EvictWhileRehash, EmbeddingHashTableEvictTest, ::testing::Values(GetTestOneDimSgdHashTable())); INSTANTIATE_TEST_CASE_P(SkipZeroEmbedding, EmbeddingHashTableSkipZeroEmbeddingTest, ::testing::Values(GetTestOneDimSgdHashTable( EmbeddingHashTableConfig::PACKED, true))); } // namespace } // namespace hash_table } // namespace monolith ================================================ FILE: monolith/native_training/runtime/hash_table/cuckoohash/cuckoohash_config.hpp ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. /** \file */ #ifndef _MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_CUCKOOHASH_CUCKOOHASH_CONFIG_HPP #define _MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_CUCKOOHASH_CUCKOOHASH_CONFIG_HPP #include #include namespace libcuckoo { // The default maximum number of keys per bucket constexpr size_t DEFAULT_SLOT_PER_BUCKET = 4; // The default number of elements in an empty hash table constexpr size_t DEFAULT_SIZE = (1U << 16) * DEFAULT_SLOT_PER_BUCKET; // The default minimum load factor that the table allows for automatic // expansion. It must be a number between 0.0 and 1.0. The table will throw // load_factor_too_low if the load factor falls below this value // during an automatic expansion. constexpr double DEFAULT_MINIMUM_LOAD_FACTOR = 0.05; // An alias for the value that sets no limit on the maximum hashpower. If this // value is set as the maximum hashpower limit, there will be no limit. This // is also the default initial value for the maximum hashpower in a table. constexpr size_t NO_MAXIMUM_HASHPOWER = std::numeric_limits::max(); // set LIBCUCKOO_DEBUG to 1 to enable debug output #define LIBCUCKOO_DEBUG 0 } // namespace libcuckoo #endif // _MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_CUCKOOHASH_CUCKOOHASH_CONFIG_HPP ================================================ FILE: monolith/native_training/runtime/hash_table/cuckoohash/cuckoohash_map.hpp ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. /** This cuckoo hash implementation is adapted from * https://github.com/efficient/libcuckoo.git */ #ifndef MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_CUCKOOHASH_CUCKOOHASH_MAP_HPP_ #define MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_CUCKOOHASH_CUCKOOHASH_MAP_HPP_ #include #include #include #include #include #include #include #include #include // NOLINT #include #include #include #include #include #include #include #include #include #include #include #include "absl/container/internal/hash_function_defaults.h" // IWYU pragma: export #include "glog/logging.h" #include "monolith/native_training/runtime/hash_table/cuckoohash/bucket_container.hpp" #include "monolith/native_training/runtime/hash_table/cuckoohash/cuckoohash_config.hpp" #include "monolith/native_training/runtime/hash_table/cuckoohash/cuckoohash_util.hpp" #include "monolith/native_training/runtime/hash_table/embedding_hash_table_interface.h" namespace libcuckoo { /** * A concurrent hash table * * @tparam Key type of keys in the table * @tparam T type of values in the table * @tparam Hash type of hash functor * @tparam KeyEqual type of equality comparison functor * @tparam Allocator type of allocator. We suggest using an aligned allocator, * because the table relies on types that are over-aligned to optimize * concurrent cache usage. * @tparam SLOT_PER_BUCKET number of slots for each bucket in the table */ // We pick absl map over std::hash since the hash function is more evenly // distributed. template , class KeyEqual = absl::container_internal::hash_default_eq, class Allocator = std::allocator>, std::size_t SLOT_PER_BUCKET = DEFAULT_SLOT_PER_BUCKET> class cuckoohash_map { private: // Type of the partial key using partial_t = uint8_t; // The type of the buckets container using buckets_t = bucket_container; public: /** @name Type Declarations */ /**@{*/ using key_type = typename buckets_t::key_type; using mapped_type = typename buckets_t::mapped_type; /** * This type is defined as an @c std::pair. Note that table behavior is * undefined if a user-defined specialization of @c std::pair or @c * std::pair exists. */ using value_type = typename buckets_t::value_type; using size_type = typename buckets_t::size_type; using difference_type = std::ptrdiff_t; using hasher = Hash; using key_equal = KeyEqual; using allocator_type = typename buckets_t::allocator_type; using reference = typename buckets_t::reference; using const_reference = typename buckets_t::const_reference; using pointer = typename buckets_t::pointer; using const_pointer = typename buckets_t::const_pointer; class locked_table; /**@}*/ /** @name Table Parameters */ /**@{*/ /** * The number of slots per hash bucket */ static constexpr uint16_t slot_per_bucket() { return SLOT_PER_BUCKET; } /**@}*/ /** @name Constructors, Destructors, and Assignment */ /**@{*/ /** * Creates a new cuckohash_map instance * * @param n the number of elements to reserve space for initially * @param hf hash function instance to use * @param equal equality function instance to use * @param alloc allocator instance to use */ cuckoohash_map(size_type n = DEFAULT_SIZE, const Hash &hf = Hash(), const KeyEqual &equal = KeyEqual(), const Allocator &alloc = Allocator()) : hash_fn_(hf), eq_fn_(equal), buckets_(reserve_calc(n), alloc), old_buckets_(0, alloc), all_locks_(get_allocator()), num_remaining_lazy_rehash_locks_(0), minimum_load_factor_(DEFAULT_MINIMUM_LOAD_FACTOR), maximum_hashpower_(NO_MAXIMUM_HASHPOWER), max_num_worker_threads_(0) { all_locks_.emplace_back(std::min(bucket_count(), size_type(kMaxNumLocks)), spinlock(), get_allocator()); } /** * Constructs the map with the contents of the range @c [first, last]. If * multiple elements in the range have equivalent keys, it is unspecified * which element is inserted. * * @param first the beginning of the range to copy from * @param last the end of the range to copy from * @param n the number of elements to reserve space for initially * @param hf hash function instance to use * @param equal equality function instance to use * @param alloc allocator instance to use */ template cuckoohash_map(InputIt first, InputIt last, size_type n = DEFAULT_SIZE, const Hash &hf = Hash(), const KeyEqual &equal = KeyEqual(), const Allocator &alloc = Allocator()) : cuckoohash_map(n, hf, equal, alloc) { for (; first != last; ++first) { insert(first->first, first->second); } } /** * Copy constructor. If @p other is being modified concurrently, behavior is * unspecified. * * @param other the map being copied */ cuckoohash_map(const cuckoohash_map &other) = default; /** * Copy constructor with separate allocator. If @p other is being modified * concurrently, behavior is unspecified. * * @param other the map being copied * @param alloc the allocator instance to use with the map */ cuckoohash_map(const cuckoohash_map &other, const Allocator &alloc) : hash_fn_(other.hash_fn_), eq_fn_(other.eq_fn_), buckets_(other.buckets_, alloc), old_buckets_(other.old_buckets_, alloc), all_locks_(alloc), num_remaining_lazy_rehash_locks_( other.num_remaining_lazy_rehash_locks_), minimum_load_factor_(other.minimum_load_factor_), maximum_hashpower_(other.maximum_hashpower_), max_num_worker_threads_(other.max_num_worker_threads_) { if (other.get_allocator() == alloc) { all_locks_ = other.all_locks_; } else { add_locks_from_other(other); } } /** * Move constructor. If @p other is being modified concurrently, behavior is * unspecified. * * @param other the map being moved */ cuckoohash_map(cuckoohash_map &&other) = default; /** * Move constructor with separate allocator. If the map being moved is being * modified concurrently, behavior is unspecified. * * @param other the map being moved * @param alloc the allocator instance to use with the map */ cuckoohash_map(cuckoohash_map &&other, const Allocator &alloc) : hash_fn_(std::move(other.hash_fn_)), eq_fn_(std::move(other.eq_fn_)), buckets_(std::move(other.buckets_), alloc), old_buckets_(std::move(other.old_buckets_), alloc), all_locks_(alloc), num_remaining_lazy_rehash_locks_( other.num_remaining_lazy_rehash_locks_), minimum_load_factor_(other.minimum_load_factor_), maximum_hashpower_(other.maximum_hashpower_), max_num_worker_threads_(other.max_num_worker_threads_) { if (other.get_allocator() == alloc) { all_locks_ = std::move(other.all_locks_); } else { add_locks_from_other(other); } } /** * Constructs the map with the contents of initializer list @c init. * * @param init initializer list to initialize the elements of the map with * @param n the number of elements to reserve space for initially * @param hf hash function instance to use * @param equal equality function instance to use * @param alloc allocator instance to use */ cuckoohash_map(std::initializer_list init, size_type n = DEFAULT_SIZE, const Hash &hf = Hash(), const KeyEqual &equal = KeyEqual(), const Allocator &alloc = Allocator()) : cuckoohash_map(init.begin(), init.end(), n, hf, equal, alloc) {} /** * Exchanges the contents of the map with those of @p other * * @param other the map to exchange contents with */ void swap(cuckoohash_map &other) noexcept { std::swap(hash_fn_, other.hash_fn_); std::swap(eq_fn_, other.eq_fn_); buckets_.swap(other.buckets_); all_locks_.swap(other.all_locks_); other.minimum_load_factor_.store( minimum_load_factor_.exchange(other.minimum_load_factor(), std::memory_order_release), std::memory_order_release); other.maximum_hashpower_.store( maximum_hashpower_.exchange(other.maximum_hashpower(), std::memory_order_release), std::memory_order_release); } /** * Copy assignment operator. If @p other is being modified concurrently, * behavior is unspecified. * * @param other the map to assign from * @return @c *this */ cuckoohash_map &operator=(const cuckoohash_map &other) = default; /** * Move assignment operator. If @p other is being modified concurrently, * behavior is unspecified. * * @param other the map to assign from * @return @c *this */ cuckoohash_map &operator=(cuckoohash_map &&other) = default; /** * Initializer list assignment operator * * @param ilist an initializer list to assign from * @return @c *this */ cuckoohash_map &operator=(std::initializer_list ilist) { clear(); for (const auto &item : ilist) { insert(item.first, item.second); } return *this; } /**@}*/ /** @name Table Details * * Methods for getting information about the table. Methods that query * changing properties of the table are not synchronized with concurrent * operations, and may return out-of-date information if the table is being * concurrently modified. They will also continue to work after the container * has been moved. * */ /**@{*/ /** * Returns the function that hashes the keys * * @return the hash function */ hasher hash_function() const { return hash_fn_; } /** * Returns the function that compares keys for equality * * @return the key comparison function */ key_equal key_eq() const { return eq_fn_; } /** * Returns the allocator associated with the map * * @return the associated allocator */ allocator_type get_allocator() const { return buckets_.get_allocator(); } /** * Returns the hashpower of the table, which is log2(@ref * bucket_count()). * * @return the hashpower */ size_type hashpower() const { return buckets_.hashpower(); } /** * Returns the number of buckets in the table. * * @return the bucket count */ size_type bucket_count() const { return buckets_.size(); } /** * Returns whether the table is empty or not. * * @return true if the table is empty, false otherwise */ bool empty() const { return size() == 0; } /** * Returns the number of elements in the table. * * @return number of elements in the table */ size_type size() const { if (all_locks_.size() == 0) { return 0; } counter_type s = 0; for (spinlock &lock : get_current_locks()) { s += lock.elem_counter(); } assert(s >= 0); return static_cast(s); } /** Returns the current capacity of the table, that is, @ref bucket_count() * × @ref slot_per_bucket(). * * @return capacity of table */ size_type capacity() const { return bucket_count() * slot_per_bucket(); } /** * Returns the percentage the table is filled, that is, @ref size() ÷ * @ref capacity(). * * @return load factor of the table */ double load_factor() const { return static_cast(size()) / static_cast(capacity()); } /** * Sets the minimum load factor allowed for automatic expansions. If an * expansion is needed when the load factor of the table is lower than this * threshold, @ref load_factor_too_low is thrown. It will not be * thrown for an explicitly-triggered expansion. * * @param mlf the load factor to set the minimum to * @throw std::invalid_argument if the given load factor is less than 0.0 * or greater than 1.0 */ void minimum_load_factor(const double mlf) { if (mlf < 0.0) { throw std::invalid_argument("load factor " + std::to_string(mlf) + " cannot be " "less than 0"); } else if (mlf > 1.0) { throw std::invalid_argument("load factor " + std::to_string(mlf) + " cannot be " "greater than 1"); } minimum_load_factor_.store(mlf, std::memory_order_release); } /** * Returns the minimum load factor of the table * * @return the minimum load factor */ double minimum_load_factor() const { return minimum_load_factor_.load(std::memory_order_acquire); } /** * Sets the maximum hashpower the table can be. If set to @ref * NO_MAXIMUM_HASHPOWER, there will be no limit on the hashpower. * Otherwise, the table will not be able to expand beyond the given * hashpower, either by an explicit or an automatic expansion. * * @param mhp the hashpower to set the maximum to * @throw std::invalid_argument if the current hashpower exceeds the limit */ void maximum_hashpower(size_type mhp) { if (hashpower() > mhp) { throw std::invalid_argument("maximum hashpower " + std::to_string(mhp) + " is less than current hashpower"); } maximum_hashpower_.store(mhp, std::memory_order_release); } /** * Returns the maximum hashpower of the table * * @return the maximum hashpower */ size_type maximum_hashpower() const { return maximum_hashpower_.load(std::memory_order_acquire); } /** * Set the maximum number of extra worker threads the table can spawn when * doing large batch operations. Currently batch operations occur in the * following scenarios. * - Any resizing operation which invokes cuckoo_expand_simple. This * includes any explicit rehash/resize operation, or any general resize if * the data is not nothrow-move-constructible. * - Creating a locked_table or resizing within a locked_table. * * @param num_threads the number of extra threads */ void max_num_worker_threads(size_type extra_threads) { max_num_worker_threads_.store(extra_threads, std::memory_order_release); } /** * Returns the maximum number of extra worker threads. */ size_type max_num_worker_threads() const { return max_num_worker_threads_.load(std::memory_order_acquire); } /**@}*/ /** @name Table Operations * * These are operations that affect the data in the table. They are safe to * call concurrently with each other. * */ /**@{*/ /** * Searches the table for @p key, and invokes @p fn on the value. @p fn is * not allowed to modify the contents of the value if found. * * @tparam K type of the key. This can be any type comparable with @c key_type * @tparam F type of the functor. It should implement the method * void operator()(const mapped_type&). * @param key the key to search for * @param fn the functor to invoke if the element is found * @return true if the key was found and functor invoked, false otherwise */ template bool find_fn(const K &key, F fn) const { const hash_value hv = hashed_key(key); const auto b = snapshot_and_lock_two(hv); const table_position pos = cuckoo_find(key, hv.partial, b.i1, b.i2); if (pos.status == ok) { fn(buckets_[pos.index].mapped(pos.slot)); return true; } else { return false; } } /** * Searches the table for @p key, and invokes @p fn on the value. @p fn is * allow to modify the contents of the value if found. * * @tparam K type of the key. This can be any type comparable with @c key_type * @tparam F type of the functor. It should implement the method * void operator()(mapped_type&). * @param key the key to search for * @param fn the functor to invoke if the element is found * @return true if the key was found and functor invoked, false otherwise */ template bool update_fn(const K &key, F fn) { const hash_value hv = hashed_key(key); const auto b = snapshot_and_lock_two(hv); const table_position pos = cuckoo_find(key, hv.partial, b.i1, b.i2); if (pos.status == ok) { fn(buckets_[pos.index].mapped(pos.slot)); return true; } else { return false; } } /** * Searches for @p key in the table, and invokes @p fn on the value if the * key is found. The functor can mutate the value, and should return @c true * in order to erase the element, and @c false otherwise. * * @tparam K type of the key * @tparam F type of the functor. It should implement the method * bool operator()(mapped_type&). * @param key the key to possibly erase from the table * @param fn the functor to invoke if the element is found * @return true if @p key was found and @p fn invoked, false otherwise */ template bool erase_fn(const K &key, F fn) { const hash_value hv = hashed_key(key); const auto b = snapshot_and_lock_two(hv); const table_position pos = cuckoo_find(key, hv.partial, b.i1, b.i2); if (pos.status == ok) { if (fn(buckets_[pos.index].mapped(pos.slot))) { del_from_bucket(pos.index, pos.slot); } return true; } else { return false; } } /** * Searches for @p key in the table. If the key is found, then @p fn is * called on the existing value, and nothing happens to the passed-in key and * values. The functor can mutate the value, and should return @c true in * order to erase the element, and @c false otherwise. If the key is not * found and must be inserted, the pair will be constructed by forwarding the * given key and values. If there is no room left in the table, it will be * automatically expanded. Expansion may throw exceptions. * * @tparam K type of the key * @tparam F type of the functor. It should implement the method * bool operator()(mapped_type&). * @tparam Args list of types for the value constructor arguments * @param key the key to insert into the table * @param fn the functor to invoke if the element is found. If your @p fn * needs more data that just the value being modified, consider implementing * it as a lambda with captured arguments. * @param val a list of constructor arguments with which to create the value * @return true if a new key was inserted, false if the key was already in * the table */ template bool uprase_fn(K &&key, F fn, Args &&... val) { hash_value hv = hashed_key(key); auto b = snapshot_and_lock_two(hv); table_position pos = cuckoo_insert_loop(hv, b, key); if (pos.status == ok) { add_to_bucket(pos.index, pos.slot, hv.partial, std::forward(key), std::forward(val)...); } else { if (fn(buckets_[pos.index].mapped(pos.slot))) { del_from_bucket(pos.index, pos.slot); } } return pos.status == ok; } /** * Equivalent to calling @ref uprase_fn with a functor that modifies the * given value and always returns false (meaning the element is not removed). * The passed-in functor must implement the method void * operator()(mapped_type&). */ template bool upsert(K &&key, F fn, Args &&... val) { return uprase_fn(std::forward(key), [&fn](mapped_type &v) { fn(v); return false; }, std::forward(val)...); } /** * Copies the value associated with @p key into @p val. Equivalent to * calling @ref find_fn with a functor that copies the value into @p val. @c * mapped_type must be @c CopyAssignable. */ template bool find(const K &key, mapped_type &val) const { // NOLINT return find_fn(key, [&val](const mapped_type &v) mutable { val = v; }); } /** Searches the table for @p key, and returns the associated value it * finds. @c mapped_type must be @c CopyConstructible. * * @tparam K type of the key * @param key the key to search for * @return the value associated with the given key * @throw std::out_of_range if the key is not found */ template mapped_type find(const K &key) const { const hash_value hv = hashed_key(key); const auto b = snapshot_and_lock_two(hv); const table_position pos = cuckoo_find(key, hv.partial, b.i1, b.i2); if (pos.status == ok) { return buckets_[pos.index].mapped(pos.slot); } else { throw std::out_of_range("key not found in table"); } } /** * Returns whether or not @p key is in the table. Equivalent to @ref * find_fn with a functor that does nothing. */ template bool contains(const K &key) const { return find_fn(key, [](const mapped_type &) {}); } /** * Updates the value associated with @p key to @p val. Equivalent to * calling @ref update_fn with a functor that assigns the existing mapped * value to @p val. @c mapped_type must be @c MoveAssignable or @c * CopyAssignable. */ template bool update(const K &key, V &&val) { return update_fn(key, [&val](mapped_type &v) { v = std::forward(val); }); } /** * Inserts the key-value pair into the table. Equivalent to calling @ref * upsert with a functor that does nothing. */ template bool insert(K &&key, Args &&... val) { return upsert(std::forward(key), [](mapped_type &) {}, std::forward(val)...); } /** * Inserts the key-value pair into the table. If the key is already in the * table, assigns the existing mapped value to @p val. Equivalent to * calling @ref upsert with a functor that assigns the mapped value to @p * val. */ template bool insert_or_assign(K &&key, V &&val) { return upsert(std::forward(key), [&val](mapped_type &m) { m = val; }, std::forward(val)); } /** * Erases the key from the table. Equivalent to calling @ref erase_fn with a * functor that just returns true. */ template bool erase(const K &key) { return erase_fn(key, [](mapped_type &) { return true; }); } /** * Resizes the table to the given hashpower. If this hashpower is not larger * than the current hashpower, then it decreases the hashpower to the * maximum of the specified value and the smallest hashpower that can hold * all the elements currently in the table. * * @param n the hashpower to set for the table * @return true if the table changed size, false otherwise */ bool rehash(size_type n) { return cuckoo_rehash(n); } /** * Reserve enough space in the table for the given number of elements. If * the table can already hold that many elements, the function will shrink * the table to the smallest hashpower that can hold the maximum of the * specified amount and the current table size. * * @param n the number of elements to reserve space for * @return true if the size of the table changed, false otherwise */ bool reserve(size_type n) { return cuckoo_reserve(n); } /** * Removes all elements in the table, calling their destructors. */ void clear() { auto all_locks_manager = lock_all(normal_mode()); cuckoo_clear(); } void clear_with_callback(std::function fn) { auto all_locks_manager = lock_all(normal_mode()); cuckoo_clear(); fn(); } /** * Construct a @ref locked_table object that owns all the locks in the * table. * * @return a \ref locked_table instance */ locked_table lock_table() { return locked_table(*this); } /** * The hashtable is equally partitioned into total_shards number of * partitions. * Calling partial_dump with shard_idx value i dumps ith parition using * dump_fn. * * @param shard * @param dump_fn * */ void partial_dump( monolith::hash_table::EmbeddingHashTableInterface::DumpShard shard, std::function dump_fn, monolith::hash_table::EmbeddingHashTableInterface::DumpIterator *iter) const { const size_type hash_size = hashsize(hashpower()); int Q = hash_size / shard.total; int R = hash_size % shard.total; int begin = (shard.idx * Q) + std::min(shard.idx, R); int end = begin + Q + (shard.idx < R ? 1 : 0); int64_t count = 0; const int64_t bucket_offset = iter->offset / slot_per_bucket(); int64_t slot_offset = iter->offset % slot_per_bucket(); auto get_offset = [this](size_type bucket, size_type slot) { return bucket * this->slot_per_bucket() + slot; }; const int64_t begin_bucket_idx = begin + bucket_offset; for (size_type i = begin_bucket_idx; i < end; ++i) { auto &bucket = buckets_[i]; for (size_type j = slot_offset; j < slot_per_bucket(); ++j) { if (bucket.occupied(j)) { ++count; const bool result = dump_fn(bucket.key(j), bucket.mapped(j)); if (!result || count >= shard.limit) { iter->offset = get_offset((i - begin), j + 1); return; } } } slot_offset = 0; } // Using +1 here since end might equal to begin. iter->offset = get_offset(end - begin + 1, 0); } void evict(std::function should_be_evict_fn) { locks_t &locks = get_current_locks(); for (size_t l = 0; l < locks.size(); ++l) { spinlock &lock = locks[l]; if (!lock.is_migrated()) continue; lock.lock(); const auto &lock_manager = LockManager(&lock); for (size_type bucket_ind = l; bucket_ind < buckets_.size(); bucket_ind += kMaxNumLocks) { auto &bucket = buckets_[bucket_ind]; for (size_type bucket_slot = 0; bucket_slot < slot_per_bucket(); ++bucket_slot) { if (!bucket.occupied(bucket_slot)) { continue; } const auto &kv = bucket.kvpair(bucket_slot); const auto &key = kv.first; const auto &entry = kv.second; if (should_be_evict_fn(key, entry)) { del_from_bucket(bucket_ind, bucket_slot); } } } } } private: // Constructor helpers void add_locks_from_other(const cuckoohash_map &other) { locks_t &other_locks = other.get_current_locks(); all_locks_.emplace_back(other_locks.size(), spinlock(), get_allocator()); std::copy(other_locks.begin(), other_locks.end(), get_current_locks().begin()); } // Hashing types and functions // true if the key is small and simple, which means using partial keys for // lookup would probably slow us down static constexpr bool is_simple() { return std::is_pod::value && sizeof(key_type) <= 8; } // Whether or not the data is nothrow-move-constructible. static constexpr bool is_data_nothrow_move_constructible() { return std::is_nothrow_move_constructible::value && std::is_nothrow_move_constructible::value; } // Contains a hash and partial for a given key. The partial key is used for // partial-key cuckoohashing, and for finding the alternate bucket of that a // key hashes to. struct hash_value { size_type hash; partial_t partial; }; template hash_value hashed_key(const K &key) const { const size_type hash = hash_function()(key); return {hash, partial_key(hash)}; } template size_type hashed_key_only_hash(const K &key) const { return hash_function()(key); } // hashsize returns the number of buckets corresponding to a given // hashpower. static inline size_type hashsize(const size_type hp) { return size_type(1) << hp; } // hashmask returns the bitmask for the buckets array corresponding to a // given hashpower. static inline size_type hashmask(const size_type hp) { return hashsize(hp) - 1; } // The partial key must only depend on the hash value. It cannot change with // the hashpower, because, in order for `cuckoo_fast_double` to work // properly, the alt_index must only grow by one bit at the top each time we // expand the table. static partial_t partial_key(const size_type hash) { const uint64_t hash_64bit = hash; const uint32_t hash_32bit = (static_cast(hash_64bit) ^ static_cast(hash_64bit >> 32)); const uint16_t hash_16bit = (static_cast(hash_32bit) ^ static_cast(hash_32bit >> 16)); const uint8_t hash_8bit = (static_cast(hash_16bit) ^ static_cast(hash_16bit >> 8)); return hash_8bit; } // index_hash returns the first possible bucket that the given hashed key // could be. static inline size_type index_hash(const size_type hp, const size_type hv) { return hv & hashmask(hp); } // alt_index returns the other possible bucket that the given hashed key // could be. It takes the first possible bucket as a parameter. Note that // this function will return the first possible bucket if index is the // second possible bucket, so alt_index(ti, partial, alt_index(ti, partial, // index_hash(ti, hv))) == index_hash(ti, hv). static inline size_type alt_index(const size_type hp, const partial_t partial, const size_type index) { // ensure tag is nonzero for the multiply. 0xc6a4a7935bd1e995 is the // hash constant from 64-bit MurmurHash2 const size_type nonzero_tag = static_cast(partial) + 1; return (index ^ (nonzero_tag * 0xc6a4a7935bd1e995)) & hashmask(hp); } // Locking types // Counter type using counter_type = int64_t; // A fast, lightweight spinlock // // Per-spinlock, we also maintain some metadata about the contents of the // table. Storing data per-spinlock avoids false sharing issues when multiple // threads need to update this metadata. We store the following information: // // - elem_counter: A counter indicating how many elements in the table are // under this lock. One can compute the size of the table by summing the // elem_counter over all locks. // // - is_migrated: When resizing with cuckoo_fast_doulbe, we do not // immediately rehash elements from the old buckets array to the new one. // Instead, we'll mark all of the locks as not migrated. So anybody trying to // acquire the lock must also migrate the corresponding buckets if // !is_migrated. LIBCUCKOO_SQUELCH_PADDING_WARNING class LIBCUCKOO_ALIGNAS(64) spinlock { public: spinlock() : elem_counter_(0), is_migrated_(true) { lock_.clear(); } spinlock(const spinlock &other) noexcept : elem_counter_(other.elem_counter()), is_migrated_(other.is_migrated()) { lock_.clear(); } spinlock &operator=(const spinlock &other) noexcept { elem_counter() = other.elem_counter(); is_migrated() = other.is_migrated(); return *this; } void lock() noexcept { while (lock_.test_and_set(std::memory_order_acq_rel)) ; // NOLINT } void unlock() noexcept { lock_.clear(std::memory_order_release); } bool try_lock() noexcept { return !lock_.test_and_set(std::memory_order_acq_rel); } counter_type &elem_counter() noexcept { return elem_counter_; } counter_type elem_counter() const noexcept { return elem_counter_; } bool &is_migrated() noexcept { return is_migrated_; } bool is_migrated() const noexcept { return is_migrated_; } private: std::atomic_flag lock_; counter_type elem_counter_; bool is_migrated_; }; template using rebind_alloc = typename std::allocator_traits::template rebind_alloc; using locks_t = std::vector>; using all_locks_t = std::list>; // Classes for managing locked buckets. By storing and moving around sets of // locked buckets in these classes, we can ensure that they are unlocked // properly. struct LockDeleter { void operator()(spinlock *l) const { l->unlock(); } }; using LockManager = std::unique_ptr; // Each of the locking methods can operate in two modes: locked_table_mode // and normal_mode. When we're in locked_table_mode, we assume the caller has // already taken all locks on the buckets. We also require that all data is // rehashed immediately, so that the caller never has to look through any // locks. In normal_mode, we actually do take locks, and can rehash lazily. using locked_table_mode = std::integral_constant; using normal_mode = std::integral_constant; class TwoBuckets { public: TwoBuckets() {} TwoBuckets(size_type i1_, size_type i2_, locked_table_mode) : i1(i1_), i2(i2_) {} TwoBuckets(locks_t &locks, size_type i1_, size_type i2_, normal_mode) // NOLINT : i1(i1_), i2(i2_), first_manager_(&locks[lock_ind(i1)]), second_manager_((lock_ind(i1) != lock_ind(i2)) ? &locks[lock_ind(i2)] : nullptr) {} void unlock() { first_manager_.reset(); second_manager_.reset(); } size_type i1, i2; private: LockManager first_manager_, second_manager_; }; struct AllUnlocker { void operator()(cuckoohash_map *map) const { for (auto it = first_locked; it != map->all_locks_.end(); ++it) { locks_t &locks = *it; for (spinlock &lock : locks) { lock.unlock(); } } } typename all_locks_t::iterator first_locked; }; using AllLocksManager = std::unique_ptr; // This exception is thrown whenever we try to lock a bucket, but the // hashpower is not what was expected class hashpower_changed {}; // After taking a lock on the table for the given bucket, this function will // check the hashpower to make sure it is the same as what it was before the // lock was taken. If it isn't unlock the bucket and throw a // hashpower_changed exception. inline void check_hashpower(size_type hp, spinlock &lock) const { // NOLINT if (hashpower() != hp) { lock.unlock(); LIBCUCKOO_DBG("%s", "hashpower changed\n"); throw hashpower_changed(); } } // If necessary, rehashes the buckets corresponding to the given lock index, // and sets the is_migrated flag to true. We should only ever do migrations // if the data is nothrow move constructible, so this function is noexcept. // // This only works if our current locks array is at the maximum size, because // otherwise, rehashing could require taking other locks. Assumes the lock at // the given index is taken. // // If IS_LAZY is true, we assume the lock is being rehashed in a lazy // (on-demand) fashion, so we additionally decrement the number of locks we // need to lazy_rehash. This may trigger false sharing with other // lazy-rehashing threads, but the hope is that the fraction of such // operations is low-enough to not significantly impact overall performance. static constexpr bool kIsLazy = true; static constexpr bool kIsNotLazy = false; template void rehash_lock(size_t l) const noexcept { locks_t &locks = get_current_locks(); spinlock &lock = locks[l]; if (lock.is_migrated()) return; assert(is_data_nothrow_move_constructible()); assert(locks.size() == kMaxNumLocks); assert(old_buckets_.hashpower() + 1 == buckets_.hashpower()); assert(old_buckets_.size() >= kMaxNumLocks); // Iterate through all buckets in old_buckets that are controlled by this // lock, and move them into the current buckets array. for (size_type bucket_ind = l; bucket_ind < old_buckets_.size(); bucket_ind += kMaxNumLocks) { move_bucket(old_buckets_, buckets_, bucket_ind); } lock.is_migrated() = true; if (IS_LAZY) { decrement_num_remaining_lazy_rehash_locks(); } } // locks the given bucket index. // // throws hashpower_changed if it changed after taking the lock. LockManager lock_one(size_type, size_type, locked_table_mode) const { return LockManager(); } LockManager lock_one(size_type hp, size_type i, normal_mode) const { locks_t &locks = get_current_locks(); const size_type l = lock_ind(i); spinlock &lock = locks[l]; lock.lock(); check_hashpower(hp, lock); rehash_lock(l); return LockManager(&lock); } // locks the two bucket indexes, always locking the earlier index first to // avoid deadlock. If the two indexes are the same, it just locks one. // // throws hashpower_changed if it changed after taking the lock. TwoBuckets lock_two(size_type, size_type i1, size_type i2, locked_table_mode) const { return TwoBuckets(i1, i2, locked_table_mode()); } TwoBuckets lock_two(size_type hp, size_type i1, size_type i2, normal_mode) const { size_type l1 = lock_ind(i1); size_type l2 = lock_ind(i2); if (l2 < l1) { std::swap(l1, l2); } locks_t &locks = get_current_locks(); locks[l1].lock(); check_hashpower(hp, locks[l1]); if (l2 != l1) { locks[l2].lock(); } rehash_lock(l1); rehash_lock(l2); return TwoBuckets(locks, i1, i2, normal_mode()); } // lock_three locks the three bucket indexes in numerical order, returning // the containers as a two (i1 and i2) and a one (i3). The one will not be // active if i3 shares a lock index with i1 or i2. // // throws hashpower_changed if it changed after taking the lock. std::pair lock_three(size_type, size_type i1, size_type i2, size_type, locked_table_mode) const { return std::make_pair(TwoBuckets(i1, i2, locked_table_mode()), LockManager()); } std::pair lock_three(size_type hp, size_type i1, size_type i2, size_type i3, normal_mode) const { std::array l{{lock_ind(i1), lock_ind(i2), lock_ind(i3)}}; // Lock in order. if (l[2] < l[1]) std::swap(l[2], l[1]); if (l[2] < l[0]) std::swap(l[2], l[0]); if (l[1] < l[0]) std::swap(l[1], l[0]); locks_t &locks = get_current_locks(); locks[l[0]].lock(); check_hashpower(hp, locks[l[0]]); if (l[1] != l[0]) { locks[l[1]].lock(); } if (l[2] != l[1]) { locks[l[2]].lock(); } rehash_lock(l[0]); rehash_lock(l[1]); rehash_lock(l[2]); return std::make_pair(TwoBuckets(locks, i1, i2, normal_mode()), LockManager((lock_ind(i3) == lock_ind(i1) || lock_ind(i3) == lock_ind(i2)) ? nullptr : &locks[lock_ind(i3)])); } // snapshot_and_lock_two loads locks the buckets associated with the given // hash value, making sure the hashpower doesn't change before the locks are // taken. Thus it ensures that the buckets and locks corresponding to the // hash value will stay correct as long as the locks are held. It returns // the bucket indices associated with the hash value and the current // hashpower. template TwoBuckets snapshot_and_lock_two(const hash_value &hv) const { while (true) { // Keep the current hashpower and locks we're using to compute the buckets const size_type hp = hashpower(); const size_type i1 = index_hash(hp, hv.hash); const size_type i2 = alt_index(hp, hv.partial, i1); try { return lock_two(hp, i1, i2, TABLE_MODE()); } catch (hashpower_changed &) { // The hashpower changed while taking the locks. Try again. continue; } } } // lock_all takes all the locks, and returns a deleter object that releases // the locks upon destruction. It does NOT perform any hashpower checks, or // rehash any un-migrated buckets. // // Note that after taking all the locks, it is okay to resize the buckets_ // container, since no other threads should be accessing the buckets. AllLocksManager lock_all(locked_table_mode) { return AllLocksManager(); } AllLocksManager lock_all(normal_mode) { // all_locks_ should never decrease in size, so if it is non-empty now, it // will remain non-empty assert(!all_locks_.empty()); const auto first_locked = std::prev(all_locks_.end()); auto current_locks = first_locked; while (current_locks != all_locks_.end()) { locks_t &locks = *current_locks; for (spinlock &lock : locks) { lock.lock(); } ++current_locks; } // Once we have taken all the locks of the "current" container, nobody // else can do locking operations on the table. return AllLocksManager(this, AllUnlocker{first_locked}); } // lock_ind converts an index into buckets to an index into locks. static inline size_type lock_ind(const size_type bucket_ind) { return bucket_ind & (kMaxNumLocks - 1); } // Data storage types and functions // The type of the bucket using bucket = typename buckets_t::bucket; // Status codes for internal functions enum cuckoo_status { ok, failure, failure_key_not_found, failure_key_duplicated, failure_table_full, failure_under_expansion, }; // A composite type for functions that need to return a table position, and // a status code. struct table_position { size_type index; size_type slot; cuckoo_status status; }; // Searching types and functions // cuckoo_find searches the table for the given key, returning the position // of the element found, or a failure status code if the key wasn't found. // It expects the locks to be taken and released outside the function. template table_position cuckoo_find(const K &key, const partial_t partial, const size_type i1, const size_type i2) const { int slot = try_read_from_bucket(buckets_[i1], partial, key); if (slot != -1) { return table_position{i1, static_cast(slot), ok}; } slot = try_read_from_bucket(buckets_[i2], partial, key); if (slot != -1) { return table_position{i2, static_cast(slot), ok}; } return table_position{0, 0, failure_key_not_found}; } // try_read_from_bucket will search the bucket for the given key and return // the index of the slot if found, or -1 if not found. template int try_read_from_bucket(const bucket &b, const partial_t partial, const K &key) const { // Silence a warning from MSVC about partial being unused if is_simple. (void)partial; for (int i = 0; i < static_cast(slot_per_bucket()); ++i) { if (!b.occupied(i) || (!is_simple() && partial != b.partial(i))) { continue; } else if (key_eq()(b.key(i), key)) { return i; } } return -1; } // Insertion types and function /** * Runs cuckoo_insert in a loop until it succeeds in insert and upsert, so * we pulled out the loop to avoid duplicating logic. * * @param hv the hash value of the key * @param b bucket locks * @param key the key to insert * @return table_position of the location to insert the new element, or the * site of the duplicate element with a status code if there was a duplicate. * In either case, the locks will still be held after the function ends. * @throw load_factor_too_low if expansion is necessary, but the * load factor of the table is below the threshold */ template table_position cuckoo_insert_loop(hash_value hv, TwoBuckets &b, K &key) { // NOLINT table_position pos; while (true) { const size_type hp = hashpower(); pos = cuckoo_insert(hv, b, key); switch (pos.status) { case ok: case failure_key_duplicated: return pos; case failure_table_full: // Expand the table and try again, re-grabbing the locks cuckoo_fast_double(hp); b = snapshot_and_lock_two(hv); break; case failure_under_expansion: // The table was under expansion while we were cuckooing. Re-grab the // locks and try again. b = snapshot_and_lock_two(hv); break; default: assert(false); } } } // cuckoo_insert tries to find an empty slot in either of the buckets to // insert the given key into, performing cuckoo hashing if necessary. It // expects the locks to be taken outside the function. Before inserting, it // checks that the key isn't already in the table. cuckoo hashing presents // multiple concurrency issues, which are explained in the function. The // following return states are possible: // // ok -- Found an empty slot, locks will be held on both buckets after the // function ends, and the position of the empty slot is returned // // failure_key_duplicated -- Found a duplicate key, locks will be held, and // the position of the duplicate key will be returned // // failure_under_expansion -- Failed due to a concurrent expansion // operation. Locks are released. No meaningful position is returned. // // failure_table_full -- Failed to find an empty slot for the table. Locks // are released. No meaningful position is returned. template table_position cuckoo_insert(const hash_value hv, TwoBuckets &b, K &key) { // NOLINT int res1, res2; bucket &b1 = buckets_[b.i1]; if (!try_find_insert_bucket(b1, res1, hv.partial, key)) { return table_position{b.i1, static_cast(res1), failure_key_duplicated}; } bucket &b2 = buckets_[b.i2]; if (!try_find_insert_bucket(b2, res2, hv.partial, key)) { return table_position{b.i2, static_cast(res2), failure_key_duplicated}; } if (res1 != -1) { return table_position{b.i1, static_cast(res1), ok}; } if (res2 != -1) { return table_position{b.i2, static_cast(res2), ok}; } // We are unlucky, so let's perform cuckoo hashing. size_type insert_bucket = 0; size_type insert_slot = 0; cuckoo_status st = run_cuckoo(b, insert_bucket, insert_slot); if (st == failure_under_expansion) { // The run_cuckoo operation operated on an old version of the table, // so we have to try again. We signal to the calling insert method // to try again by returning failure_under_expansion. return table_position{0, 0, failure_under_expansion}; } else if (st == ok) { assert(TABLE_MODE() == locked_table_mode() || !get_current_locks()[lock_ind(b.i1)].try_lock()); assert(TABLE_MODE() == locked_table_mode() || !get_current_locks()[lock_ind(b.i2)].try_lock()); assert(!buckets_[insert_bucket].occupied(insert_slot)); assert(insert_bucket == index_hash(hashpower(), hv.hash) || insert_bucket == alt_index(hashpower(), hv.partial, index_hash(hashpower(), hv.hash))); // Since we unlocked the buckets during run_cuckoo, another insert // could have inserted the same key into either b.i1 or // b.i2, so we check for that before doing the insert. table_position pos = cuckoo_find(key, hv.partial, b.i1, b.i2); if (pos.status == ok) { pos.status = failure_key_duplicated; return pos; } return table_position{insert_bucket, insert_slot, ok}; } assert(st == failure); LIBCUCKOO_DBG( "hash table is full (hashpower = %zu, hash_items = %zu," "load factor = %.2f), need to increase hashpower\n", hashpower(), size(), load_factor()); return table_position{0, 0, failure_table_full}; } // add_to_bucket will insert the given key-value pair into the slot. The key // and value will be move-constructed into the table, so they are not valid // for use afterwards. template void add_to_bucket(const size_type bucket_ind, const size_type slot, const partial_t partial, K &&key, Args &&... val) { buckets_.setKV(bucket_ind, slot, partial, std::forward(key), std::forward(val)...); ++get_current_locks()[lock_ind(bucket_ind)].elem_counter(); } // try_find_insert_bucket will search the bucket for the given key, and for // an empty slot. If the key is found, we store the slot of the key in // `slot` and return false. If we find an empty slot, we store its position // in `slot` and return true. If no duplicate key is found and no empty slot // is found, we store -1 in `slot` and return true. template bool try_find_insert_bucket(const bucket &b, int &slot, // NOLINT const partial_t partial, const K &key) const { // Silence a warning from MSVC about partial being unused if is_simple. (void)partial; slot = -1; for (int i = 0; i < static_cast(slot_per_bucket()); ++i) { if (b.occupied(i)) { if (!is_simple() && partial != b.partial(i)) { continue; } if (key_eq()(b.key(i), key)) { slot = i; return false; } } else { slot = i; } } return true; } // CuckooRecord holds one position in a cuckoo path. Since cuckoopath // elements only define a sequence of alternate hashings for different hash // values, we only need to keep track of the hash values being moved, rather // than the keys themselves. typedef struct { size_type bucket; size_type slot; hash_value hv; } CuckooRecord; // The maximum number of items in a cuckoo BFS path. It determines the // maximum number of slots we search when cuckooing. static constexpr uint8_t MAX_BFS_PATH_LEN = 5; // An array of CuckooRecords using CuckooRecords = std::array; // run_cuckoo performs cuckoo hashing on the table in an attempt to free up // a slot on either of the insert buckets, which are assumed to be locked // before the start. On success, the bucket and slot that was freed up is // stored in insert_bucket and insert_slot. In order to perform the search // and the swaps, it has to release the locks, which can lead to certain // concurrency issues, the details of which are explained in the function. // If run_cuckoo returns ok (success), then `b` will be active, otherwise it // will not. template cuckoo_status run_cuckoo(TwoBuckets &b, size_type &insert_bucket, // NOLINT size_type &insert_slot) { // NOLINT // We must unlock the buckets here, so that cuckoopath_search and // cuckoopath_move can lock buckets as desired without deadlock. // cuckoopath_move has to move something out of one of the original // buckets as its last operation, and it will lock both buckets and // leave them locked after finishing. This way, we know that if // cuckoopath_move succeeds, then the buckets needed for insertion are // still locked. If cuckoopath_move fails, the buckets are unlocked and // we try again. This unlocking does present two problems. The first is // that another insert on the same key runs and, finding that the key // isn't in the table, inserts the key into the table. Then we insert // the key into the table, causing a duplication. To check for this, we // search the buckets for the key we are trying to insert before doing // so (this is done in cuckoo_insert, and requires that both buckets are // locked). Another problem is that an expansion runs and changes the // hashpower, meaning the buckets may not be valid anymore. In this // case, the cuckoopath functions will have thrown a hashpower_changed // exception, which we catch and handle here. size_type hp = hashpower(); b.unlock(); CuckooRecords cuckoo_path; bool done = false; try { while (!done) { const int depth = cuckoopath_search(hp, cuckoo_path, b.i1, b.i2); if (depth < 0) { break; } if (cuckoopath_move(hp, cuckoo_path, depth, b)) { insert_bucket = cuckoo_path[0].bucket; insert_slot = cuckoo_path[0].slot; assert(insert_bucket == b.i1 || insert_bucket == b.i2); assert(TABLE_MODE() == locked_table_mode() || !get_current_locks()[lock_ind(b.i1)].try_lock()); assert(TABLE_MODE() == locked_table_mode() || !get_current_locks()[lock_ind(b.i2)].try_lock()); assert(!buckets_[insert_bucket].occupied(insert_slot)); done = true; break; } } } catch (hashpower_changed &) { // The hashpower changed while we were trying to cuckoo, which means // we want to retry. b.i1 and b.i2 should not be locked // in this case. return failure_under_expansion; } return done ? ok : failure; } // cuckoopath_search finds a cuckoo path from one of the starting buckets to // an empty slot in another bucket. It returns the depth of the discovered // cuckoo path on success, and -1 on failure. Since it doesn't take locks on // the buckets it searches, the data can change between this function and // cuckoopath_move. Thus cuckoopath_move checks that the data matches the // cuckoo path before changing it. // // throws hashpower_changed if it changed during the search. template int cuckoopath_search(const size_type hp, CuckooRecords &cuckoo_path, // NOLINT const size_type i1, const size_type i2) { b_slot x = slot_search(hp, i1, i2); if (x.depth == -1) { return -1; } // Fill in the cuckoo path slots from the end to the beginning. for (int i = x.depth; i >= 0; i--) { cuckoo_path[i].slot = x.pathcode % slot_per_bucket(); x.pathcode /= slot_per_bucket(); } // Fill in the cuckoo_path buckets and keys from the beginning to the // end, using the final pathcode to figure out which bucket the path // starts on. Since data could have been modified between slot_search // and the computation of the cuckoo path, this could be an invalid // cuckoo_path. CuckooRecord &first = cuckoo_path[0]; if (x.pathcode == 0) { first.bucket = i1; } else { assert(x.pathcode == 1); first.bucket = i2; } { const auto lock_manager = lock_one(hp, first.bucket, TABLE_MODE()); const bucket &b = buckets_[first.bucket]; if (!b.occupied(first.slot)) { // We can terminate here return 0; } first.hv = hashed_key(b.key(first.slot)); } for (int i = 1; i <= x.depth; ++i) { CuckooRecord &curr = cuckoo_path[i]; const CuckooRecord &prev = cuckoo_path[i - 1]; assert(prev.bucket == index_hash(hp, prev.hv.hash) || prev.bucket == alt_index(hp, prev.hv.partial, index_hash(hp, prev.hv.hash))); // We get the bucket that this slot is on by computing the alternate // index of the previous bucket curr.bucket = alt_index(hp, prev.hv.partial, prev.bucket); const auto lock_manager = lock_one(hp, curr.bucket, TABLE_MODE()); const bucket &b = buckets_[curr.bucket]; if (!b.occupied(curr.slot)) { // We can terminate here return i; } curr.hv = hashed_key(b.key(curr.slot)); } return x.depth; } // cuckoopath_move moves keys along the given cuckoo path in order to make // an empty slot in one of the buckets in cuckoo_insert. Before the start of // this function, the two insert-locked buckets were unlocked in run_cuckoo. // At the end of the function, if the function returns true (success), then // both insert-locked buckets remain locked. If the function is // unsuccessful, then both insert-locked buckets will be unlocked. // // throws hashpower_changed if it changed during the move. template bool cuckoopath_move(const size_type hp, CuckooRecords &cuckoo_path, // NOLINT size_type depth, TwoBuckets &b) { // NOLINT if (depth == 0) { // There is a chance that depth == 0, when try_add_to_bucket sees // both buckets as full and cuckoopath_search finds one empty. In // this case, we lock both buckets. If the slot that // cuckoopath_search found empty isn't empty anymore, we unlock them // and return false. Otherwise, the bucket is empty and insertable, // so we hold the locks and return true. const size_type bucket_i = cuckoo_path[0].bucket; assert(bucket_i == b.i1 || bucket_i == b.i2); b = lock_two(hp, b.i1, b.i2, TABLE_MODE()); if (!buckets_[bucket_i].occupied(cuckoo_path[0].slot)) { return true; } else { b.unlock(); return false; } } while (depth > 0) { CuckooRecord &from = cuckoo_path[depth - 1]; CuckooRecord &to = cuckoo_path[depth]; const size_type fs = from.slot; const size_type ts = to.slot; TwoBuckets twob; LockManager extra_manager; if (depth == 1) { // Even though we are only swapping out of one of the original // buckets, we have to lock both of them along with the slot we // are swapping to, since at the end of this function, they both // must be locked. We store tb inside the extrab container so it // is unlocked at the end of the loop. std::tie(twob, extra_manager) = lock_three(hp, b.i1, b.i2, to.bucket, TABLE_MODE()); } else { twob = lock_two(hp, from.bucket, to.bucket, TABLE_MODE()); } bucket &fb = buckets_[from.bucket]; bucket &tb = buckets_[to.bucket]; // We plan to kick out fs, but let's check if it is still there; // there's a small chance we've gotten scooped by a later cuckoo. If // that happened, just... try again. Also the slot we are filling in // may have already been filled in by another thread, or the slot we // are moving from may be empty, both of which invalidate the swap. // We only need to check that the hash value is the same, because, // even if the keys are different and have the same hash value, then // the cuckoopath is still valid. if (tb.occupied(ts) || !fb.occupied(fs) || hashed_key_only_hash(fb.key(fs)) != from.hv.hash) { return false; } buckets_.setKV(to.bucket, ts, fb.partial(fs), fb.movable_key(fs), std::move(fb.mapped(fs))); buckets_.eraseKV(from.bucket, fs); if (depth == 1) { // Hold onto the locks contained in twob b = std::move(twob); } depth--; } return true; } // A constexpr version of pow that we can use for various compile-time // constants and checks. static constexpr size_type const_pow(size_type a, size_type b) { return (b == 0) ? 1 : a * const_pow(a, b - 1); } // b_slot holds the information for a BFS path through the table. struct b_slot { // The bucket of the last item in the path. size_type bucket; // a compressed representation of the slots for each of the buckets in // the path. pathcode is sort of like a base-slot_per_bucket number, and // we need to hold at most MAX_BFS_PATH_LEN slots. Thus we need the // maximum pathcode to be at least slot_per_bucket()^(MAX_BFS_PATH_LEN). uint16_t pathcode; static_assert(const_pow(slot_per_bucket(), MAX_BFS_PATH_LEN) < std::numeric_limits::max(), "pathcode may not be large enough to encode a cuckoo " "path"); // The 0-indexed position in the cuckoo path this slot occupies. It must // be less than MAX_BFS_PATH_LEN, and also able to hold negative values. int8_t depth; static_assert(MAX_BFS_PATH_LEN - 1 <= std::numeric_limits::max(), "The depth type must able to hold a value of" " MAX_BFS_PATH_LEN - 1"); static_assert(-1 >= std::numeric_limits::min(), "The depth type must be able to hold a value of -1"); b_slot() {} b_slot(const size_type b, const uint16_t p, const decltype(depth) d) : bucket(b), pathcode(p), depth(d) { assert(d < MAX_BFS_PATH_LEN); } }; // b_queue is the queue used to store b_slots for BFS cuckoo hashing. class b_queue { public: b_queue() noexcept : first_(0), last_(0) {} void enqueue(b_slot x) { assert(!full()); slots_[last_++] = x; } b_slot dequeue() { assert(!empty()); assert(first_ < last_); b_slot &x = slots_[first_++]; return x; } bool empty() const { return first_ == last_; } bool full() const { return last_ == MAX_CUCKOO_COUNT; } private: // The size of the BFS queue. It holds just enough elements to fulfill a // MAX_BFS_PATH_LEN search for two starting buckets, with no circular // wrapping-around. For one bucket, this is the geometric sum // sum_{k=0}^{MAX_BFS_PATH_LEN-1} slot_per_bucket()^k // = (1 - slot_per_bucket()^MAX_BFS_PATH_LEN) / (1 - slot_per_bucket()) // // Note that if slot_per_bucket() == 1, then this simply equals // MAX_BFS_PATH_LEN. static_assert(slot_per_bucket() > 0, "SLOT_PER_BUCKET must be greater than 0."); static constexpr size_type MAX_CUCKOO_COUNT = 2 * ((slot_per_bucket() == 1) ? MAX_BFS_PATH_LEN : (const_pow(slot_per_bucket(), MAX_BFS_PATH_LEN) - 1) / (slot_per_bucket() - 1)); // An array of b_slots. Since we allocate just enough space to complete a // full search, we should never exceed the end of the array. b_slot slots_[MAX_CUCKOO_COUNT]; // The index of the head of the queue in the array size_type first_; // One past the index of the last_ item of the queue in the array. size_type last_; }; // slot_search searches for a cuckoo path using breadth-first search. It // starts with the i1 and i2 buckets, and, until it finds a bucket with an // empty slot, adds each slot of the bucket in the b_slot. If the queue runs // out of space, it fails. // // throws hashpower_changed if it changed during the search template b_slot slot_search(const size_type hp, const size_type i1, const size_type i2) { b_queue q; // The initial pathcode informs cuckoopath_search which bucket the path // starts on q.enqueue(b_slot(i1, 0, 0)); q.enqueue(b_slot(i2, 1, 0)); while (!q.empty()) { b_slot x = q.dequeue(); auto lock_manager = lock_one(hp, x.bucket, TABLE_MODE()); bucket &b = buckets_[x.bucket]; // Picks a (sort-of) random slot to start from size_type starting_slot = x.pathcode % slot_per_bucket(); for (size_type i = 0; i < slot_per_bucket(); ++i) { uint16_t slot = (starting_slot + i) % slot_per_bucket(); if (!b.occupied(slot)) { // We can terminate the search here x.pathcode = x.pathcode * slot_per_bucket() + slot; return x; } // If x has less than the maximum number of path components, // create a new b_slot item, that represents the bucket we would // have come from if we kicked out the item at this slot. const partial_t partial = b.partial(slot); if (x.depth < MAX_BFS_PATH_LEN - 1) { assert(!q.full()); b_slot y(alt_index(hp, partial, x.bucket), x.pathcode * slot_per_bucket() + slot, x.depth + 1); q.enqueue(y); } } } // We didn't find a short-enough cuckoo path, so the search terminated. // Return a failure value. return b_slot(0, 0, -1); } // cuckoo_fast_double will double the size of the table by taking advantage // of the properties of index_hash and alt_index. If the key's move // constructor is not noexcept, we use cuckoo_expand_simple, since that // provides a strong exception guarantee. template cuckoo_status cuckoo_fast_double(size_type current_hp) { if (!is_data_nothrow_move_constructible()) { LIBCUCKOO_DBG("%s", "cannot run cuckoo_fast_double because key-value" " pair is not nothrow move constructible"); return cuckoo_expand_simple(current_hp + 1); } const size_type new_hp = current_hp + 1; auto all_locks_manager = lock_all(TABLE_MODE()); cuckoo_status st = check_resize_validity(current_hp, new_hp); if (st != ok) { return st; } // Finish rehashing any un-rehashed buckets, so that we can move out any // remaining data in old_buckets_. We should be running cuckoo_fast_double // only after trying to cuckoo for a while, which should mean we've tried // going through most of the table and thus done a lot of rehashing // already. So this shouldn't be too expensive. // // We restrict ourselves to the current thread because we want to avoid // possibly spawning extra threads in this function, unless the // circumstances are predictable (i.e. data is nothrow move constructible, // we're in locked_table mode and must keep the buckets_ container // up-to-date, etc). // // If we have fewer than kNumLocks buckets, there shouldn't be any buckets // left to rehash, so this should be a no-op. { locks_t ¤t_locks = get_current_locks(); for (size_t i = 0; i < current_locks.size(); ++i) { rehash_lock(i); } num_remaining_lazy_rehash_locks(0); } // Resize the locks array if necessary. This is done before we update the // hashpower so that other threads don't grab the new hashpower and the old // locks. maybe_resize_locks(size_type(1) << new_hp); locks_t ¤t_locks = get_current_locks(); // Move the current buckets into old_buckets_, and create a new empty // buckets container, which will become the new current one. The // old_buckets_ data will be destroyed when move-assigning to buckets_. old_buckets_.swap(buckets_); buckets_ = buckets_t(new_hp, get_allocator()); // If we have less than kMaxNumLocks buckets, we do a full rehash in the // current thread. On-demand rehashing wouldn't be very easy with less than // kMaxNumLocks buckets, because it would require taking extra lower-index // locks to do the rehashing. Because kMaxNumLocks is relatively small, // this should not be very expensive. We have already set all locks to // migrated at the start of the function, so we shouldn't have to touch // them again. // // Otherwise, if we're in locked_table_mode, the expectation is that we can // access the latest data in buckets_ without taking any locks. So we must // rehash the data immediately. This would not be much different from // lazy-rehashing in locked_table_mode anyways, because it would still be // going on in one thread. if (old_buckets_.size() < kMaxNumLocks) { for (size_type i = 0; i < old_buckets_.size(); ++i) { move_bucket(old_buckets_, buckets_, i); } // This will also delete the old_buckets_ data. num_remaining_lazy_rehash_locks(0); } else { // Mark all current locks as un-migrated, so that we rehash the data // on-demand when the locks are taken. for (spinlock &lock : current_locks) { lock.is_migrated() = false; } num_remaining_lazy_rehash_locks(current_locks.size()); if (std::is_same::value) { rehash_with_workers(); } } return ok; } void move_bucket(buckets_t &old_buckets, buckets_t &new_buckets, // NOLINT size_type old_bucket_ind) const noexcept { const size_t old_hp = old_buckets.hashpower(); const size_t new_hp = new_buckets.hashpower(); // By doubling the table size, the index_hash and alt_index of each key got // one bit added to the top, at position old_hp, which means anything we // have to move will either be at the same bucket position, or exactly // hashsize(old_hp) later than the current bucket. bucket &old_bucket = old_buckets_[old_bucket_ind]; const size_type new_bucket_ind = old_bucket_ind + hashsize(old_hp); size_type new_bucket_slot = 0; // For each occupied slot, either move it into its same position in the // new buckets container, or to the first available spot in the new // bucket in the new buckets container. for (size_type old_bucket_slot = 0; old_bucket_slot < slot_per_bucket(); ++old_bucket_slot) { if (!old_bucket.occupied(old_bucket_slot)) { continue; } const hash_value hv = hashed_key(old_bucket.key(old_bucket_slot)); const size_type old_ihash = index_hash(old_hp, hv.hash); const size_type old_ahash = alt_index(old_hp, hv.partial, old_ihash); const size_type new_ihash = index_hash(new_hp, hv.hash); const size_type new_ahash = alt_index(new_hp, hv.partial, new_ihash); size_type dst_bucket_ind, dst_bucket_slot; if ((old_bucket_ind == old_ihash && new_ihash == new_bucket_ind) || (old_bucket_ind == old_ahash && new_ahash == new_bucket_ind)) { // We're moving the key to the new bucket dst_bucket_ind = new_bucket_ind; dst_bucket_slot = new_bucket_slot++; } else { // We're moving the key to the old bucket assert((old_bucket_ind == old_ihash && new_ihash == old_ihash) || (old_bucket_ind == old_ahash && new_ahash == old_ahash)); dst_bucket_ind = old_bucket_ind; dst_bucket_slot = old_bucket_slot; } new_buckets.setKV(dst_bucket_ind, dst_bucket_slot++, old_bucket.partial(old_bucket_slot), old_bucket.movable_key(old_bucket_slot), std::move(old_bucket.mapped(old_bucket_slot))); } } // Checks whether the resize is okay to proceed. Returns a status code, or // throws an exception, depending on the error type. using automatic_resize = std::integral_constant; using manual_resize = std::integral_constant; template cuckoo_status check_resize_validity(const size_type orig_hp, const size_type new_hp) { const size_type mhp = maximum_hashpower(); if (mhp != NO_MAXIMUM_HASHPOWER && new_hp > mhp) { throw maximum_hashpower_exceeded(new_hp); } if (AUTO_RESIZE::value && load_factor() < minimum_load_factor()) { throw load_factor_too_low(minimum_load_factor()); } if (hashpower() != orig_hp) { // Most likely another expansion ran before this one could grab the // locks LIBCUCKOO_DBG("%s", "another expansion is on-going\n"); return failure_under_expansion; } return ok; } // When we expand the contanier, we may need to expand the locks array, if // the current locks array is smaller than the maximum size and also smaller // than the number of buckets in the upcoming buckets container. In this // case, we grow the locks array to the smaller of the maximum lock array // size and the bucket count. This is done by allocating an entirely new lock // container, taking all the locks, copying over the counters, and then // finally adding it to the end of `all_locks_`, thereby designating it the // "current" locks container. It is the responsibility of the caller to // unlock all locks taken, including the new locks, whenever it is done with // them, so that old threads can resume and potentially re-start. void maybe_resize_locks(size_type new_bucket_count) { locks_t ¤t_locks = get_current_locks(); if (!(current_locks.size() < kMaxNumLocks && current_locks.size() < new_bucket_count)) { return; } locks_t new_locks(std::min(size_type(kMaxNumLocks), new_bucket_count), spinlock(), get_allocator()); assert(new_locks.size() > current_locks.size()); std::copy(current_locks.begin(), current_locks.end(), new_locks.begin()); for (spinlock &lock : new_locks) { lock.lock(); } all_locks_.emplace_back(std::move(new_locks)); } // cuckoo_expand_simple will resize the table to at least the given // new_hashpower. When we're shrinking the table, if the current table // contains more elements than can be held by new_hashpower, the resulting // hashpower will be greater than `new_hp`. It needs to take all the bucket // locks, since no other operations can change the table during expansion. // Throws maximum_hashpower_exceeded if we're expanding beyond the // maximum hashpower, and we have an actual limit. template cuckoo_status cuckoo_expand_simple(size_type new_hp) { auto all_locks_manager = lock_all(TABLE_MODE()); const size_type hp = hashpower(); cuckoo_status st = check_resize_validity(hp, new_hp); if (st != ok) { return st; } // Finish rehashing any data into buckets_. rehash_with_workers(); // Creates a new hash table with hashpower new_hp and adds all the elements // from buckets_ and old_buckets_. Allow this map to spawn extra threads if // it needs to resize during the resize. cuckoohash_map new_map(hashsize(new_hp) * slot_per_bucket(), hash_function(), key_eq(), get_allocator()); new_map.max_num_worker_threads(max_num_worker_threads()); parallel_exec(0, hashsize(hp), [this, &new_map](size_type i, size_type end, std::exception_ptr &eptr) { try { for (; i < end; ++i) { auto &bucket = buckets_[i]; for (size_type j = 0; j < slot_per_bucket(); ++j) { if (bucket.occupied(j)) { new_map.insert(bucket.movable_key(j), std::move(bucket.mapped(j))); } } } } catch (...) { eptr = std::current_exception(); } }); // Finish rehashing any data in new_map. new_map.rehash_with_workers(); // Swap the buckets_ container with new_map's. This is okay, because we // have all the locks, so nobody else should be reading from the buckets // array. Then the old buckets will be deleted when new_map is deleted. maybe_resize_locks(new_map.bucket_count()); buckets_.swap(new_map.buckets_); return ok; } // Executes the function over the given range, splitting the work between the // current thread and any available worker threads. // // In the noexcept version, the functor must implement operator()(size_type // start, size_type end). // // In the non-noexcept version, the functor will receive an additional // std::exception_ptr& argument. template void parallel_exec_noexcept(size_type start, size_type end, F func) { const size_type num_extra_threads = max_num_worker_threads(); const size_type num_workers = 1 + num_extra_threads; size_type work_per_thread = (end - start) / num_workers; std::vector> threads( get_allocator()); threads.reserve(num_extra_threads); for (size_type i = 0; i < num_extra_threads; ++i) { threads.emplace_back(func, start, start + work_per_thread); start += work_per_thread; } func(start, end); for (std::thread &t : threads) { t.join(); } } template void parallel_exec(size_type start, size_type end, F func) { const size_type num_extra_threads = max_num_worker_threads(); const size_type num_workers = 1 + num_extra_threads; size_type work_per_thread = (end - start) / num_workers; std::vector> threads( get_allocator()); threads.reserve(num_extra_threads); std::vector> eptrs( num_workers, nullptr, get_allocator()); for (size_type i = 0; i < num_extra_threads; ++i) { threads.emplace_back(func, start, start + work_per_thread, std::ref(eptrs[i])); start += work_per_thread; } func(start, end, std::ref(eptrs.back())); for (std::thread &t : threads) { t.join(); } for (std::exception_ptr &eptr : eptrs) { if (eptr) std::rethrow_exception(eptr); } } // Does a batch resize of the remaining data in old_buckets_. Assumes all the // locks have already been taken. void rehash_with_workers() noexcept { locks_t ¤t_locks = get_current_locks(); parallel_exec_noexcept(0, current_locks.size(), [this](size_type start, size_type end) { for (size_type i = start; i < end; ++i) { rehash_lock(i); } }); num_remaining_lazy_rehash_locks(0); } // Deletion functions // Removes an item from a bucket, decrementing the associated counter as // well. void del_from_bucket(const size_type bucket_ind, const size_type slot) { buckets_.eraseKV(bucket_ind, slot); --get_current_locks()[lock_ind(bucket_ind)].elem_counter(); } // Empties the table, calling the destructors of all the elements it removes // from the table. It assumes the locks are taken as necessary. void cuckoo_clear() { buckets_.clear(); // This will also clear out any data in old_buckets and delete it, if we // haven't already. num_remaining_lazy_rehash_locks(0); for (spinlock &lock : get_current_locks()) { lock.elem_counter() = 0; lock.is_migrated() = true; } } // Rehashing functions template bool cuckoo_rehash(size_type n) { const size_type hp = hashpower(); if (n == hp) { return false; } return cuckoo_expand_simple(n) == ok; } template bool cuckoo_reserve(size_type n) { const size_type hp = hashpower(); const size_type new_hp = reserve_calc(n); if (new_hp == hp) { return false; } return cuckoo_expand_simple(new_hp) == ok; } // Miscellaneous functions // reserve_calc takes in a parameter specifying a certain number of slots // for a table and returns the smallest hashpower that will hold n elements. static size_type reserve_calc(const size_type n) { const size_type buckets = (n + slot_per_bucket() - 1) / slot_per_bucket(); size_type blog2; for (blog2 = 0; (size_type(1) << blog2) < buckets; ++blog2) ; // NOLINT assert(n <= buckets * slot_per_bucket() && buckets <= hashsize(blog2)); return blog2; } // This class is a friend for unit testing friend class UnitTestInternalAccess; static constexpr size_type kMaxNumLocks = 1UL << 16; locks_t &get_current_locks() const { return all_locks_.back(); } // Get/set/decrement num remaining lazy rehash locks. If we reach 0 remaining // lazy locks, we can deallocate the memory in old_buckets_. size_type num_remaining_lazy_rehash_locks() const { return num_remaining_lazy_rehash_locks_.load(std::memory_order_acquire); } void num_remaining_lazy_rehash_locks(size_type n) const { num_remaining_lazy_rehash_locks_.store(n, std::memory_order_release); if (n == 0) { old_buckets_.clear_and_deallocate(); } } void decrement_num_remaining_lazy_rehash_locks() const { size_type old_num_remaining = num_remaining_lazy_rehash_locks_.fetch_sub( 1, std::memory_order_acq_rel); assert(old_num_remaining >= 1); if (old_num_remaining == 1) { old_buckets_.clear_and_deallocate(); } } // Member variables // The hash function hasher hash_fn_; // The equality function key_equal eq_fn_; // container of buckets. The size or memory location of the buckets cannot be // changed unless all the locks are taken on the table. Thus, it is only safe // to access the buckets_ container when you have at least one lock held. // // Marked mutable so that const methods can rehash into this container when // necessary. mutable buckets_t buckets_; // An old container of buckets, containing data that may not have been // rehashed into the current one. If valid, this will always have a hashpower // exactly one less than the one in buckets_. // // Marked mutable so that const methods can rehash into this container when // necessary. mutable buckets_t old_buckets_; // A linked list of all lock containers. We never discard lock containers, // since there is currently no mechanism for detecting when all threads are // done looking at the memory. The back lock container in this list is // designated the "current" one, and is used by all operations taking locks. // This container can be modified if either it is empty (which should only // occur during construction), or if the modifying thread has taken all the // locks on the existing "current" container. In the latter case, a // modification must take place before a modification to the hashpower, so // that other threads can detect the change and adjust appropriately. Marked // mutable so that const methods can access and take locks. mutable all_locks_t all_locks_; // A small wrapper around std::atomic to make it copyable for constructors. template class CopyableAtomic : public std::atomic { public: using std::atomic::atomic; CopyableAtomic(const CopyableAtomic &other) noexcept : CopyableAtomic(other.load(std::memory_order_acquire)) {} CopyableAtomic &operator=(const CopyableAtomic &other) noexcept { this->store(other.load(std::memory_order_acquire), std::memory_order_release); return *this; } }; // We keep track of the number of remaining locks in the latest locks array, // that remain to be rehashed. Once this reaches 0, we can free the memory of // the old buckets. It should only be accessed or modified when // lazy-rehashing a lock, so not in the common case. // // Marked mutable so that we can modify this during rehashing. mutable CopyableAtomic num_remaining_lazy_rehash_locks_; // Stores the minimum load factor allowed for automatic expansions. Whenever // an automatic expansion is triggered (during an insertion where cuckoo // hashing fails, for example), we check the load factor against this // double, and throw an exception if it's lower than this value. It can be // used to signal when the hash function is bad or the input adversarial. CopyableAtomic minimum_load_factor_; // stores the maximum hashpower allowed for any expansions. If set to // NO_MAXIMUM_HASHPOWER, this limit will be disregarded. CopyableAtomic maximum_hashpower_; // Maximum number of extra threads to spawn when doing any large batch // operations. CopyableAtomic max_num_worker_threads_; public: /** * An ownership wrapper around a @ref cuckoohash_map table instance. When * given a table instance, it takes all the locks on the table, blocking all * outside operations on the table. Because the locked_table has unique * ownership of the table, it can provide a set of operations on the table * that aren't possible in a concurrent context. * * The locked_table interface is very similar to the STL unordered_map * interface, and for functions whose signatures correspond to unordered_map * methods, the behavior should be mostly the same. */ class locked_table { public: /** @name Type Declarations */ /**@{*/ using key_type = typename cuckoohash_map::key_type; using mapped_type = typename cuckoohash_map::mapped_type; using value_type = typename cuckoohash_map::value_type; using size_type = typename cuckoohash_map::size_type; using difference_type = typename cuckoohash_map::difference_type; using hasher = typename cuckoohash_map::hasher; using key_equal = typename cuckoohash_map::key_equal; using allocator_type = typename cuckoohash_map::allocator_type; using reference = typename cuckoohash_map::reference; using const_reference = typename cuckoohash_map::const_reference; using pointer = typename cuckoohash_map::pointer; using const_pointer = typename cuckoohash_map::const_pointer; /** * A constant iterator over a @ref locked_table, which allows read-only * access to the elements of the table. It fulfills the * BidirectionalIterator concept. */ class const_iterator { public: using difference_type = typename locked_table::difference_type; using value_type = typename locked_table::value_type; using pointer = typename locked_table::const_pointer; using reference = typename locked_table::const_reference; using iterator_category = std::bidirectional_iterator_tag; const_iterator() {} // Return true if the iterators are from the same locked table and // location, false otherwise. bool operator==(const const_iterator &it) const { return buckets_ == it.buckets_ && index_ == it.index_ && slot_ == it.slot_; } bool operator!=(const const_iterator &it) const { return !(operator==(it)); } reference operator*() const { return (*buckets_)[index_].kvpair(slot_); } pointer operator->() const { return std::addressof(operator*()); } // Advance the iterator to the next item in the table, or to the end // of the table. Returns the iterator at its new position. const_iterator &operator++() { // Move forward until we get to a slot that is occupied, or we // get to the end ++slot_; for (; index_ < buckets_->size(); ++index_) { for (; slot_ < slot_per_bucket(); ++slot_) { if ((*buckets_)[index_].occupied(slot_)) { return *this; } } slot_ = 0; } assert(std::make_pair(index_, slot_) == end_pos(*buckets_)); return *this; } // Advance the iterator to the next item in the table, or to the end // of the table. Returns the iterator at its old position. const_iterator operator++(int) { const_iterator old(*this); ++(*this); return old; } // Move the iterator back to the previous item in the table. Returns // the iterator at its new position. const_iterator &operator--() { // Move backward until we get to the beginning. Behavior is // undefined if we are iterating at the first element, so we can // assume we'll reach an element. This means we'll never reach // index_ == 0 and slot_ == 0. if (slot_ == 0) { --index_; slot_ = slot_per_bucket() - 1; } else { --slot_; } while (!(*buckets_)[index_].occupied(slot_)) { if (slot_ == 0) { --index_; slot_ = slot_per_bucket() - 1; } else { --slot_; } } return *this; } // Move the iterator back to the previous item in the table. // Returns the iterator at its old position. Behavior is undefined // if the iterator is at the beginning. const_iterator operator--(int) { const_iterator old(*this); --(*this); return old; } protected: // The buckets owned by the locked table being iterated over. Even // though const_iterator cannot modify the buckets, we don't mark // them const so that the mutable iterator can derive from this // class. Also, since iterators should be default constructible, // copyable, and movable, we have to make this a raw pointer type. buckets_t *buckets_; // The bucket index of the item being pointed to. For implementation // convenience, we let it take on negative values. size_type index_; // The slot in the bucket of the item being pointed to. For // implementation convenience, we let it take on negative values. size_type slot_; // Returns the position signifying the end of the table static std::pair end_pos(const buckets_t &buckets) { return std::make_pair(buckets.size(), 0); } // The private constructor is used by locked_table to create // iterators from scratch. If the given index_-slot_ pair is at the // end of the table, or the given spot is occupied, stay. Otherwise, // step forward to the next data item, or to the end of the table. const_iterator(buckets_t &buckets, size_type index, // NOLINT size_type slot) noexcept : buckets_(std::addressof(buckets)), index_(index), slot_(slot) { if (std::make_pair(index_, slot_) != end_pos(*buckets_) && !(*buckets_)[index_].occupied(slot_)) { operator++(); } } friend class locked_table; }; /** * An iterator over a @ref locked_table, which allows read-write access * to elements of the table. It fulfills the BidirectionalIterator * concept. */ class iterator : public const_iterator { public: using pointer = typename cuckoohash_map::pointer; using reference = typename cuckoohash_map::reference; iterator() {} bool operator==(const iterator &it) const { return const_iterator::operator==(it); } bool operator!=(const iterator &it) const { return const_iterator::operator!=(it); } reference operator*() { return (*const_iterator::buckets_)[const_iterator::index_].kvpair( const_iterator::slot_); } pointer operator->() { return std::addressof(operator*()); } iterator &operator++() { const_iterator::operator++(); return *this; } iterator operator++(int) { iterator old(*this); const_iterator::operator++(); return old; } iterator &operator--() { const_iterator::operator--(); return *this; } iterator operator--(int) { iterator old(*this); const_iterator::operator--(); return old; } private: iterator(buckets_t &buckets, size_type index, size_type slot) noexcept // NOLINT : const_iterator(buckets, index, slot) {} friend class locked_table; }; /**@}*/ /** @name Table Parameters */ /**@{*/ static constexpr size_type slot_per_bucket() { return cuckoohash_map::slot_per_bucket(); } /**@}*/ /** @name Constructors, Destructors, and Assignment */ /**@{*/ locked_table() = delete; locked_table(const locked_table &) = delete; locked_table &operator=(const locked_table &) = delete; locked_table(locked_table &<) noexcept : map_(std::move(lt.map_)), all_locks_manager_(std::move(lt.all_locks_manager_)) {} locked_table &operator=(locked_table &<) noexcept { unlock(); map_ = std::move(lt.map_); all_locks_manager_ = std::move(lt.all_locks_manager_); return *this; } /** * Unlocks the table, thereby freeing the locks on the table, but also * invalidating all iterators and table operations with this object. It * is idempotent. */ void unlock() { all_locks_manager_.reset(); } /**@}*/ /** @name Table Details * * Methods for getting information about the table. Many are identical * to their @ref cuckoohash_map counterparts. Only new functions or * those with different behavior are documented. * */ /**@{*/ /** * Returns whether the locked table has ownership of the table * * @return true if it still has ownership, false otherwise */ bool is_active() const { return static_cast(all_locks_manager_); } hasher hash_function() const { return map_.get().hash_function(); } key_equal key_eq() const { return map_.get().key_eq(); } allocator_type get_allocator() const { return map_.get().get_allocator(); } size_type hashpower() const { return map_.get().hashpower(); } size_type bucket_count() const { return map_.get().bucket_count(); } bool empty() const { return map_.get().empty(); } size_type size() const { return map_.get().size(); } size_type capacity() const { return map_.get().capacity(); } double load_factor() const { return map_.get().load_factor(); } void minimum_load_factor(const double mlf) { map_.get().minimum_load_factor(mlf); } double minimum_load_factor() const { return map_.get().minimum_load_factor(); } void maximum_hashpower(size_type mhp) { map_.get().maximum_hashpower(mhp); } size_type maximum_hashpower() const { return map_.get().maximum_hashpower(); } void max_num_worker_threads(size_type extra_threads) { map_.get().max_num_worker_threads(extra_threads); } size_type max_num_worker_threads() const { return map_.get().max_num_worker_threads(); } /**@}*/ /** @name Iterators */ /**@{*/ /** * Returns an iterator to the beginning of the table. If the table is * empty, it will point past the end of the table. * * @return an iterator to the beginning of the table */ iterator begin() { return iterator(map_.get().buckets_, 0, 0); } const_iterator begin() const { return const_iterator(map_.get().buckets_, 0, 0); } const_iterator cbegin() const { return begin(); } /** * Returns an iterator past the end of the table. * * @return an iterator past the end of the table */ iterator end() { const auto end_pos = const_iterator::end_pos(map_.get().buckets_); return iterator(map_.get().buckets_, static_cast(end_pos.first), static_cast(end_pos.second)); } const_iterator end() const { const auto end_pos = const_iterator::end_pos(map_.get().buckets_); return const_iterator(map_.get().buckets_, static_cast(end_pos.first), static_cast(end_pos.second)); } const_iterator cend() const { return end(); } /**@}*/ /** @name Modifiers */ /**@{*/ void clear() { map_.get().cuckoo_clear(); } /** * This behaves like the @c unordered_map::try_emplace method. It will * always invalidate all iterators, due to the possibilities of cuckoo * hashing and expansion. */ template std::pair insert(K &&key, Args &&... val) { hash_value hv = map_.get().hashed_key(key); auto b = map_.get().template snapshot_and_lock_two(hv); table_position pos = map_.get().template cuckoo_insert_loop(hv, b, key); if (pos.status == ok) { map_.get().add_to_bucket(pos.index, pos.slot, hv.partial, std::forward(key), std::forward(val)...); } else { assert(pos.status == failure_key_duplicated); } return std::make_pair(iterator(map_.get().buckets_, pos.index, pos.slot), pos.status == ok); } iterator erase(const_iterator pos) { map_.get().del_from_bucket(pos.index_, pos.slot_); return iterator(map_.get().buckets_, pos.index_, pos.slot_); } iterator erase(iterator pos) { map_.get().del_from_bucket(pos.index_, pos.slot_); return iterator(map_.get().buckets_, pos.index_, pos.slot_); } template size_type erase(const K &key) { const hash_value hv = map_.get().hashed_key(key); const auto b = map_.get().template snapshot_and_lock_two(hv); const table_position pos = map_.get().cuckoo_find(key, hv.partial, b.i1, b.i2); if (pos.status == ok) { map_.get().del_from_bucket(pos.index, pos.slot); return 1; } else { return 0; } } /**@}*/ /** @name Lookup */ /**@{*/ template iterator find(const K &key) { const hash_value hv = map_.get().hashed_key(key); const auto b = map_.get().template snapshot_and_lock_two(hv); const table_position pos = map_.get().cuckoo_find(key, hv.partial, b.i1, b.i2); if (pos.status == ok) { return iterator(map_.get().buckets_, pos.index, pos.slot); } else { return end(); } } template const_iterator find(const K &key) const { const hash_value hv = map_.get().hashed_key(key); const auto b = map_.get().template snapshot_and_lock_two(hv); const table_position pos = map_.get().cuckoo_find(key, hv.partial, b.i1, b.i2); if (pos.status == ok) { return const_iterator(map_.get().buckets_, pos.index, pos.slot); } else { return end(); } } template mapped_type &at(const K &key) { auto it = find(key); if (it == end()) { throw std::out_of_range("key not found in table"); } else { return it->second; } } template const mapped_type &at(const K &key) const { auto it = find(key); if (it == end()) { throw std::out_of_range("key not found in table"); } else { return it->second; } } /** * This function has the same lifetime properties as @ref * cuckoohash_map::insert, except that the value is default-constructed, * with no parameters, if it is not already in the table. */ template T &operator[](K &&key) { auto result = insert(std::forward(key)); return result.first->second; } template size_type count(const K &key) const { const hash_value hv = map_.get().hashed_key(key); const auto b = map_.get().template snapshot_and_lock_two(hv); return map_.get().cuckoo_find(key, hv.partial, b.i1, b.i2).status == ok ? 1 : 0; } template std::pair equal_range(const K &key) { auto it = find(key); if (it == end()) { return std::make_pair(it, it); } else { auto start_it = it++; return std::make_pair(start_it, it); } } template std::pair equal_range(const K &key) const { auto it = find(key); if (it == end()) { return std::make_pair(it, it); } else { auto start_it = it++; return std::make_pair(start_it, it); } } /**@}*/ /** @name Re-sizing */ /**@{*/ /** * This has the same behavior as @ref cuckoohash_map::rehash, except * that we don't return anything. */ void rehash(size_type n) { map_.get().template cuckoo_rehash(n); } /** * This has the same behavior as @ref cuckoohash_map::reserve, except * that we don't return anything. */ void reserve(size_type n) { map_.get().template cuckoo_reserve(n); } /**@}*/ /** @name Comparison */ /**@{*/ bool operator==(const locked_table <) const { if (size() != lt.size()) { return false; } for (const auto &elem : lt) { auto it = find(elem.first); if (it == end() || it->second != elem.second) { return false; } } return true; } bool operator!=(const locked_table <) const { if (size() != lt.size()) { return true; } for (const auto &elem : lt) { auto it = find(elem.first); if (it == end() || it->second != elem.second) { return true; } } return false; } /**@}*/ private: // The constructor locks the entire table. We keep this constructor private // (but expose it to the cuckoohash_map class), since we don't want users // calling it. We also complete any remaining rehashing in the table, so // that everything is in map.buckets_. locked_table(cuckoohash_map &map) noexcept // NOLINT : map_(map), all_locks_manager_(map.lock_all(normal_mode())) { map.rehash_with_workers(); } // Dispatchers for methods on cuckoohash_map buckets_t &buckets() { return map_.get().buckets_; } const buckets_t &buckets() const { return map_.get().buckets_; } void maybe_resize_locks(size_type new_bucket_count) { map_.get().maybe_resize_locks(new_bucket_count); } locks_t &get_current_locks() { return map_.get().get_current_locks(); } // A reference to the map owned by the table std::reference_wrapper map_; // A manager for all the locks we took on the table. AllLocksManager all_locks_manager_; friend class cuckoohash_map; friend std::ostream &operator<<(std::ostream &os, const locked_table <) { os << lt.buckets(); size_type size = lt.size(); os.write(reinterpret_cast(&size), sizeof(size_type)); double mlf = lt.minimum_load_factor(); size_type mhp = lt.maximum_hashpower(); os.write(reinterpret_cast(&mlf), sizeof(double)); os.write(reinterpret_cast(&mhp), sizeof(size_type)); return os; } friend std::istream &operator>>(std::istream &is, locked_table <) { is >> lt.buckets(); // Re-size the locks, and set the size to the stored size lt.maybe_resize_locks(lt.bucket_count()); for (auto &lock : lt.get_current_locks()) { lock.elem_counter() = 0; } size_type size; is.read(reinterpret_cast(&size), sizeof(size_type)); if (size > 0) { lt.get_current_locks()[0].elem_counter() = size; } double mlf; size_type mhp; is.read(reinterpret_cast(&mlf), sizeof(double)); is.read(reinterpret_cast(&mhp), sizeof(size_type)); lt.minimum_load_factor(mlf); lt.maximum_hashpower(mhp); return is; } }; }; /** * Specializes the @c std::swap algorithm for @c cuckoohash_map. Calls @c * lhs.swap(rhs). * * @param lhs the map on the left side to swap * @param lhs the map on the right side to swap */ template void swap( cuckoohash_map &lhs, cuckoohash_map &rhs) noexcept { lhs.swap(rhs); } } // namespace libcuckoo #endif // MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_CUCKOOHASH_CUCKOOHASH_MAP_HPP_ ================================================ FILE: monolith/native_training/runtime/hash_table/cuckoohash/cuckoohash_util.hpp ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. /** \file */ #ifndef _MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_CUCKOOHASH_CUCKOOHASH_UTIL_HPP #define _MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_CUCKOOHASH_CUCKOOHASH_UTIL_HPP #include #include #include #include #include "monolith/native_training/runtime/hash_table/cuckoohash/cuckoohash_config.hpp" // for LIBCUCKOO_DEBUG namespace libcuckoo { #if LIBCUCKOO_DEBUG // When \ref LIBCUCKOO_DEBUG is 0, LIBCUCKOO_DBG will printing out status // messages in various situations #define LIBCUCKOO_DBG(fmt, ...) \ fprintf(stderr, \ "\x1b[32m" \ "[libcuckoo:%s:%d:%lu] " fmt \ "" \ "\x1b[0m", \ __FILE__, __LINE__, \ std::hash()(std::this_thread::get_id()), \ __VA_ARGS__) #else // When \ref LIBCUCKOO_DEBUG is 0, LIBCUCKOO_DBG does nothing #define LIBCUCKOO_DBG(fmt, ...) \ do { \ } while (0) #endif /** * alignas() requires GCC >= 4.9, so we stick with the alignment attribute for * GCC. */ #ifdef __GNUC__ #define LIBCUCKOO_ALIGNAS(x) __attribute__((aligned(x))) #else #define LIBCUCKOO_ALIGNAS(x) alignas(x) #endif /** * At higher warning levels, MSVC produces an annoying warning that alignment * may cause wasted space: "structure was padded due to __declspec(align())". */ #ifdef _MSC_VER #define LIBCUCKOO_SQUELCH_PADDING_WARNING __pragma(warning(suppress : 4324)) #else #define LIBCUCKOO_SQUELCH_PADDING_WARNING #endif /** * At higher warning levels, MSVC may issue a deadcode warning which depends on * the template arguments given. For certain other template arguments, the code * is not really "dead". */ #ifdef _MSC_VER #define LIBCUCKOO_SQUELCH_DEADCODE_WARNING_BEGIN \ do { \ __pragma(warning(push)); \ __pragma(warning(disable : 4702)) \ } while (0) #define LIBCUCKOO_SQUELCH_DEADCODE_WARNING_END __pragma(warning(pop)) #else #define LIBCUCKOO_SQUELCH_DEADCODE_WARNING_BEGIN #define LIBCUCKOO_SQUELCH_DEADCODE_WARNING_END #endif /** * Thrown when an automatic expansion is triggered, but the load factor of the * table is below a minimum threshold, which can be set by the \ref * cuckoohash_map::minimum_load_factor method. This can happen if the hash * function does not properly distribute keys, or for certain adversarial * workloads. */ class load_factor_too_low : public std::exception { public: /** * Constructor * * @param lf the load factor of the table when the exception was thrown */ load_factor_too_low(const double lf) noexcept : load_factor_(lf) {} // NOLINT /** * @return a descriptive error message */ virtual const char *what() const noexcept override { // NOLINT return "Automatic expansion triggered when load factor was below " "minimum threshold"; } /** * @return the load factor of the table when the exception was thrown */ double load_factor() const noexcept { return load_factor_; } private: const double load_factor_; }; /** * Thrown when an expansion is triggered, but the hashpower specified is greater * than the maximum, which can be set with the \ref * cuckoohash_map::maximum_hashpower method. */ class maximum_hashpower_exceeded : public std::exception { public: /** * Constructor * * @param hp the hash power we were trying to expand to */ maximum_hashpower_exceeded(const size_t hp) noexcept : hashpower_(hp) {} // NOLINT /** * @return a descriptive error message */ virtual const char *what() const noexcept override { // NOLINT return "Expansion beyond maximum hashpower"; } /** * @return the hashpower we were trying to expand to */ size_t hashpower() const noexcept { return hashpower_; } private: const size_t hashpower_; }; } // namespace libcuckoo #endif // _MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_CUCKOOHASH_CUCKOOHASH_UTIL_HPP ================================================ FILE: monolith/native_training/runtime/hash_table/embedding_hash_table.proto ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. syntax = "proto2"; import public "monolith/native_training/runtime/hash_table/compressor/float_compressor.proto"; import public "monolith/native_training/runtime/hash_table/initializer/initializer_config.proto"; import public "monolith/native_training/runtime/hash_table/optimizer/optimizer.proto"; package monolith.hash_table; message EntryConfig { message Segment { optional InitializerConfig init_config = 1; optional OptimizerConfig opt_config = 2; optional FloatCompressorConfig comp_config = 3; // Will overwrite dim_size in init_config, opt_config and comp_config. optional int32 dim_size = 7; } repeated Segment segments = 1; enum EntryType { UNKNOWN = 0; TRAINING = 1; SERVING = 2; } // If this entry is for serving: // For training entry, comp_config is not used. // For serving entry, init_config & opt_config is not used. optional EntryType entry_type = 2 [default = TRAINING]; } message EntryDump { optional sfixed64 id = 1; repeated float num = 2; optional OptimizerDump opt = 3; optional int64 last_update_ts_sec = 4; } // Use per slot expire time (in days) to align with sail requirement. // In future, slot settings might be deprecated. message SlotExpireTimeConfig { // These slot and expire time are used to overwrite default slot expire time. message SlotExpireTime { optional uint32 slot = 1; optional uint32 expire_time = 2; } repeated SlotExpireTime slot_expire_times = 1; // default expire time is 100 years. optional uint32 default_expire_time = 2 [default = 36500]; } message CuckooEmbeddingHashTableConfig { } message EmbeddingHashTableConfig { optional EntryConfig entry_config = 1; enum EntryType { // Memory efficient, but slower. PACKED = 1; // Fastest RAW = 2; } optional EntryType entry_type = 6 [default = PACKED]; optional uint64 initial_capacity = 2 [default = 1]; optional SlotExpireTimeConfig slot_expire_time_config = 3; oneof type { CuckooEmbeddingHashTableConfig cuckoo = 5; } // Whether to evict features periodically during training and serving. optional bool enable_feature_eviction = 7; // Trigger features eviction every n hours. optional int32 feature_evict_every_n_hours = 8 [default = 240]; // Whether to erase zero embeddings(l2norm = 0) when serving optional bool skip_zero_embedding = 10 [default = false]; } message MultiEmbeddingHashTableConfig { repeated string names = 1; repeated EmbeddingHashTableConfig configs = 2; } // Use per slot occurrence threshold config to align with sail requirement. // In future, slot settings might be deprecated. message SlotOccurrenceThresholdConfig { // These slot and occurrence threshold are used to overwrite default // occurrence thresholds. message SlotOccurrenceThreshold { optional uint32 slot = 1; optional uint32 occurrence_threshold = 2; } repeated SlotOccurrenceThreshold slot_occurrence_thresholds = 1; optional uint32 default_occurrence_threshold = 2 [default = 0]; } message SlidingHashFilterMetaDump { optional uint32 split_num = 1 [default = 0]; optional uint32 max_forward_step = 2 [default = 0]; optional uint32 max_backward_step = 3 [default = 0]; optional uint32 max_step = 4 [default = 0]; optional uint32 head = 5 [default = 0]; optional uint32 head_increment = 6 [default = 0]; optional uint64 failure_count = 7 [default = 0]; } // Here we make each hash filter split keep the shared meta dump. // This meta is small and it can help simplify the design to store // the shared meta in a seperate file. We will consider to refine this in // future. message HashFilterSplitMetaDump { optional uint64 failure_count = 1 [default = 0]; optional uint64 total_size = 2 [default = 0]; optional uint64 num_elements = 3 [default = 0]; optional double fill_rate = 4 [default = 0]; optional SlidingHashFilterMetaDump sliding_hash_filter_meta = 5; } message HashFilterSplitDataDump { optional uint32 offset = 1; repeated uint32 data = 2; } message MultiHashTableMetadata { optional string table_name = 1; optional uint64 num_entries = 2; } ================================================ FILE: monolith/native_training/runtime/hash_table/embedding_hash_table_factory.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/runtime/hash_table/embedding_hash_table_factory.h" #include #include "absl/strings/str_format.h" #include "monolith/native_training/runtime/hash_table/cuckoohash/cuckoo_embedding_hash_table.h" #include "monolith/native_training/runtime/hash_table/entry_accessor.h" namespace monolith { namespace hash_table { std::unique_ptr NewEmbeddingHashTableFromConfig( EmbeddingHashTableConfig config, GpuExtraArgs args) { switch (config.type_case()) { case EmbeddingHashTableConfig::kCuckoo: if (config.skip_zero_embedding() && config.entry_config().entry_type() != EntryConfig_EntryType_SERVING) { throw std::invalid_argument( "Only EntryConfig_EntryType_SERVING supports skip_zero_embedding!"); } return NewCuckooEmbeddingHashTable( config.cuckoo(), NewEntryAccessor(config.entry_config()), config.entry_type(), config.initial_capacity(), config.slot_expire_time_config(), config.skip_zero_embedding()); default: throw std::invalid_argument(absl::StrFormat( "Unknown type of hash table. %s", config.ShortDebugString())); } } } // namespace hash_table } // namespace monolith ================================================ FILE: monolith/native_training/runtime/hash_table/embedding_hash_table_factory.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_EMBEDDING_HASH_TABLE_FACTORY #define MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_EMBEDDING_HASH_TABLE_FACTORY #include #include "monolith/native_training/runtime/hash_table/embedding_hash_table.pb.h" #include "monolith/native_training/runtime/hash_table/embedding_hash_table_interface.h" namespace monolith { namespace hash_table { std::unique_ptr NewEmbeddingHashTableFromConfig( EmbeddingHashTableConfig config, GpuExtraArgs args = GpuExtraArgs{}); } // namespace hash_table } // namespace monolith #endif // MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_EMBEDDING_HASH_TABLE_FACTORY ================================================ FILE: monolith/native_training/runtime/hash_table/embedding_hash_table_interface.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_EMBEDDING_HASH_TABLE_INTERFACE_H_ #define MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_EMBEDDING_HASH_TABLE_INTERFACE_H_ #include #include #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "monolith/native_training/runtime/hash_table/embedding_hash_table.pb.h" namespace monolith { namespace hash_table { class CucoMultiHashTableOp; struct GpuExtraArgs { CucoMultiHashTableOp* shared_state; }; // Hash table maps int64 to a float array with fixed length. // Implementation of this interface should guarantee thread safety. class EmbeddingHashTableInterface { public: virtual ~EmbeddingHashTableInterface() = default; // Returns the corresponding entry for |ids|. virtual int64_t BatchLookup( absl::Span ids, absl::Span> embeddings) const = 0; // Handles the corresponding entry for |ids|. virtual void BatchLookupEntry(absl::Span ids, absl::Span entries) const = 0; // Returns the corresponding entry for |id|. virtual int64_t Lookup(int64_t id, absl::Span embedding) const = 0; // Handles the corresponding entry for |id|. virtual void LookupEntry(int64_t id, absl::Span entry) const = 0; // Update the hash table entry directly. virtual void Assign(absl::Span ids, absl::Span> updates, int64_t update_time) = 0; // Update the hash table entry directly. virtual void AssignAdd(int64_t id, absl::Span update, int64_t update_time) = 0; // Reinitialize the hash table entry virtual void Reinitialize(absl::Span ids, absl::Span status) = 0; // Update the hash table based on optimizer. virtual void BatchOptimize(absl::Span ids, absl::Span> grads, absl::Span learning_rates, int64_t update_time, const int64_t global_step = 0) = 0; // Update the hash table based on optimizer. virtual void Optimize(int64_t id, absl::Span grad, absl::Span learning_rates, int64_t update_time, const int64_t global_step = 0) = 0; // Evict the outdated hash table values based on the last updated time. virtual void Evict(int64_t max_update_time) = 0; // Check if a given id exists in the hashtable virtual bool Contains(const int64_t id) = 0; // To utilize multithread, we need to specify how many shard we will use. // Args: // offset - The offset of this shard, should be either 0, or return value from // Save. // limit - how many EntryDump will be fed into write_fn. Default to no limit. struct DumpShard { int idx; int total; int64_t limit = 1LL << 61; }; struct DumpIterator { int64_t offset = 0; }; using WriteFn = std::function; class LockCtx { public: virtual ~LockCtx() = default; }; // Locks all entries in the table. This is used together with Save to prevent // concurrent updates. virtual std::unique_ptr LockAll() = 0; // Saves the data. The implementation should guarantee that different shard // can be dumped in the parallel. virtual void Save(DumpShard shard, WriteFn write_fn, DumpIterator* iter) const = 0; // Restores the data from get_fn. The implementation should guarantee that // different shard can be dumped in the parallel. // |get_fn| returns false if it is end of stream. // Returns max_update_ts in this shard. virtual int64_t Restore(DumpShard shard, std::function get_fn) = 0; // Clears data of hash table. virtual void Clear() = 0; // Returns the size of the current table. virtual int64_t Size() const = 0; // Returns the dimension size of the current table. virtual int DimSize() const = 0; virtual int SliceSize() const = 0; // Returns true if the current table contains the given key. virtual bool Contains(int64_t id) const = 0; virtual std::string DebugString() const = 0; }; // A decorator will default redirect all method to base class. class DefaultEmbeddingHashTableDecorator : public EmbeddingHashTableInterface { public: DefaultEmbeddingHashTableDecorator( std::unique_ptr base) : base_(std::move(base)) {} EmbeddingHashTableInterface* base() const { return base_.get(); } EmbeddingHashTableInterface* base() { return base_.get(); } int64_t BatchLookup(absl::Span ids, absl::Span> embeddings) const override { return base_->BatchLookup(ids, embeddings); } void BatchLookupEntry(absl::Span ids, absl::Span entries) const override { return base_->BatchLookupEntry(ids, entries); } int64_t Lookup(int64_t id, absl::Span embedding) const override { return base_->Lookup(id, embedding); } void LookupEntry(int64_t id, absl::Span entry) const override { return base_->LookupEntry(id, entry); } void Assign(absl::Span ids, absl::Span> updates, int64_t update_time) override { return base_->Assign(ids, updates, update_time); } void AssignAdd(int64_t id, absl::Span update, int64_t update_time) override { return base_->AssignAdd(id, update, update_time); } void Reinitialize(absl::Span ids, absl::Span status) override { base_->Reinitialize(ids, status); } void BatchOptimize(absl::Span ids, absl::Span> grads, absl::Span learning_rates, int64_t update_time, const int64_t global_step = 0) override { return base_->BatchOptimize(ids, grads, learning_rates, update_time); } void Optimize(int64_t id, absl::Span grad, absl::Span learning_rates, int64_t update_time, const int64_t global_step = 0) override { return base_->Optimize(id, grad, learning_rates, update_time); } std::unique_ptr LockAll() override { return base_->LockAll(); } void Save(DumpShard shard, WriteFn write_fn, DumpIterator* iter) const override { return base_->Save(shard, std::move(write_fn), iter); } int64_t Restore(DumpShard shard, std::function get_fn) override { return base_->Restore(std::move(shard), std::move(get_fn)); } void Evict(int64_t max_update_time) { base_->Evict(max_update_time); } bool Contains(const int64_t id) { return base_->Contains(id); } void Clear() override { return base_->Clear(); } int64_t Size() const override { return base_->Size(); } int DimSize() const override { return base_->DimSize(); } int SliceSize() const override { return base_->SliceSize(); } bool Contains(int64_t id) const override { return base_->Contains(id); } std::string DebugString() const override { return base_->DebugString(); } private: std::unique_ptr base_; }; // A class that provides some useful functionality. Like default values for // some method. class EmbeddingHashTableHelper : public DefaultEmbeddingHashTableDecorator { public: explicit EmbeddingHashTableHelper( std::unique_ptr base) : DefaultEmbeddingHashTableDecorator(std::move(base)) {} using DefaultEmbeddingHashTableDecorator::Assign; // Provide default parameters. void Assign(absl::Span ids, absl::Span> updates) { return base()->Assign(ids, updates, 0); } using DefaultEmbeddingHashTableDecorator::AssignAdd; void AssignAdd(int64_t id, absl::Span update) { return base()->AssignAdd(id, update, 0); } using DefaultEmbeddingHashTableDecorator::BatchOptimize; void BatchOptimize(absl::Span ids, absl::Span> grads, absl::Span learning_rates) { return base()->BatchOptimize(ids, grads, learning_rates, 0, 0); } // Some wrapper for easy use. void AssignOne(int64_t id, absl::Span update, int64_t update_time = 0) { return base()->Assign(absl::MakeConstSpan({id}), absl::MakeConstSpan({update}), update_time); } using DefaultEmbeddingHashTableDecorator::Contains; using DefaultEmbeddingHashTableDecorator::Evict; using DefaultEmbeddingHashTableDecorator::Save; void Save(DumpShard shard, WriteFn write_fn) { DumpIterator iter; return base()->Save(std::move(shard), std::move(write_fn), &iter); } }; } // namespace hash_table } // namespace monolith #endif // MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_EMBEDDING_HASH_TABLE_INTERFACE_H_ ================================================ FILE: monolith/native_training/runtime/hash_table/embedding_hash_table_test.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_EMBEDDING_HASH_TABLE_TEST_H_ #define MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_EMBEDDING_HASH_TABLE_TEST_H_ #include #include #include #include "absl/synchronization/mutex.h" #include "gmock/gmock.h" #include "google/protobuf/text_format.h" #include "gtest/gtest.h" #include "monolith/native_training/runtime/hash_table/embedding_hash_table_factory.h" #include "monolith/native_training/runtime/hash_table/entry_accessor.h" namespace monolith { namespace hash_table { namespace proto2 = google::protobuf; constexpr int64_t kSecondsPerDay = 24 * 60 * 60; // The tests assume dim size == 1, sgd optimizer and zero initializer // Please see the config in flat_embedding_hash_table_test.cc as a reference class ReadWriteEmbeddingHashTableTest : public ::testing::TestWithParam< std::tuple>> {}; TEST_P(ReadWriteEmbeddingHashTableTest, SingleThread) { auto p = GetParam(); EmbeddingHashTableConfig config = std::get<0>(p); const auto& learning_rates = std::get<1>(p); std::unique_ptr table = NewEmbeddingHashTableFromConfig(config); std::vector num_buffer(1); absl::Span num = absl::MakeSpan(num_buffer); table->Lookup(5, num); EXPECT_THAT(num, ::testing::ElementsAre(0)); table->AssignAdd(-10, {2.5}, 100LL); table->Lookup(-10, num); EXPECT_THAT(num, ::testing::ElementsAre(2.5)); if (config.entry_config().entry_type() == EntryConfig::TRAINING) { table->Optimize(13, {1.0f}, learning_rates, 0, 0); table->Lookup(13, num); EXPECT_THAT(num, ::testing::ElementsAre(-0.01)); } std::vector ids{-10, 13, 100}; std::vector update1 = {1.0f}, update2 = {2.0f}, update3 = {3.0f}; std::vector> updates = {absl::MakeSpan(update1), absl::MakeSpan(update2), absl::MakeSpan(update3)}; table->Assign(absl::MakeSpan(ids), absl::MakeSpan(updates), 0); std::vector lookup_ids{5, -10, 13, 100}; std::vector emb1 = {0}, emb2 = {0}, emb3 = {0}, emb4 = {0}; std::vector> embeddings = { absl::MakeSpan(emb1), absl::MakeSpan(emb2), absl::MakeSpan(emb3), absl::MakeSpan(emb4)}; table->BatchLookup(absl::MakeSpan(lookup_ids), absl::MakeSpan(embeddings)); EXPECT_THAT(embeddings[0], ::testing::ElementsAre(0)); EXPECT_THAT(embeddings[1], ::testing::ElementsAre(1)); EXPECT_THAT(embeddings[2], ::testing::ElementsAre(2)); EXPECT_THAT(embeddings[3], ::testing::ElementsAre(3)); std::vector entries(4); table->BatchLookupEntry(absl::MakeSpan(lookup_ids), absl::MakeSpan(entries)); EXPECT_EQ(entries[0].SerializeAsString(), ""); EntryDump expect; proto2::TextFormat::ParseFromString(R"( num: 1 opt { dump { sgd { } } } last_update_ts_sec: 0 )", &expect); EXPECT_EQ(entries[1].SerializeAsString(), expect.SerializeAsString()); } TEST_P(ReadWriteEmbeddingHashTableTest, MultiThread) { auto p = GetParam(); EmbeddingHashTableConfig config = std::get<0>(p); const auto& learning_rates = std::get<1>(p); std::unique_ptr table = NewEmbeddingHashTableFromConfig(config); auto func = [&table](int id) { table->AssignAdd(id, {static_cast(id)}, 0); }; auto func2 = [&table](int id, const std::vector& learning_rates) { table->Optimize(id, {static_cast(-200 * id)}, learning_rates, 0, 0); }; const int kNumThread = 100; std::vector> threads; for (int i = 0; i < kNumThread; ++i) { threads.emplace_back(std::make_unique(func, i)); if (config.entry_config().entry_type() == EntryConfig::TRAINING) { threads.emplace_back( std::make_unique(func2, i, learning_rates)); } } for (auto& thread : threads) { thread->join(); } for (int i = 0; i < kNumThread; ++i) { std::vector num(1); table->Lookup(i, absl::MakeSpan(num)); if (config.entry_config().entry_type() == EntryConfig::TRAINING) { EXPECT_THAT(num, ::testing::ElementsAre(3 * i)); } else { EXPECT_THAT(num, ::testing::ElementsAre(i)); } } } TEST_P(ReadWriteEmbeddingHashTableTest, Clear) { auto p = GetParam(); EmbeddingHashTableConfig config = std::get<0>(p); std::unique_ptr table = std::make_unique( NewEmbeddingHashTableFromConfig(config)); table->Assign({1}, {{2.0f}}); std::vector emb(1); table->Lookup(1, absl::MakeSpan(emb)); EXPECT_THAT(emb, testing::ElementsAre(2.0f)); table->Clear(); table->Lookup(1, absl::MakeSpan(emb)); EXPECT_THAT(emb, testing::ElementsAre(0.0f)); } class SaveRestoreEmbeddingHashTestTest : public ::testing::TestWithParam< std::tuple>> {}; TEST_P(SaveRestoreEmbeddingHashTestTest, SaveRestore) { auto p = GetParam(); auto table = std::make_unique( NewEmbeddingHashTableFromConfig(std::get<0>(p))); table->AssignAdd(5, {2.5}, 0); table->AssignAdd(-3, {-0.5}, 0); std::vector dumps; auto write_fn = [&dumps](EntryDump dump) { dumps.push_back(dump); return true; }; const EmbeddingHashTableInterface::DumpShard kSingleShard{0, 1}; table->Save(kSingleShard, write_fn); std::unique_ptr table2 = NewEmbeddingHashTableFromConfig(std::get<0>(p)); int idx = 0; auto get_fn = [&idx, &dumps](EntryDump* dump, int64_t* max_update_ts) { if (idx == static_cast(dumps.size())) return false; *dump = dumps[idx++]; return true; }; table2->Restore(kSingleShard, get_fn); std::vector num(1); table2->Lookup(5, absl::MakeSpan(num)); EXPECT_THAT(num, ::testing::ElementsAre(2.5)); table2->Lookup(-3, absl::MakeSpan(num)); EXPECT_THAT(num, ::testing::ElementsAre(-0.5)); } TEST_P(SaveRestoreEmbeddingHashTestTest, SaveWithOffset) { auto p = GetParam(); std::unique_ptr table = NewEmbeddingHashTableFromConfig(std::get<0>(p)); table->AssignAdd(5, {2.5}, 0); table->AssignAdd(-3, {-0.5}, 0); std::vector dumps; auto write_fn = [&dumps](EntryDump dump) { dumps.push_back(dump); return true; }; EmbeddingHashTableInterface::DumpShard shard{0, 1}; shard.limit = 1; EmbeddingHashTableInterface::DumpIterator iter; table->Save(shard, write_fn, &iter); EXPECT_THAT(dumps.size(), 1); table->Save(shard, write_fn, &iter); std::vector ids; for (int i = 0; i < dumps.size(); ++i) { ids.push_back(dumps[i].id()); } EXPECT_THAT(ids, testing::UnorderedElementsAre(5, -3)); } TEST_P(SaveRestoreEmbeddingHashTestTest, SaveRestoreMultithreaded) { auto p = GetParam(); auto table = EmbeddingHashTableHelper(NewEmbeddingHashTableFromConfig(std::get<0>(p))); const int kNumThreads = 10; const int kPerThreadIds = 2600; for (int64_t i = 0; i < kNumThreads * kPerThreadIds; ++i) { table.AssignOne(i - kNumThreads * kPerThreadIds / 2, {float(i)}); } std::vector dumps; absl::Mutex mu; auto write_fn = [&dumps, &mu](EntryDump dump) { absl::MutexLock l(&mu); dumps.push_back(dump); return true; }; std::vector> save_threads; for (int i = 0; i < kNumThreads; ++i) { auto save_func = [kNumThreads, &table, &write_fn](int i) { EmbeddingHashTableInterface::DumpShard shard{i, kNumThreads}; table.Save(shard, write_fn); return true; }; save_threads.push_back(std::make_unique(save_func, i)); } for (int i = 0; i < kNumThreads; ++i) { save_threads[i]->join(); } ASSERT_THAT(dumps.size(), kNumThreads * kPerThreadIds); std::unique_ptr table2 = NewEmbeddingHashTableFromConfig(std::get<0>(p)); std::vector> restore_threads; for (int i = 0; i < kNumThreads; ++i) { auto restore_fn = [&table2, &dumps, kNumThreads, kPerThreadIds](int i) { int idx = i * kPerThreadIds; int end_idx = (i + 1) * kPerThreadIds; auto get_fn = [&dumps, &idx, &end_idx](EntryDump* dump, int64_t*) { if (idx == end_idx) return false; *dump = dumps[idx++]; return true; }; table2->Restore({i, kNumThreads}, get_fn); }; restore_threads.push_back(std::make_unique(restore_fn, i)); } for (int i = 0; i < kNumThreads; ++i) { restore_threads[i]->join(); } for (int64_t i = 0; i < kNumThreads * kPerThreadIds; ++i) { std::vector nums(1); table.Lookup(i - kNumThreads * kPerThreadIds / 2, absl::MakeSpan(nums)); ASSERT_THAT(nums[0], i); } } TEST_P(SaveRestoreEmbeddingHashTestTest, SaveWithStopEarly) { auto p = GetParam(); auto table = EmbeddingHashTableHelper(NewEmbeddingHashTableFromConfig(std::get<0>(p))); table.Assign({0, 1}, {{0.0}, {0.0}}); int called = 0; auto write_fn = [&called](EntryDump dump) { ++called; return false; }; EmbeddingHashTableInterface::DumpShard shard{0, 1}; table.Save(shard, write_fn); // Should stop early EXPECT_THAT(called, 1); } class EmbeddingHashTableEvictTest : public ::testing::TestWithParam< std::tuple>> {}; TEST_P(EmbeddingHashTableEvictTest, OneTimeEvict) { auto p = GetParam(); auto embedding_hash_table_config = std::get<0>(p); auto* slot_expire_time_config = embedding_hash_table_config.mutable_slot_expire_time_config(); slot_expire_time_config->set_default_expire_time(14); const std::vector slot_to_expire_time = {0, 5, 6}; for (int i = 0; i < slot_to_expire_time.size(); ++i) { auto* expire_time = slot_expire_time_config->add_slot_expire_times(); expire_time->set_slot(i); expire_time->set_expire_time(slot_to_expire_time[i]); } auto table = NewEmbeddingHashTableFromConfig(embedding_hash_table_config); const int64_t kFidUpdateTime = 1234; const int64_t kSlot1Fid = ((1LL << 48) | (123)); const int64_t kSlot2Fid = ((2LL << 48) | (234)); const int64_t kSlot3Fid = ((3LL << 48) | (456)); table->Assign({kSlot1Fid}, {{2.0f}}, kFidUpdateTime); table->Assign({kSlot2Fid}, {{5.0f}}, kFidUpdateTime); table->Assign({kSlot3Fid}, {{7.0f}}, kFidUpdateTime); std::vector emb(1); table->Lookup(kSlot1Fid, absl::MakeSpan(emb)); EXPECT_THAT(emb, testing::ElementsAre(2.0f)); table->Lookup(kSlot2Fid, absl::MakeSpan(emb)); EXPECT_THAT(emb, testing::ElementsAre(5.0f)); table->Lookup(kSlot3Fid, absl::MakeSpan(emb)); EXPECT_THAT(emb, testing::ElementsAre(7.0f)); const int64_t current_time = kFidUpdateTime + 5 * kSecondsPerDay + 60; table->Evict(current_time); table->Lookup(kSlot1Fid, absl::MakeSpan(emb)); // Slot 1 expire time is 5 days, the time gap is 5 days + 60 seconds, so // should be evited. EXPECT_THAT(emb, testing::ElementsAre(0.0f)); table->Lookup(kSlot2Fid, absl::MakeSpan(emb)); // Slot 1 expire time is 6 days, the time gap is 5 days + 60 seconds, so // should NOT be evited. EXPECT_THAT(emb, testing::ElementsAre(5.0f)); table->Lookup(kSlot3Fid, absl::MakeSpan(emb)); // Slot 3 expire time should use default 14 days, the time gap is 5 days // + 60 seconds, so should NOT be evited. EXPECT_THAT(emb, testing::ElementsAre(7.0f)); } // Testing evict would work during the hash table rehashing. TEST_P(EmbeddingHashTableEvictTest, EvictWhileRehash) { auto p = GetParam(); auto embedding_hash_table_config = std::get<0>(p); // We keep the initial capacity very small so that inserting will trigger // rehash. const int64_t kInitialHashTableCapacity = 1; const int64_t kNumInsertThreads = 20; const int64_t kIdPerThread = 50; const float kDefaultValue = 123.0f; embedding_hash_table_config.set_initial_capacity(kInitialHashTableCapacity); auto* slot_expire_time_config = embedding_hash_table_config.mutable_slot_expire_time_config(); slot_expire_time_config->set_default_expire_time(0); for (int i = 0; i < kNumInsertThreads; ++i) { auto* expire_time = slot_expire_time_config->add_slot_expire_times(); expire_time->set_slot(i); expire_time->set_expire_time(i); } auto table = NewEmbeddingHashTableFromConfig(embedding_hash_table_config); const int64_t kFidUpdateTime = 1234; std::vector> insert_threads; for (int i = 0; i < kNumInsertThreads; ++i) { auto insert_func = [&table, kDefaultValue](int i) { for (int id = 0; id < kIdPerThread; ++id) { const int64_t slot_id = i; const int64_t fid = ((slot_id << 48) | id); table->Assign({fid}, {{kDefaultValue}}, kFidUpdateTime + id); std::this_thread::sleep_for(std::chrono::milliseconds(100)); } }; insert_threads.push_back(std::make_unique(insert_func, i)); } const int64_t kCurrentTime = kFidUpdateTime + 5 * kSecondsPerDay + kSecondsPerDay / 2; // Run Evict every 2 seconds 3 times. std::unique_ptr evict_thread; auto evict_func = [&table, kCurrentTime]() { for (int i = 0; i < 3; ++i) { table->Evict(kCurrentTime); std::this_thread::sleep_for(std::chrono::seconds(2)); } }; evict_thread = std::make_unique(evict_func); for (int i = 0; i < kNumInsertThreads; ++i) { insert_threads[i]->join(); } evict_thread->join(); // Have a final evict to make sure the test is not flaky. table->Evict(kCurrentTime); for (int i = 0; i < kNumInsertThreads; ++i) { std::vector num(1); for (int id = 0; id < kIdPerThread; ++id) { const int64_t slot_id = i; const int64_t fid = ((slot_id << 48) | id); table->Lookup(fid, absl::MakeSpan(num)); if (i <= 5) { EXPECT_THAT(num, ::testing::ElementsAre(0)); } else { EXPECT_THAT(num, ::testing::ElementsAre(kDefaultValue)); } } } } class EmbeddingHashTableSkipZeroEmbeddingTest : public ::testing::TestWithParam< std::tuple>> {}; TEST_P(EmbeddingHashTableSkipZeroEmbeddingTest, AssignSkipZeroEmbedding) { auto p = GetParam(); auto table = EmbeddingHashTableHelper(NewEmbeddingHashTableFromConfig(std::get<0>(p))); const int kNumThreads = 10; const int kNumIds = 3000; auto AssignFn = [&]() { for (int i = 0; i < kNumIds; ++i) { table.Assign({i}, {{static_cast(i % 2)}}); } }; std::vector> threads; for (int i = 0; i < kNumThreads; ++i) { threads.emplace_back(std::make_unique(AssignFn)); } for (auto& thread : threads) { thread->join(); } for (int64_t i = 0; i < kNumIds; ++i) { std::vector nums(1); table.Lookup(i, absl::MakeSpan(nums)); ASSERT_THAT(nums[0], i % 2); if (i % 2 == 0) { EXPECT_FALSE(table.Contains(i)); } else { EXPECT_TRUE(table.Contains(i)); } } } TEST_P(EmbeddingHashTableSkipZeroEmbeddingTest, RestoreSkipZeroEmbedding) { EmbeddingHashTableConfig config; EXPECT_TRUE(proto2::TextFormat::ParseFromString(R"( entry_config { segments { dim_size: 1 init_config { zeros {} } opt_config { sgd {} } } } initial_capacity: 1 cuckoo {} )", &config)); auto table = EmbeddingHashTableHelper(NewEmbeddingHashTableFromConfig(config)); const int kNumThreads = 10; const int kNumIds = 3000; // Assign auto AssignFn = [&]() { for (int i = 0; i < kNumIds; ++i) { table.Assign({i}, {{static_cast(i % 2)}}); } }; std::vector> threads; for (int i = 0; i < kNumThreads; ++i) { threads.emplace_back(std::make_unique(AssignFn)); } for (auto& thread : threads) { thread->join(); } for (int64_t i = 0; i < kNumIds; ++i) { std::vector nums(1); table.Lookup(i, absl::MakeSpan(nums)); ASSERT_THAT(nums[0], i % 2); EXPECT_TRUE(table.Contains(i)); } // Save std::vector dumps; absl::Mutex mu; auto write_fn = [&dumps, &mu](EntryDump dump) { absl::MutexLock l(&mu); dumps.push_back(dump); return true; }; std::vector> save_threads; for (int i = 0; i < kNumThreads; ++i) { auto save_func = [kNumThreads, kNumIds, &table, &write_fn](int i) { EmbeddingHashTableInterface::DumpShard shard{i, kNumThreads}; table.Save(shard, write_fn); return true; }; save_threads.push_back(std::make_unique(save_func, i)); } for (int i = 0; i < kNumThreads; ++i) { save_threads[i]->join(); } ASSERT_THAT(dumps.size(), kNumIds); // Restore auto p = GetParam(); auto table2 = EmbeddingHashTableHelper(NewEmbeddingHashTableFromConfig(std::get<0>(p))); std::vector> restore_threads; for (int i = 0; i < kNumThreads; ++i) { auto restore_fn = [&table2, &dumps, kNumThreads, kNumIds](int i) { const int kPerThreadIds = kNumIds / kNumThreads; int idx = i * kPerThreadIds; int end_idx = (i + 1) * kPerThreadIds; auto get_fn = [&dumps, &idx, &end_idx](EntryDump* dump, int64_t*) { if (idx == end_idx) return false; *dump = dumps[idx++]; return true; }; table2.Restore({i, kNumThreads}, get_fn); }; restore_threads.push_back(std::make_unique(restore_fn, i)); } for (int i = 0; i < kNumThreads; ++i) { restore_threads[i]->join(); } for (int64_t i = 0; i < kNumIds; ++i) { std::vector nums(1); table2.Lookup(i, absl::MakeSpan(nums)); ASSERT_THAT(nums[0], i % 2); if (i % 2 == 0) { EXPECT_FALSE(table2.Contains(i)); } else { EXPECT_TRUE(table2.Contains(i)); } } } } // namespace hash_table } // namespace monolith #endif // MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_EMBEDDING_HASH_TABLE_TEST_H_ ================================================ FILE: monolith/native_training/runtime/hash_table/entry_accessor.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/runtime/hash_table/entry_accessor.h" #include #include #include "absl/algorithm/container.h" #include "absl/base/attributes.h" #include "absl/strings/str_format.h" #include "absl/time/clock.h" #include "absl/time/time.h" #include "monolith/native_training/runtime/hash_table/compressor/float_compressor.h" #include "monolith/native_training/runtime/hash_table/initializer/initializer_combination.h" #include "monolith/native_training/runtime/hash_table/initializer/initializer_config.pb.h" #include "monolith/native_training/runtime/hash_table/initializer/initializer_factory.h" #include "monolith/native_training/runtime/hash_table/optimizer/optimizer.pb.h" #include "monolith/native_training/runtime/hash_table/optimizer/optimizer_combination.h" #include "monolith/native_training/runtime/hash_table/optimizer/optimizer_factory.h" #include "monolith/native_training/runtime/hash_table/retriever/fake_quant_retriever.h" #include "monolith/native_training/runtime/hash_table/retriever/hash_net_retriever.h" #include "monolith/native_training/runtime/hash_table/retriever/raw_retriever.h" #include "monolith/native_training/runtime/hash_table/retriever/retriever_combination.h" namespace monolith { namespace hash_table { namespace { namespace proto2 = google::protobuf; class ServingEntryAccessor final : public EntryAccessorInterface { public: explicit ServingEntryAccessor( std::unique_ptr compressor) : compressor_(std::move(compressor)), size_bytes_(compressor_->SizeBytes()), uncompressed_size_bytes_(compressor_->UncompressedSizeBytes()) {} int64_t SizeBytes() const override { return size_bytes_; } int64_t UncompressedSizeBytes() const override { return uncompressed_size_bytes_; } std::string DebugString() const override { return absl::StrFormat( R"({"compressor": "%s", "size_bytes": %ld, "uncompressed_size_bytes": %ld})", compressor_->DebugString(), SizeBytes(), UncompressedSizeBytes()); } int DimSize() const override { return compressor_->DimSize(); } int SliceSize() const override { throw std::runtime_error("ServingEntryAccessor doesn't support SliceSize"); } void Init(void* ctx) const override { // No need to initialize serving entry } void Fill(const void* ctx, absl::Span num) const override { compressor_->Decode(ctx, num); } void Assign(absl::Span num, void* ctx) const override { compressor_->Encode(num, ctx); } void AssignAdd(absl::Span num, void* ctx) const override { std::vector embedding(num.size()); compressor_->Decode(ctx, absl::MakeSpan(embedding)); for (int i = 0; i < num.size(); ++i) { embedding[i] += num[i]; } compressor_->Encode(embedding, ctx); } void Optimize(void* ctx, absl::Span grad, absl::Span learning_rates, const int64_t global_step) const override { throw std::runtime_error("ServingEntryAccessor doesn't support Optimize"); } EntryDump Save(const void* ctx, uint32_t timestamp_sec) const override { throw std::runtime_error("ServingEntryAccessor doesn't support Save"); } void Restore(void* ctx, uint32_t* timestamp_sec, EntryDump dump) const { (void)timestamp_sec; std::vector num(dump.num_size()); absl::c_copy(dump.num(), num.begin()); compressor_->Encode(num, ctx); } private: std::unique_ptr compressor_; int size_bytes_; int uncompressed_size_bytes_; }; // The layout of ctx is: // float * dim_size_ | Info | optimizer_data class EntryAccessor final : public EntryAccessorInterface { public: EntryAccessor(std::unique_ptr initializer, std::unique_ptr optimizer, std::unique_ptr retriever) : initializer_(std::move(initializer)), optimizer_(std::move(optimizer)), retriever_(std::move(retriever)), optimizer_bytes_(optimizer_->SizeBytes()), uncompressed_optimizer_bytes_(optimizer_->UncompressedSizeBytes()), dim_size_(initializer_->DimSize()), slice_size_(optimizer_->SliceSize()), num_bytes_(retriever_->SizeBytes()) { if (initializer_->DimSize() != optimizer_->DimSize() || initializer_->DimSize() != retriever_->DimSize()) { throw std::invalid_argument( absl::StrFormat("Initializer/Optimizer/Retriever dim size should " "match. But got %d vs %d vs %d", initializer_->DimSize(), optimizer_->DimSize(), retriever_->DimSize())); } } EntryAccessor(EntryAccessor&&) = default; EntryAccessor& operator=(EntryAccessor&&) = default; int64_t SizeBytes() const override { return num_bytes_ + optimizer_bytes_; } int64_t UncompressedSizeBytes() const override { return num_bytes_ + uncompressed_optimizer_bytes_; } std::string DebugString() const override { return absl::StrFormat( R"({"initializer": "%s", "optimizer": "%s", "retriever": "%s"})", initializer_->DebugString(), optimizer_->DebugString(), retriever_->DebugString()); } int DimSize() const override { return dim_size_; } int SliceSize() const override { return slice_size_; } absl::Span GetMutableNum(void* ctx) const { float* ctx_float = static_cast(ctx); return absl::MakeSpan(ctx_float, ctx_float + dim_size_); } void Init(void* ctx) const override { auto num_span = GetMutableNum(ctx); initializer_->Initialize(num_span); optimizer_->Init(GetMutableOptimizerCtx(ctx)); } void Fill(const void* ctx, absl::Span num) const override { retriever_->Retrieve(ctx, num); } void Assign(absl::Span num, void* ctx) const override { auto ctx_float = static_cast(ctx); auto embedding = absl::MakeSpan(ctx_float, ctx_float + dim_size_); std::memcpy(embedding.data(), num.data(), sizeof(float) * num.size()); } void AssignAdd(absl::Span num, void* ctx) const override { auto ctx_float = static_cast(ctx); auto embedding = absl::MakeSpan(ctx_float, ctx_float + dim_size_); for (int i = 0; i < num.size(); ++i) { embedding[i] += num[i]; } } void Optimize(void* ctx, absl::Span grad, absl::Span learning_rates, const int64_t global_step) const override { auto* mutable_grad = const_cast(grad.data()); retriever_->Backward(GetNum(ctx), absl::MakeSpan(mutable_grad, grad.size()), global_step); optimizer_->Optimize(GetMutableOptimizerCtx(ctx), GetMutableNum(ctx), grad, learning_rates, global_step); } EntryDump Save(const void* ctx, uint32_t timestamp) const override; void Restore(void* ctx, uint32_t* timestamp_sec, EntryDump dump) const override; private: absl::Span GetNum(const void* ctx) const { const float* ctx_float = static_cast(ctx); return absl::MakeConstSpan(ctx_float, ctx_float + dim_size_); } void* GetMutableOptimizerCtx(void* ctx) const { return AddOffset(ctx, num_bytes_); } const void* GetOptimizerCtx(const void* ctx) const { return AddOffset(ctx, num_bytes_); } std::unique_ptr initializer_; std::unique_ptr optimizer_; std::unique_ptr retriever_; const int64_t optimizer_bytes_ = 0; const int64_t uncompressed_optimizer_bytes_ = 0; const int dim_size_ = 0; const int slice_size_ = 0; const int64_t num_bytes_ = 0; }; EntryDump EntryAccessor::Save(const void* ctx, uint32_t timestamp_sec) const { EntryDump dump; absl::c_copy(GetNum(ctx), proto2::RepeatedFieldBackInserter(dump.mutable_num())); *dump.mutable_opt() = optimizer_->Save(GetOptimizerCtx(ctx)); dump.set_last_update_ts_sec(timestamp_sec); return dump; } void EntryAccessor::Restore(void* ctx, uint32_t* timestamp_sec, EntryDump dump) const { auto num = GetMutableNum(ctx); for (int i = 0; i < dump.num_size(); ++i) { num[i] = dump.num(i); } *timestamp_sec = dump.last_update_ts_sec(); optimizer_->Restore(GetMutableOptimizerCtx(ctx), std::move(*dump.mutable_opt())); } // Write dim_size into sub field of T (T can be OptimizerConfig, // InitializerConfig or FloatCompressorConfig). template void WriteDimSize(T* conf, int dim_size) { const proto2::Descriptor* descriptor = conf->GetDescriptor(); const proto2::Reflection* reflection = conf->GetReflection(); const proto2::OneofDescriptor* type = descriptor->FindOneofByName("type"); const proto2::FieldDescriptor* type_field = reflection->GetOneofFieldDescriptor(*conf, type); if (type_field == nullptr || type_field->type() != proto2::FieldDescriptor::TYPE_MESSAGE) { throw std::invalid_argument(absl::StrFormat("%s must be set type. Got %s", descriptor->name(), conf->ShortDebugString())); } proto2::Message* type_msg = reflection->MutableMessage(conf, type_field); const proto2::FieldDescriptor* dim_size_field = type_msg->GetDescriptor()->FindFieldByName("dim_size"); type_msg->GetReflection()->SetInt32(type_msg, dim_size_field, dim_size); } struct Objects { std::unique_ptr init; std::unique_ptr opt; std::unique_ptr comp; std::unique_ptr retriever; }; template void AssignOrCombine(T* t1, T t2, F combine_fn) { if (*t1 == nullptr) { *t1 = std::move(t2); return; } *t1 = combine_fn(std::move(*t1), std::move(t2)); } Objects GenerateObjFromSegments( proto2::RepeatedPtrField* segments) { Objects obj; for (EntryConfig::Segment& segment : *segments) { if (segment.has_comp_config() && segment.comp_config().has_fixed_r8()) { auto retriever = NewFakeQuantRetriever( segment.dim_size(), FakeQuantizer(segment.comp_config().fixed_r8().r())); AssignOrCombine(&obj.retriever, std::move(retriever), CombineRetrievers); } else if (segment.has_comp_config() && segment.comp_config().has_one_bit()) { auto hash_net_quantizer = std::make_unique(segment.comp_config().one_bit()); auto retriever = NewHashNetRetriever(segment.dim_size(), std::move(hash_net_quantizer)); AssignOrCombine(&obj.retriever, std::move(retriever), CombineRetrievers); } else { auto retriever = NewRawRetriever(segment.dim_size()); AssignOrCombine(&obj.retriever, std::move(retriever), CombineRetrievers); } if (segment.has_opt_config()) { WriteDimSize(segment.mutable_opt_config(), segment.dim_size()); auto new_opt = NewOptimizerFromConfig(segment.opt_config()); AssignOrCombine(&obj.opt, std::move(new_opt), CombineOptimizers); } if (segment.has_init_config()) { WriteDimSize(segment.mutable_init_config(), segment.dim_size()); auto new_init = NewInitializerFromConfig(segment.init_config()); AssignOrCombine(&obj.init, std::move(new_init), CombineInitializers); } if (segment.has_comp_config()) { WriteDimSize(segment.mutable_comp_config(), segment.dim_size()); auto new_comp = NewFloatCompressor(segment.comp_config()); AssignOrCombine(&obj.comp, std::move(new_comp), CombineFloatCompressor); } } return obj; } } // namespace std::unique_ptr NewEntryAccessor(EntryConfig config) { Objects obj = GenerateObjFromSegments(config.mutable_segments()); switch (config.entry_type()) { case EntryConfig::TRAINING: if (obj.init == nullptr || obj.opt == nullptr) { throw std::invalid_argument(absl::StrFormat( "init or opt config is missing from entry config : %s", config.ShortDebugString())); } return std::make_unique( std::move(obj.init), std::move(obj.opt), std::move(obj.retriever)); case EntryConfig::SERVING: if (obj.comp == nullptr) { throw std::invalid_argument( absl::StrFormat("comp config is missing form entry config: %s", config.ShortDebugString())); } return std::make_unique(std::move(obj.comp)); default: throw std::invalid_argument( absl::StrFormat("Unknown entry type: %s", config.ShortDebugString())); } } } // namespace hash_table } // namespace monolith ================================================ FILE: monolith/native_training/runtime/hash_table/entry_accessor.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_ENTRY_ACCESSOR #define MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_ENTRY_ACCESSOR #include #include #include #include #include "absl/algorithm/container.h" #include "absl/types/span.h" #include "monolith/native_training/runtime/hash_table/embedding_hash_table.pb.h" #include "monolith/native_training/runtime/hash_table/initializer/initializer_interface.h" #include "monolith/native_training/runtime/hash_table/optimizer/optimizer_interface.h" #include "monolith/native_training/runtime/hash_table/utils.h" namespace monolith { namespace hash_table { class EntryAccessorInterface { public: virtual ~EntryAccessorInterface() = default; // Size bytes need to be allocated in this entry. virtual int64_t SizeBytes() const = 0; // Size bytes need to be allocated in this entry if not compressed. virtual int64_t UncompressedSizeBytes() const = 0; virtual std::string DebugString() const = 0; // The dim that this entry accessor can support virtual int DimSize() const = 0; // The number of slices in this entry. virtual int SliceSize() const = 0; // Initialize the given entry. virtual void Init(void* ctx) const = 0; // Fills the num based on entry. virtual void Fill(const void* ctx, absl::Span num) const = 0; // Assign the entry using num virtual void Assign(absl::Span num, void* ctx) const = 0; // AssignAdd the entry using num virtual void AssignAdd(absl::Span num, void* ctx) const = 0; // Optimizes the entry with |grad|. virtual void Optimize(void* ctx, absl::Span grad, absl::Span learning_rates, const int64_t global_step = 0) const = 0; // Converts an entry to EntryDump. virtual EntryDump Save(const void* ctx, uint32_t timestamp_sec) const = 0; // Restores the entry from |dump|. virtual void Restore(void* ctx, uint32_t* timestamp_sec, EntryDump dump) const = 0; }; std::unique_ptr NewEntryAccessor(EntryConfig config); } // namespace hash_table } // namespace monolith #endif // MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_ENTRY_ACCESSOR ================================================ FILE: monolith/native_training/runtime/hash_table/entry_accessor_decorator.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_ENTRY_ACCESSOR_DECORATOR_H_ #define MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_ENTRY_ACCESSOR_DECORATOR_H_ #include "monolith/native_training/runtime/hash_table/entry_accessor.h" namespace monolith { namespace hash_table { // DEPRECATED: Prefer using retriever // // The base class of decorator. By default, it delegates all requests to the // base entry accessor. class EntryAccessorDecorator : public EntryAccessorInterface { public: explicit EntryAccessorDecorator( std::unique_ptr entry_accessor) : entry_accessor_(std::move(entry_accessor)) {} int64_t SizeBytes() const override { return entry_accessor_->SizeBytes(); } int64_t UncompressedSizeBytes() const override { return entry_accessor_->UncompressedSizeBytes(); } std::string DebugString() const override { return entry_accessor_->DebugString(); } int DimSize() const override { return entry_accessor_->DimSize(); } int SliceSize() const override { return entry_accessor_->SliceSize(); } void Init(void* ctx) const override { entry_accessor_->Init(ctx); } void Fill(const void* ctx, absl::Span num) const override { entry_accessor_->Fill(ctx, num); } void Assign(absl::Span num, void* ctx) const override { entry_accessor_->Assign(num, ctx); } void AssignAdd(absl::Span num, void* ctx) const override { entry_accessor_->AssignAdd(num, ctx); } void Optimize(void* ctx, absl::Span grad, absl::Span learning_rates, const int64_t global_step) const override { entry_accessor_->Optimize(ctx, grad, learning_rates, global_step); } EntryDump Save(const void* ctx, uint32_t timestamp_sec) const override { return entry_accessor_->Save(ctx, timestamp_sec); } void Restore(void* ctx, uint32_t* timestamp_sec, EntryDump dump) const override { entry_accessor_->Restore(ctx, timestamp_sec, dump); } protected: std::unique_ptr entry_accessor_; }; } // namespace hash_table } // namespace monolith #endif // MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_ENTRY_ACCESSOR_DECORATOR_H_ ================================================ FILE: monolith/native_training/runtime/hash_table/entry_accessor_test.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/runtime/hash_table/entry_accessor.h" #include #include "gmock/gmock.h" #include "google/protobuf/text_format.h" #include "gtest/gtest.h" #include "monolith/native_training/runtime/hash_table/embedding_hash_table.pb.h" namespace monolith { namespace hash_table { namespace { namespace proto2 = google::protobuf; using ::testing::ElementsAre; using ::testing::FloatEq; using ::testing::FloatNear; TEST(EntryAccessorTest, FromConfig) { EntryConfig config; ASSERT_TRUE(proto2::TextFormat::ParseFromString(R"( segments { dim_size: 1 init_config { zeros {} } opt_config { sgd {} } } segments { dim_size: 2 init_config { zeros {} } opt_config { sgd {} } } )", &config)); auto accessor = NewEntryAccessor(config); auto entry = std::make_unique(accessor->SizeBytes()); accessor->Init(entry.get()); accessor->Optimize(entry.get(), {1.0f, 2.0f, 3.0f}, {1.0f, 2.0f}, 0); std::vector num(3); accessor->Fill(entry.get(), absl::MakeSpan(num)); EXPECT_THAT(absl::MakeSpan(num), ElementsAre(-1.0f, -4.0f, -6.0f)); } TEST(EntryAccessorTest, SaveRestore) { EntryConfig config; ASSERT_TRUE(proto2::TextFormat::ParseFromString(R"( segments { dim_size: 1 init_config { zeros {} } opt_config { adagrad { initial_accumulator_value: 0.1 } } } )", &config)); auto accessor = NewEntryAccessor(config); auto entry1 = std::make_unique(accessor->SizeBytes()); accessor->Init(entry1.get()); accessor->Optimize(entry1.get(), {1.0f}, {1.0f}); std::vector num(1); accessor->Fill(entry1.get(), absl::MakeSpan(num)); ASSERT_THAT(absl::MakeSpan(num), ElementsAre(FloatEq(-0.95346254f))); EntryDump dump = accessor->Save(entry1.get(), 100); auto entry2 = std::make_unique(accessor->SizeBytes()); uint32_t timestamp_sec; accessor->Restore(entry2.get(), ×tamp_sec, dump); EXPECT_EQ(timestamp_sec, 100); accessor->Optimize(entry2.get(), {1.0f}, {1.0f}, 0); accessor->Fill(entry2.get(), absl::MakeSpan(num)); ASSERT_THAT(absl::MakeSpan(num), ElementsAre(FloatEq(-1.643528f))); } TEST(EntryAccessorTest, Update) { std::unordered_map configs = {{"fp32", R"( segments { dim_size: 3 init_config { zeros {} } opt_config { sgd {} } } )"}, {"fp16", R"( segments { dim_size: 3 init_config { zeros {} } opt_config { sgd {} } } )"}}; for (const auto& kv : configs) { EntryConfig config; ASSERT_TRUE(proto2::TextFormat::ParseFromString(kv.second, &config)); auto accessor = NewEntryAccessor(config); auto entry = std::make_unique(accessor->SizeBytes()); std::vector num = {0.1, 0.2, 0.3}; accessor->Assign(absl::MakeSpan(num), entry.get()); std::vector embedding(3); accessor->Fill(entry.get(), absl::MakeSpan(embedding)); if (kv.first == "fp32") { EXPECT_THAT(embedding, ElementsAre(0.1, 0.2, 0.3)); } if (kv.first == "fp16") { float eps = 0.0001; EXPECT_THAT(embedding, ElementsAre(FloatNear(0.1, eps), FloatNear(0.2, eps), FloatNear(0.3, eps))); } } } TEST(ServingEntryAccessorTest, Basic) { EntryConfig config; ASSERT_TRUE(proto2::TextFormat::ParseFromString(R"( segments { dim_size: 1 comp_config { fp32 {} } } entry_type: SERVING )", &config)); auto accessor = NewEntryAccessor(config); auto entry = std::make_unique(accessor->SizeBytes()); EntryDump dump; dump.add_num(1.0); dump.set_last_update_ts_sec(100); uint32_t timestamp_sec; accessor->Restore(entry.get(), ×tamp_sec, dump); EXPECT_THAT(timestamp_sec, timestamp_sec); std::vector out(1); accessor->Fill(entry.get(), absl::MakeSpan(out)); EXPECT_THAT(out, ElementsAre(1.0)); } TEST(ServingEntryAccessorTest, Update) { std::unordered_map configs = {{"fp32", R"( segments { dim_size: 3 comp_config { fp32 {} } } entry_type: SERVING )"}, {"fp16", R"( segments { dim_size: 3 comp_config { fp16 {} } } entry_type: SERVING )"}}; for (const auto& kv : configs) { EntryConfig config; ASSERT_TRUE(proto2::TextFormat::ParseFromString(kv.second, &config)); auto accessor = NewEntryAccessor(config); auto entry = std::make_unique(accessor->SizeBytes()); std::vector num = {0.1, 0.2, 0.3}; accessor->Assign(absl::MakeSpan(num), entry.get()); std::vector embedding(3); accessor->Fill(entry.get(), absl::MakeSpan(embedding)); if (kv.first == "fp32") { EXPECT_THAT(embedding, ElementsAre(0.1, 0.2, 0.3)); } if (kv.first == "fp16") { float eps = 0.0001; EXPECT_THAT(embedding, ElementsAre(FloatNear(0.1, eps), FloatNear(0.2, eps), FloatNear(0.3, eps))); } } } } // namespace } // namespace hash_table } // namespace monolith ================================================ FILE: monolith/native_training/runtime/hash_table/entry_defs.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/runtime/allocator/block_allocator.h" namespace monolith { namespace hash_table { // A wrapper for raw pointer. This helps utilize try_emplace in map. // TODO(leqi.zou): Essentailly we want to deprecate this. Will remove once // we find this is not useful. class PackedEntry { public: explicit PackedEntry(allocator::TSEmbeddingBlockAllocator* alloc) : p_(alloc->AllocateOne()), timestamp_(0) {} allocator::EntryAddress get_entry_addr() const { return p_; } uint32_t GetTimestamp() const { return timestamp_; } void SetTimestamp(uint32_t timestamp_sec) { timestamp_ = timestamp_sec; } private: allocator::EntryAddress p_; // Unix timestamp in seconds, UINT32_MAX means 2106-02-07 14:28:15+08:00 uint32_t timestamp_; }; class RawEntry { public: RawEntry(size_t entry_size) : p_(new char[entry_size]) {} void* get() const { return p_.get(); } uint32_t GetTimestamp() const { return timestamp_; } void SetTimestamp(uint32_t timestamp_sec) { timestamp_ = timestamp_sec; } private: std::unique_ptr p_; // Unix timestamp in seconds, UINT32_MAX means 2106-02-07 14:28:15+08:00 uint32_t timestamp_; }; template class InlineEntry { public: static_assert(length % 8 == 0 && length > 0, "InlineEntry's should be divisible by 8."); InlineEntry() { static_assert(sizeof(InlineEntry) == length, "InlineEntry's implementation is wrong"); } static int capacity() { return length - 4; } const void* get() const { return buffer_; } void* get() { return buffer_; } uint32_t GetTimestamp() const { return *reinterpret_cast(buffer_ + length - 4); } void SetTimestamp(uint32_t timestamp_sec) { *reinterpret_cast(buffer_ + length - 4) = timestamp_sec; } private: char buffer_[length]; }; } // namespace hash_table } // namespace monolith ================================================ FILE: monolith/native_training/runtime/hash_table/entry_defs_test.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/runtime/hash_table/entry_defs.h" #include "gtest/gtest.h" #include "gmock/gmock.h" namespace monolith { namespace hash_table { TEST(InlineEntryTest, Basic) { InlineEntry<8> entry; *reinterpret_cast(entry.get()) = 1.0; EXPECT_THAT(entry.capacity(), 4); entry.SetTimestamp(1234); EXPECT_THAT(entry.GetTimestamp(), 1234); EXPECT_THAT(*reinterpret_cast(entry.get()), 1.0); } } // namespace hash_table } // namespace monolith ================================================ FILE: monolith/native_training/runtime/hash_table/hash_table_benchmark.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // // Created by david on 2020-11-27. // #include #include "absl/container/flat_hash_map.h" #include "absl/random/random.h" #include "absl/strings/str_format.h" #include "benchmark/benchmark.h" #include "glog/logging.h" #include "google/protobuf/text_format.h" #include "monolith/native_training/runtime/concurrency/thread_pool.h" #include "monolith/native_training/runtime/hash_table/cuckoohash/cuckoo_embedding_hash_table.h" #include "monolith/native_training/runtime/hash_table/embedding_hash_table_factory.h" namespace monolith { namespace hash_table { namespace { namespace proto2 = ::google::protobuf; EmbeddingHashTableConfig SetupHashTableConfig(size_t dim) { EmbeddingHashTableConfig config; CHECK(proto2::TextFormat::ParseFromString(absl::StrFormat(R"( entry_config { segments { dim_size: %lu init_config { zeros {} } opt_config { sgd {} } } } cuckoo {} )", dim), &config)); return config; } void BM_Insert(benchmark::State& state) { // NOLINT auto entry_num = state.range(0); auto thread_num = state.range(1); auto dim = state.range(2); monolith::concurrency::ThreadPool thread_pool(thread_num); auto config = SetupHashTableConfig(dim); auto table = NewEmbeddingHashTableFromConfig(config); std::vector vector(dim); absl::BitGen bit_gen; for (auto& val : vector) { val = absl::Uniform(bit_gen, -1.f, 1.f); } for (auto _ : state) { std::atomic_int join(thread_num); auto AssignAdd = [&]() { for (size_t i = 0; i < entry_num; ++i) { table->AssignAdd(i, absl::MakeSpan(vector), 0); } --join; }; for (int64_t i = 0; i < thread_num; ++i) { thread_pool.Schedule(AssignAdd); } while (join) { } } } void BM_Find(benchmark::State& state) { // NOLINT auto entry_num = state.range(0); auto thread_num = state.range(1); auto dim = state.range(2); monolith::concurrency::ThreadPool thread_pool(thread_num); auto config = SetupHashTableConfig(dim); auto table = NewEmbeddingHashTableFromConfig(config); std::vector vector(dim); absl::BitGen bit_gen; for (auto& val : vector) { val = absl::Uniform(bit_gen, -1.f, 1.f); } for (size_t i = 0; i < entry_num; ++i) { table->AssignAdd(i, absl::MakeSpan(vector), 0); } std::vector ids_to_find(entry_num / thread_num); for (size_t i = 0; i < ids_to_find.size(); ++i) { ids_to_find[i] = absl::Uniform(bit_gen, 0, 2 * entry_num); } for (auto _ : state) { std::atomic_int join(thread_num); auto Lookup = [&]() { std::vector vector(dim); for (int64_t id : ids_to_find) { table->Lookup(id, absl::MakeSpan(vector)); } --join; }; for (int64_t i = 0; i < thread_num; ++i) { thread_pool.Schedule(Lookup); } while (join) { } } } /* Run on (12 X 2592 MHz CPU s) CPU Caches: L1 Data 32 KiB (x12) L1 Instruction 32 KiB (x12) L2 Unified 256 KiB (x12) L3 Unified 12288 KiB (x1) Load Average: 1.74, 1.35, 0.68 ------------------------------------------------------------------ Benchmark Time CPU Iterations ------------------------------------------------------------------ BM_Insert/10000000/1/32 7834986861 ns 7833447327 ns 1 BM_Insert/1000000/10/32 5021482710 ns 5019782037 ns 1 BM_Find/10000000/1/32 6015355836 ns 6015255039 ns 1 BM_Find/10000000/10/32 4677797500 ns 4677690867 ns 1 */ // single thread, insert 10^7 entries BENCHMARK(BM_Insert)->Args({1000 * 10000, 1, 32}); // 10 threads, insert 10^7 = 10^6 * 10 entries BENCHMARK(BM_Insert)->Args({100 * 10000, 10, 32}); // single thread, find 10^7 times from 10^7 entries BENCHMARK(BM_Find)->Args({1000 * 10000, 1, 32}); // 10 threads, find 10^7 = 10^6 * 10 times from 10^7 entries BENCHMARK(BM_Find)->Args({1000 * 10000, 10, 32}); } // namespace } // namespace hash_table } // namespace monolith BENCHMARK_MAIN(); ================================================ FILE: monolith/native_training/runtime/hash_table/initializer/BUILD ================================================ load("@rules_cc//cc:defs.bzl", "cc_library", "cc_proto_library", "cc_test") load("@rules_proto//proto:defs.bzl", "proto_library") load("@com_google_protobuf//:protobuf.bzl", "py_proto_library") package(default_visibility = ["//monolith/native_training/runtime/hash_table:__subpackages__"]) cc_library( name = "initializer_interface", hdrs = ["initializer_interface.h"], deps = [ "@com_google_absl//absl/status", "@com_google_absl//absl/types:span", ], ) proto_library( name = "initializer_config_proto", srcs = ["initializer_config.proto"], ) cc_proto_library( name = "initializer_config_cc_proto", deps = [":initializer_config_proto"], ) py_proto_library( name = "initializer_config_py_proto", srcs = ["initializer_config.proto"], srcs_version = "PY2AND3", visibility = ["//visibility:public"], ) cc_library( name = "initializer_factory", srcs = ["initializer_factory.cc"], hdrs = ["initializer_factory.h"], deps = [ "random_uniform_initializer", ":constants_initializer", ":initializer_config_cc_proto", ":initializer_interface", ], ) cc_library( name = "constants_initializer", srcs = ["constants_initializer.cc"], hdrs = ["constants_initializer.h"], deps = [ ":initializer_config_cc_proto", ":initializer_interface", "@com_google_absl//absl/strings:str_format", ], ) cc_library( name = "random_uniform_initializer", srcs = ["random_uniform_initializer.cc"], hdrs = ["random_uniform_initializer.h"], deps = [ ":initializer_config_cc_proto", ":initializer_interface", "@com_google_absl//absl/strings:str_format", ], ) cc_test( name = "random_uniform_initializer_test", srcs = ["random_uniform_initializer_test.cc"], deps = [ ":initializer_config_cc_proto", ":initializer_interface", ":random_uniform_initializer", "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", ], ) cc_library( name = "initializer_combination", srcs = ["initializer_combination.cc"], hdrs = ["initializer_combination.h"], deps = [ ":initializer_interface", "@com_google_absl//absl/strings:str_format", ], ) cc_test( name = "initializer_combination_test", srcs = ["initializer_combination_test.cc"], deps = [ ":constants_initializer", ":initializer_combination", ":initializer_interface", "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", ], ) ================================================ FILE: monolith/native_training/runtime/hash_table/initializer/constants_initializer.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "monolith/native_training/runtime/hash_table/initializer/constants_initializer.h" #include "absl/strings/str_format.h" namespace monolith { namespace hash_table { namespace { class ConstantsInitializer : public InitializerInterface { public: explicit ConstantsInitializer(int dim_size, float constant) : dim_size_(dim_size), constant_(constant) {} int DimSize() const override { return dim_size_; } void Initialize(absl::Span nums) const override { for (int i = 0; i < dim_size_; ++i) { nums[i] = constant_; } } std::string DebugString() const override { return absl::StrFormat("Constants(D=%d, C=%f)", dim_size_, constant_); } private: int dim_size_; float constant_; }; } // namespace std::unique_ptr NewZerosInitializer( ZerosInitializerConfig config) { return std::make_unique(config.dim_size(), 0); } std::unique_ptr NewZerosInitializer(int dim_size) { return std::make_unique(dim_size, 0); } std::unique_ptr NewOnesInitializer( OnesInitializerConfig config) { return std::make_unique(config.dim_size(), 1); } std::unique_ptr NewConstantsInitializer( ConstantsInitializerConfig config) { return std::make_unique(config.dim_size(), config.constant()); } std::unique_ptr NewConstantsInitializer(int dim_size, float constant) { return std::make_unique(dim_size, constant); } } // namespace hash_table } // namespace monolith ================================================ FILE: monolith/native_training/runtime/hash_table/initializer/constants_initializer.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_INITIALIZER_CONSTANTS_INITIALIZER #define MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_INITIALIZER_CONSTANTS_INITIALIZER #include #include "monolith/native_training/runtime/hash_table/initializer/initializer_config.pb.h" #include "monolith/native_training/runtime/hash_table/initializer/initializer_interface.h" namespace monolith { namespace hash_table { std::unique_ptr NewZerosInitializer( ZerosInitializerConfig config); std::unique_ptr NewZerosInitializer(int dim_size); std::unique_ptr NewOnesInitializer( OnesInitializerConfig config); std::unique_ptr NewConstantsInitializer( ConstantsInitializerConfig config); std::unique_ptr NewConstantsInitializer(int dim_size, float constant); } // namespace hash_table } // namespace monolith #endif // MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_INITIALIZER_CONSTANTS_INITIALIZER ================================================ FILE: monolith/native_training/runtime/hash_table/initializer/initializer_combination.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/runtime/hash_table/initializer/initializer_combination.h" #include "absl/strings/str_format.h" namespace monolith { namespace hash_table { namespace { class CombinedInitializer : public InitializerInterface { public: CombinedInitializer(std::unique_ptr init1, std::unique_ptr init2) : init1_(std::move(init1)), init2_(std::move(init2)) {} int DimSize() const override { return init1_->DimSize() + init2_->DimSize(); } void Initialize(absl::Span nums) const override { init1_->Initialize(nums); init2_->Initialize(nums.subspan(init1_->DimSize())); } std::string DebugString() const override { return absl::StrFormat("%s|%s", init1_->DebugString(), init2_->DebugString()); } private: std::unique_ptr init1_; std::unique_ptr init2_; }; } // namespace std::unique_ptr CombineInitializers( std::unique_ptr init1, std::unique_ptr init2) { return std::make_unique(std::move(init1), std::move(init2)); } } // namespace hash_table } // namespace monolith ================================================ FILE: monolith/native_training/runtime/hash_table/initializer/initializer_combination.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_INITIALIZER_INITIALIZER_COMBINATION #define MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_INITIALIZER_INITIALIZER_COMBINATION #include #include "monolith/native_training/runtime/hash_table/initializer/initializer_interface.h" namespace monolith { namespace hash_table { // A entry may be initialized by different initializers so we need to combine // two initializers. std::unique_ptr CombineInitializers( std::unique_ptr init1, std::unique_ptr init2); } // namespace hash_table } // namespace monolith #endif // MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_INITIALIZER_INITIALIZER_COMBINATION ================================================ FILE: monolith/native_training/runtime/hash_table/initializer/initializer_combination_test.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/runtime/hash_table/initializer/initializer_combination.h" #include #include "absl/types/span.h" #include "gmock/gmock.h" #include "gtest/gtest.h" #include "monolith/native_training/runtime/hash_table/initializer/constants_initializer.h" namespace monolith { namespace hash_table { namespace { using ::testing::ElementsAre; TEST(RandomUniformInitializer, Basic) { std::vector num(3, 1); auto init1 = NewConstantsInitializer(1, 3); auto init2 = NewConstantsInitializer(2, 4); auto combined_init = CombineInitializers(std::move(init1), std::move(init2)); EXPECT_THAT(combined_init->DimSize(), 3); combined_init->Initialize(absl::Span(num)); EXPECT_THAT(num, ElementsAre(3, 4, 4)); } } // namespace } // namespace hash_table } // namespace monolith ================================================ FILE: monolith/native_training/runtime/hash_table/initializer/initializer_config.proto ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. syntax="proto2"; package monolith.hash_table; message ZerosInitializerConfig { optional int32 dim_size = 1; } message OnesInitializerConfig { optional int32 dim_size = 1; } message ConstantsInitializerConfig { optional int32 dim_size = 1; optional float constant = 2; } message RandomUniformInitializerConfig { optional int32 dim_size = 1; optional float minval = 2 [default=-0.05]; optional float maxval = 3 [default=0.05]; } message InitializerConfig { oneof type { ZerosInitializerConfig zeros = 1; RandomUniformInitializerConfig random_uniform = 2; OnesInitializerConfig ones = 3; ConstantsInitializerConfig constants = 15; } } ================================================ FILE: monolith/native_training/runtime/hash_table/initializer/initializer_factory.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/runtime/hash_table/initializer/initializer_factory.h" #include #include "absl/strings/str_format.h" #include "monolith/native_training/runtime/hash_table/initializer/constants_initializer.h" #include "monolith/native_training/runtime/hash_table/initializer/random_uniform_initializer.h" namespace monolith { namespace hash_table { std::unique_ptr NewInitializerFromConfig( InitializerConfig config) { switch (config.type_case()) { case InitializerConfig::kZeros: return NewZerosInitializer(std::move(*config.mutable_zeros())); case InitializerConfig::kRandomUniform: return NewRandomUniformInitializer( std::move(*config.mutable_random_uniform())); case InitializerConfig::kOnes: return NewOnesInitializer(std::move(*config.mutable_ones())); case InitializerConfig::kConstants: return NewConstantsInitializer(std::move(*config.mutable_constants())); default: throw std::invalid_argument(absl::StrFormat("Unsupported initializer: %s", config.ShortDebugString())); } } } // namespace hash_table } // namespace monolith ================================================ FILE: monolith/native_training/runtime/hash_table/initializer/initializer_factory.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_INITIALIZER_INITIALIZER_FACTORY #define MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_INITIALIZER_INITIALIZER_FACTORY #include #include "monolith/native_training/runtime/hash_table/initializer/initializer_config.pb.h" #include "monolith/native_training/runtime/hash_table/initializer/initializer_interface.h" namespace monolith { namespace hash_table { std::unique_ptr NewInitializerFromConfig( InitializerConfig config); } // namespace hash_table } // namespace monolith #endif // MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_INITIALIZER_INITIALIZER_FACTORY ================================================ FILE: monolith/native_training/runtime/hash_table/initializer/initializer_interface.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_INITIALIZER_INTERFACE #define MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_INITIALIZER_INTERFACE #include "absl/types/span.h" namespace monolith { namespace hash_table { class InitializerInterface { public: virtual ~InitializerInterface() = default; virtual int DimSize() const = 0; virtual void Initialize(absl::Span nums) const = 0; virtual std::string DebugString() const = 0; }; } // namespace hash_table } // namespace monolith #endif // MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_INITIALIZER_INTERFACE ================================================ FILE: monolith/native_training/runtime/hash_table/initializer/random_uniform_initializer.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/runtime/hash_table/initializer/random_uniform_initializer.h" #include #include "absl/strings/str_format.h" namespace monolith { namespace hash_table { namespace { class RandomUniformInitializer : public InitializerInterface { public: explicit RandomUniformInitializer(RandomUniformInitializerConfig conf) : conf_(std::move(conf)) {} int DimSize() const override { return conf_.dim_size(); } void Initialize(absl::Span nums) const override { thread_local std::mt19937 generator; std::uniform_real_distribution distribution(conf_.minval(), conf_.maxval()); for (int i = 0; i < conf_.dim_size(); ++i) { nums[i] = distribution(generator); } } std::string DebugString() const override { return absl::StrFormat("RandomUniform(D=%d, min=%f, max=%f)", DimSize(), conf_.minval(), conf_.maxval()); } private: RandomUniformInitializerConfig conf_; }; } // namespace std::unique_ptr NewRandomUniformInitializer( RandomUniformInitializerConfig config) { return std::make_unique(std::move(config)); } } // namespace hash_table } // namespace monolith ================================================ FILE: monolith/native_training/runtime/hash_table/initializer/random_uniform_initializer.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_INITIALIZER_RANDOM_UNIFORM_INITIALIZER #define MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_INITIALIZER_RANDOM_UNIFORM_INITIALIZER #include #include "monolith/native_training/runtime/hash_table/initializer/initializer_config.pb.h" #include "monolith/native_training/runtime/hash_table/initializer/initializer_interface.h" namespace monolith { namespace hash_table { std::unique_ptr NewRandomUniformInitializer( RandomUniformInitializerConfig config); } // namespace hash_table } // namespace monolith #endif // MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_INITIALIZER_RANDOM_UNIFORM_INITIALIZER ================================================ FILE: monolith/native_training/runtime/hash_table/initializer/random_uniform_initializer_test.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/runtime/hash_table/initializer/random_uniform_initializer.h" #include #include "absl/types/span.h" #include "gmock/gmock.h" #include "gtest/gtest.h" #include "monolith/native_training/runtime/hash_table/initializer/initializer_config.pb.h" namespace monolith { namespace hash_table { namespace { using ::testing::Gt; using ::testing::Lt; TEST(RandomUniformInitializer, Basic) { const int kDimSize = 1000; std::vector num(kDimSize, 0); RandomUniformInitializerConfig config; config.set_dim_size(kDimSize); config.set_minval(-1); config.set_maxval(1); auto initializer = NewRandomUniformInitializer(config); initializer->Initialize(absl::Span(num)); EXPECT_THAT(*std::max_element(num.begin(), num.end()), Gt(0.9)); EXPECT_THAT(*std::min_element(num.begin(), num.end()), Lt(0.9)); } } // namespace } // namespace hash_table } // namespace monolith ================================================ FILE: monolith/native_training/runtime/hash_table/optimizer/BUILD ================================================ load("@rules_cc//cc:defs.bzl", "cc_library", "cc_proto_library", "cc_test") load("@rules_proto//proto:defs.bzl", "proto_library") load("@com_google_protobuf//:protobuf.bzl", "py_proto_library") package(default_visibility = ["//monolith/native_training/runtime:__subpackages__"]) proto_library( name = "optimizer_proto", srcs = ["optimizer.proto"], ) cc_proto_library( name = "optimizer_cc_proto", deps = [":optimizer_proto"], ) py_proto_library( name = "optimizer_py_proto", srcs = ["optimizer.proto"], srcs_version = "PY2AND3", visibility = ["//visibility:public"], ) cc_library( name = "optimizer_interface", hdrs = ["optimizer_interface.h"], deps = [ ":optimizer_cc_proto", "@com_google_absl//absl/base", "@com_google_absl//absl/types:span", ], ) cc_library( name = "optimizer_decorator", hdrs = ["optimizer_decorator.h"], deps = [ ":optimizer_cc_proto", ":optimizer_interface", "@com_google_absl//absl/base", "@com_google_absl//absl/types:span", ], ) cc_library( name = "stochastic_rounding", srcs = ["stochastic_rounding.cc"], hdrs = ["stochastic_rounding.h"], deps = [ ":optimizer_cc_proto", ":optimizer_interface", "//third_party/half_sourceforge_net:half", ], ) cc_test( name = "stochastic_rounding_test", srcs = ["stochastic_rounding_test.cc"], deps = [ ":optimizer_cc_proto", ":optimizer_factory", ":stochastic_rounding", ":test_utils", "@com_google_googletest//:gtest_main", ], ) cc_library( name = "optimizer_factory", srcs = ["optimizer_factory.cc"], hdrs = ["optimizer_factory.h"], deps = [ ":adadelta_optimizer", ":adagrad_optimizer", ":adam_optimizer", ":amsgrad_optimizer", ":batch_softmax_optimizer", ":dynamic_wd_adagrad_optimizer", ":ftrl_optimizer", ":group_ftrl_optimizer", ":group_adagrad_optimizer", ":momentum_optimizer", ":moving_average_optimizer", ":rmsprop_optimizer", ":sgd_optimizer", ":stochastic_rounding", "@com_google_absl//absl/strings:str_format", ], ) cc_library( name = "adagrad_optimizer_internal_deps", ) cc_library( name = "adagrad_optimizer", srcs = ["adagrad_optimizer.cc"], hdrs = ["adagrad_optimizer.h"], copts = [ "-D_ENABLE_AVX", ], deps = [ ":adagrad_optimizer_internal_deps", ":avx_utils", ":optimizer_cc_proto", ":optimizer_interface", "@com_google_absl//absl/strings:str_format", ], ) cc_library( name = "batch_softmax_optimizer", srcs = ["batch_softmax_optimizer.cc"], hdrs = ["batch_softmax_optimizer.h"], deps = [ ":optimizer_cc_proto", ":optimizer_interface", "@com_google_absl//absl/strings:str_format", "@com_google_glog//:glog", ], ) cc_library( name = "dynamic_wd_adagrad_optimizer_internal_deps", ) cc_library( name = "dynamic_wd_adagrad_optimizer", srcs = ["dynamic_wd_adagrad_optimizer.cc"], hdrs = ["dynamic_wd_adagrad_optimizer.h"], copts = [ "-D_ENABLE_AVX", ], deps = [ ":dynamic_wd_adagrad_optimizer_internal_deps", ":dynamic_wd_avx_utils", ":optimizer_cc_proto", ":optimizer_interface", "@com_google_absl//absl/strings:str_format", ], ) cc_library( name = "ftrl_optimizer", srcs = ["ftrl_optimizer.cc"], hdrs = ["ftrl_optimizer.h"], deps = [ ":optimizer_cc_proto", ":optimizer_interface", "@com_google_absl//absl/strings:str_format", ], ) cc_test( name = "ftrl_optimizer_test", srcs = ["ftrl_optimizer_test.cc"], deps = [ ":ftrl_optimizer", ":optimizer_cc_proto", ":test_utils", "@com_google_googletest//:gtest_main", ], ) cc_test( name = "adagrad_optimizer_test", srcs = ["adagrad_optimizer_test.cc"], deps = [ ":adagrad_optimizer", ":avx_utils", ":optimizer_cc_proto", ":test_utils", "@com_google_googletest//:gtest_main", ], ) cc_test( name = "batch_softmax_optimizer_test", srcs = ["batch_softmax_optimizer_test.cc"], deps = [ ":batch_softmax_optimizer", ":test_utils", "@com_google_googletest//:gtest_main", ], ) cc_test( name = "dynamic_wd_adagrad_optimizer_test", srcs = ["dynamic_wd_adagrad_optimizer_test.cc"], deps = [ ":dynamic_wd_adagrad_optimizer", ":dynamic_wd_avx_utils", ":optimizer_cc_proto", ":test_utils", "@com_google_googletest//:gtest_main", ], ) cc_library( name = "sgd_optimizer", srcs = ["sgd_optimizer.cc"], hdrs = ["sgd_optimizer.h"], deps = [ ":optimizer_cc_proto", ":optimizer_interface", "@com_google_absl//absl/strings:str_format", ], ) cc_test( name = "sgd_optimizer_test", srcs = ["sgd_optimizer_test.cc"], deps = [ ":optimizer_cc_proto", ":sgd_optimizer", ":test_utils", "@com_google_googletest//:gtest_main", ], ) cc_library( name = "adadelta_optimizer", srcs = ["adadelta_optimizer.cc"], hdrs = ["adadelta_optimizer.h"], deps = [ ":optimizer_cc_proto", ":optimizer_interface", "@com_google_absl//absl/strings:str_format", ], ) cc_test( name = "adadelta_optimizer_test", srcs = ["adadelta_optimizer_test.cc"], deps = [ ":adadelta_optimizer", ":optimizer_cc_proto", ":test_utils", "@com_google_googletest//:gtest_main", ], ) cc_library( name = "adam_optimizer", srcs = ["adam_optimizer.cc"], hdrs = ["adam_optimizer.h"], deps = [ ":optimizer_cc_proto", ":optimizer_interface", "@com_google_absl//absl/strings:str_format", ], ) cc_test( name = "adam_optimizer_test", srcs = ["adam_optimizer_test.cc"], deps = [ ":adam_optimizer", ":optimizer_cc_proto", ":test_utils", "@com_google_googletest//:gtest_main", ], ) cc_library( name = "amsgrad_optimizer", srcs = ["amsgrad_optimizer.cc"], hdrs = ["amsgrad_optimizer.h"], deps = [ ":optimizer_cc_proto", ":optimizer_interface", "@com_google_absl//absl/strings:str_format", ], ) cc_test( name = "amsgrad_optimizer_test", srcs = ["amsgrad_optimizer_test.cc"], deps = [ ":amsgrad_optimizer", ":optimizer_cc_proto", ":test_utils", "@com_google_googletest//:gtest_main", ], ) cc_library( name = "momentum_optimizer", srcs = ["momentum_optimizer.cc"], hdrs = ["momentum_optimizer.h"], deps = [ ":optimizer_cc_proto", ":optimizer_interface", "@com_google_absl//absl/strings:str_format", ], ) cc_test( name = "momentum_optimizer_test", srcs = ["momentum_optimizer_test.cc"], deps = [ ":momentum_optimizer", ":optimizer_cc_proto", ":test_utils", "@com_google_googletest//:gtest_main", ], ) cc_library( name = "moving_average_optimizer", srcs = ["moving_average_optimizer.cc"], hdrs = ["moving_average_optimizer.h"], deps = [ ":optimizer_cc_proto", ":optimizer_interface", "@com_google_absl//absl/strings:str_format", ], ) cc_test( name = "moving_average_optimizer_test", srcs = ["moving_average_optimizer_test.cc"], deps = [ ":moving_average_optimizer", ":optimizer_cc_proto", ":test_utils", "@com_google_googletest//:gtest_main", ], ) cc_library( name = "rmsprop_optimizer", srcs = ["rmsprop_optimizer.cc"], hdrs = ["rmsprop_optimizer.h"], deps = [ ":optimizer_cc_proto", ":optimizer_interface", "@com_google_absl//absl/strings:str_format", ], ) cc_test( name = "rmsprop_optimizer_test", srcs = ["rmsprop_optimizer_test.cc"], deps = [ ":optimizer_cc_proto", ":rmsprop_optimizer", ":test_utils", "@com_google_googletest//:gtest_main", ], ) cc_library( name = "dc_optimizer", srcs = ["dc_optimizer.cc"], hdrs = ["dc_optimizer.h"], deps = [ ":optimizer_cc_proto", ":optimizer_decorator", ":optimizer_interface", ], ) cc_test( name = "dc_optimizer_test", srcs = ["dc_optimizer_test.cc"], deps = [ ":adadelta_optimizer", ":dc_optimizer", ":optimizer_cc_proto", ":test_utils", "@com_google_googletest//:gtest_main", ], ) cc_library( name = "optimizer_combination", srcs = ["optimizer_combination.cc"], hdrs = ["optimizer_combination.h"], deps = [ ":optimizer_interface", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/strings:str_format", ], ) cc_test( name = "optimizer_combination_test", srcs = ["optimizer_combination_test.cc"], deps = [ ":adagrad_optimizer", ":optimizer_cc_proto", ":optimizer_combination", ":test_utils", "@com_google_googletest//:gtest_main", ], ) cc_library( name = "test_utils", testonly = 1, hdrs = ["test_utils.h"], deps = [ "@com_google_absl//absl/types:span", ], ) cc_library( name = "dynamic_wd_avx_utils", hdrs = ["dynamic_wd_avx_utils.h"], copts = [ "-D_ENABLE_AVX", ], deps = [ "@com_google_absl//absl/types:span", ], ) cc_test( name = "dynamic_wd_avx_test", testonly = 1, srcs = ["dynamic_wd_avx_test.cc"], copts = [ "-D_ENABLE_AVX", ], deps = [ ":dynamic_wd_avx_utils", "@com_google_absl//absl/random", "@com_google_googletest//:gtest_main", ], ) cc_library( name = "group_ftrl_optimizer", srcs = ["group_ftrl_optimizer.cc"], hdrs = ["group_ftrl_optimizer.h"], deps = [ ":avx_utils", ":optimizer_cc_proto", ":optimizer_interface", "@com_google_absl//absl/strings:str_format", ], ) cc_test( name = "group_ftrl_optimizer_test", srcs = ["group_ftrl_optimizer_test.cc"], deps = [ ":avx_utils", ":group_ftrl_optimizer", ":optimizer_cc_proto", ":test_utils", "@com_google_googletest//:gtest_main", ], ) cc_library( name = "group_adagrad_optimizer", srcs = ["group_adagrad_optimizer.cc"], hdrs = ["group_adagrad_optimizer.h"], deps = [ ":optimizer_cc_proto", ":optimizer_interface", ":avx_utils", "@com_google_absl//absl/strings:str_format", ], ) cc_test( name = "group_adagrad_optimizer_test", srcs = ["group_adagrad_optimizer_test.cc"], deps = [ ":group_adagrad_optimizer", ":optimizer_cc_proto", ":test_utils", ":avx_utils", "@com_google_googletest//:gtest_main", ], ) cc_library( name = "avx_utils", hdrs = ["avx_utils.h"], copts = [ "-D_ENABLE_AVX", ], deps = [ "@com_google_absl//absl/types:span", ], ) cc_test( name = "avx_test", testonly = 1, srcs = ["avx_test.cc"], copts = [ "-D_ENABLE_AVX", ], deps = [ ":avx_utils", "@com_google_absl//absl/random", "@com_google_googletest//:gtest_main", ], ) cc_binary( name = "avx_benchmark", testonly = 1, srcs = ["avx_benchmark.cc"], copts = [ "-D_ENABLE_AVX", ], deps = [ ":avx_utils", "//monolith/native_training/runtime/allocator:block_allocator", "//monolith/native_training/runtime/common:cpu_info", "@com_github_google_benchmark//:benchmark", "@com_google_absl//absl/random", "@com_google_glog//:glog", "@com_google_googletest//:gtest_main", ], ) ================================================ FILE: monolith/native_training/runtime/hash_table/optimizer/adadelta_optimizer.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/runtime/hash_table/optimizer/adadelta_optimizer.h" #include #include #include "absl/strings/str_format.h" namespace monolith { namespace hash_table { namespace { class AdadeltaOptimizer : public OptimizerInterface { public: explicit AdadeltaOptimizer(AdadeltaOptimizerConfig config) : conf_(std::move(config)) {} int64_t SizeBytes() const override { return 2 * conf_.dim_size() * sizeof(float); } int64_t UncompressedSizeBytes() const override { return SizeBytes(); } std::string DebugString() const override { return absl::StrFormat("Adadelta(D=%d)", DimSize()); } int DimSize() const override { return conf_.dim_size(); } int SliceSize() const override { return 1; } void Init(void* ctx) const override { float* accum = static_cast(ctx); float* accum_update = accum + conf_.dim_size(); for (int i = 0; i < conf_.dim_size(); ++i) { accum[i] = accum_update[i] = 0; } } void Optimize(void* ctx, absl::Span num, absl::Span grad, absl::Span learning_rates, const int64_t global_step) const override { float* accum = static_cast(ctx); float* accum_update = accum + conf_.dim_size(); float effective_lr = learning_rates[0]; for (int i = 0; i < conf_.dim_size(); ++i) { float cur_grad = grad[i] + conf_.weight_decay_factor() * num[i]; float new_accum = accum[i] * conf_.averaging_ratio() + cur_grad * cur_grad * (1 - conf_.averaging_ratio()); float update = std::sqrt(accum_update[i] + conf_.epsilon()) / std::sqrt(new_accum + conf_.epsilon()) * cur_grad; float new_w = num[i] - update * effective_lr; float new_accum_update = accum_update[i] * conf_.averaging_ratio() + update * update * (1 - conf_.averaging_ratio()); // printf("%d: %f %f %f %f %f\n", i, cur_grad, new_accum, update, new_w, // new_accum_update); num[i] = new_w; accum[i] = new_accum; accum_update[i] = new_accum_update; } } OptimizerDump Save(const void* ctx) const override { OptimizerDump dump; AdadeltaOptimizerDump* adadelta_dump = dump.add_dump()->mutable_adadelta(); const float* accum = static_cast(ctx); const float* accum_update = accum + conf_.dim_size(); for (int i = 0; i < conf_.dim_size(); ++i) { adadelta_dump->add_accum(accum[i]); adadelta_dump->add_accum_update(accum_update[i]); } return dump; } void Restore(void* ctx, OptimizerDump dump) const override { const AdadeltaOptimizerDump& adadelta_dump = dump.dump(0).adadelta(); float* accum = static_cast(ctx); float* accum_update = accum + conf_.dim_size(); for (int i = 0; i < conf_.dim_size(); ++i) { accum[i] = adadelta_dump.accum(i); accum_update[i] = adadelta_dump.accum_update(i); } } private: AdadeltaOptimizerConfig conf_; }; } // namespace std::unique_ptr NewAdadeltaOptimizer( AdadeltaOptimizerConfig config) { return std::make_unique(std::move(config)); } } // namespace hash_table } // namespace monolith ================================================ FILE: monolith/native_training/runtime/hash_table/optimizer/adadelta_optimizer.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_OPTIMIZER_ADADELTA_OPTIMIZER #define MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_OPTIMIZER_ADADELTA_OPTIMIZER #include "monolith/native_training/runtime/hash_table/optimizer/optimizer.pb.h" #include "monolith/native_training/runtime/hash_table/optimizer/optimizer_interface.h" namespace monolith { namespace hash_table { std::unique_ptr NewAdadeltaOptimizer( AdadeltaOptimizerConfig config); } // namespace hash_table } // namespace monolith #endif // MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_OPTIMIZER_ADADELTA_OPTIMIZER ================================================ FILE: monolith/native_training/runtime/hash_table/optimizer/adadelta_optimizer_test.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/runtime/hash_table/optimizer/adadelta_optimizer.h" #include #include "gmock/gmock.h" #include "gtest/gtest.h" #include "monolith/native_training/runtime/hash_table/optimizer/optimizer.pb.h" #include "monolith/native_training/runtime/hash_table/optimizer/test_utils.h" namespace monolith { namespace hash_table { namespace { using ::testing::Pointwise; using ::testing::FloatNear; using ::testing::ElementsAreArray; TEST(AdadeltaOptimizer, Basic) { AdadeltaOptimizerConfig config; config.set_dim_size(1); auto opt = NewAdadeltaOptimizer(config); TestOptimizerEntry mem(opt.get()); opt->Init(mem.mutable_ctx()); opt->Optimize(mem.mutable_ctx(), mem.mutable_num_span(), {10.0f}, {0.01f}); auto expected = {-0.0031607f}; ASSERT_THAT(mem.num(), Pointwise(FloatNear(1e-6), expected)); // Test dump & restore OptimizerDump dump = opt->Save(mem.ctx()); EXPECT_NEAR(dump.dump(0).adadelta().accum(0), 10, 1e-4); TestOptimizerEntry mem2(opt.get()); opt->Restore(mem2.mutable_ctx(), dump); *mem2.mutable_num() = mem.num(); opt->Optimize(mem.mutable_ctx(), mem.mutable_num_span(), {10.0f}, {0.01f}); auto expected2 = {-0.0064035f}; ASSERT_THAT(mem.num(), Pointwise(FloatNear(1e-6), expected2)); } TEST(AdadeltaOptimizer, ListUpdate) { AdadeltaOptimizerConfig config; config.set_dim_size(2); auto opt = NewAdadeltaOptimizer(config); TestOptimizerEntry mem(opt.get()); opt->Init(mem.mutable_ctx()); opt->Optimize(mem.mutable_ctx(), mem.mutable_num_span(), {10.0f, 1.0f}, {0.01f}); auto expected = {-0.0031607f, -0.0030151f}; ASSERT_THAT(mem.num(), Pointwise(FloatNear(1e-6), expected)); // Test dump & restore OptimizerDump dump = opt->Save(mem.ctx()); EXPECT_NEAR(dump.dump(0).adadelta().accum(0), 10, 1e-4); EXPECT_NEAR(dump.dump(0).adadelta().accum(1), .1, 1e-4); TestOptimizerEntry mem2(opt.get()); opt->Restore(mem2.mutable_ctx(), dump); *mem2.mutable_num() = mem.num(); opt->Optimize(mem.mutable_ctx(), mem.mutable_num_span(), {10.0f, 1.0f}, {0.01f}); auto expected2 = {-0.0064035f, -0.0061047f}; ASSERT_THAT(mem.num(), Pointwise(FloatNear(1e-6), expected2)); } } // namespace } // namespace hash_table } // namespace monolith ================================================ FILE: monolith/native_training/runtime/hash_table/optimizer/adagrad_optimizer.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include "absl/strings/str_format.h" #include "monolith/native_training/runtime/hash_table/optimizer/adagrad_optimizer.h" #include "monolith/native_training/runtime/hash_table/optimizer/avx_utils.h" namespace monolith { namespace hash_table { namespace { class AdagradOptimizer : public OptimizerInterface { public: explicit AdagradOptimizer(AdagradOptimizerConfig config) : conf_(std::move(config)) {} int64_t SizeBytes() const override { return conf_.dim_size() * sizeof(float); } int64_t UncompressedSizeBytes() const override { return conf_.dim_size() * sizeof(float); } std::string DebugString() const override { return absl::StrFormat("Adagrad(D=%d)", DimSize()); } int DimSize() const override { return conf_.dim_size(); } int SliceSize() const override { return 1; } void Init(void* ctx) const override { float* norm = static_cast(ctx); for (int i = 0; i < conf_.dim_size(); ++i) { norm[i] = conf_.initial_accumulator_value(); } } void Optimize(void* ctx, absl::Span num, absl::Span grad, absl::Span learning_rates, const int64_t global_step) const override { float* norm = static_cast(ctx); AdagradOptimize(num.data(), norm, grad.data(), conf_.dim_size(), learning_rates[0], conf_.weight_decay_factor()); } OptimizerDump Save(const void* ctx) const override { OptimizerDump dump; AdagradOptimizerDump* adagrad_dump = dump.add_dump()->mutable_adagrad(); const float* norm = static_cast(ctx); for (int i = 0; i < conf_.dim_size(); ++i) { adagrad_dump->add_norm(norm[i]); } return dump; } void Restore(void* ctx, OptimizerDump dump) const override { const AdagradOptimizerDump& adagrad_dump = dump.dump(0).adagrad(); float* norm = static_cast(ctx); for (int i = 0; i < conf_.dim_size(); ++i) { norm[i] = adagrad_dump.norm(i); } } private: AdagradOptimizerConfig conf_; }; } // namespace std::unique_ptr NewAdagradOptimizer( AdagradOptimizerConfig config) { return std::make_unique(std::move(config)); } } // namespace hash_table } // namespace monolith ================================================ FILE: monolith/native_training/runtime/hash_table/optimizer/adagrad_optimizer.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_OPTIMIZER_ADAGRAD_OPTIMIZER #define MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_OPTIMIZER_ADAGRAD_OPTIMIZER #include "monolith/native_training/runtime/hash_table/optimizer/optimizer.pb.h" #include "monolith/native_training/runtime/hash_table/optimizer/optimizer_interface.h" namespace monolith { namespace hash_table { std::unique_ptr NewAdagradOptimizer( AdagradOptimizerConfig config); } // namespace hash_table } // namespace monolith #endif // MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_OPTIMIZER_ADAGRAD_OPTIMIZER ================================================ FILE: monolith/native_training/runtime/hash_table/optimizer/adagrad_optimizer_test.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/runtime/hash_table/optimizer/adagrad_optimizer.h" #include #include "gmock/gmock.h" #include "gtest/gtest.h" #include "monolith/native_training/runtime/hash_table/optimizer/optimizer.pb.h" #include "monolith/native_training/runtime/hash_table/optimizer/test_utils.h" namespace monolith { namespace hash_table { namespace { using ::testing::Pointwise; using ::testing::FloatNear; using ::testing::ElementsAreArray; TEST(AdagradOptimizer, Basic) { AdagradOptimizerConfig config; config.set_dim_size(2); config.set_initial_accumulator_value(1.0f); auto opt = NewAdagradOptimizer(config); TestOptimizerEntry mem(opt.get()); opt->Init(mem.mutable_ctx()); opt->Optimize(mem.mutable_ctx(), mem.mutable_num_span(), {1.0f, 2.0f}, {0.1f}); auto expected = {-0.07071067, -0.08944272}; ASSERT_THAT(mem.num(), Pointwise(FloatNear(1e-6), expected)); // Test dump & restore OptimizerDump dump = opt->Save(mem.ctx()); TestOptimizerEntry mem2(opt.get()); opt->Restore(mem2.mutable_ctx(), dump); *mem2.mutable_num() = mem.num(); opt->Optimize(mem.mutable_ctx(), mem.mutable_num_span(), {1.0f, 1.0f}, {0.1f}); opt->Optimize(mem2.mutable_ctx(), mem2.mutable_num_span(), {1.0f, 1.0f}, {0.1f}); EXPECT_THAT(mem.num(), ElementsAreArray(mem2.num())); } TEST(AdagradOptimizer, OptimizeWithWeightDecay) { AdagradOptimizerConfig config; config.set_dim_size(2); config.set_initial_accumulator_value(1.0f); config.set_weight_decay_factor(0.1f); auto opt = NewAdagradOptimizer(config); TestOptimizerEntry mem(opt.get()); opt->Init(mem.mutable_ctx()); opt->Optimize(mem.mutable_ctx(), mem.mutable_num_span(), {1.0f, 2.0f}, {0.1f}); auto expected = {-0.07071067, -0.08944272}; ASSERT_THAT(mem.num(), Pointwise(FloatNear(1e-6), expected)); opt->Optimize(mem.mutable_ctx(), mem.mutable_num_span(), {1.0f, 2.0f}, {0.1f}); auto expected2 = {-0.128173, -0.155943}; ASSERT_THAT(mem.num(), Pointwise(FloatNear(1e-6), expected2)); // Test dump & restore OptimizerDump dump = opt->Save(mem.ctx()); TestOptimizerEntry mem2(opt.get()); opt->Restore(mem2.mutable_ctx(), dump); *mem2.mutable_num() = mem.num(); opt->Optimize(mem.mutable_ctx(), mem.mutable_num_span(), {1.0f, 1.0f}, {0.1f}); opt->Optimize(mem2.mutable_ctx(), mem2.mutable_num_span(), {1.0f, 1.0f}, {0.1f}); EXPECT_THAT(mem.num(), ElementsAreArray(mem2.num())); } } // namespace } // namespace hash_table } // namespace monolith ================================================ FILE: monolith/native_training/runtime/hash_table/optimizer/adam_optimizer.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include "absl/strings/str_format.h" #include "monolith/native_training/runtime/hash_table/optimizer/adam_optimizer.h" namespace monolith { namespace hash_table { namespace { class AdamOptimizer : public OptimizerInterface { public: explicit AdamOptimizer(AdamOptimizerConfig config) : conf_(std::move(config)) {} int64_t SizeBytes() const override { return (2 * conf_.dim_size() + 2) * sizeof(float); } int64_t UncompressedSizeBytes() const override { return SizeBytes(); } std::string DebugString() const override { return absl::StrFormat("Adam(D=%d)", DimSize()); } int DimSize() const override { return conf_.dim_size(); } int SliceSize() const override { return 1; } void Init(void* ctx) const override { float* m = static_cast(ctx); float* v = m + conf_.dim_size(); float& beta1_power = v[conf_.dim_size()]; float& beta2_power = v[conf_.dim_size() + 1]; for (int i = 0; i < conf_.dim_size(); ++i) { m[i] = v[i] = 0; } beta1_power = conf_.beta1(); beta2_power = conf_.beta2(); } void Optimize(void* ctx, absl::Span num, absl::Span grad, absl::Span learning_rates, const int64_t global_step) const override { float* m = static_cast(ctx); float* v = m + conf_.dim_size(); float& beta1_power = v[conf_.dim_size()]; float& beta2_power = v[conf_.dim_size() + 1]; float lr = learning_rates[0] * sqrt(1 - beta2_power) / (1 - beta1_power); for (int i = 0; i < conf_.dim_size(); ++i) { float cur_grad = grad[i] + conf_.weight_decay_factor() * num[i]; float new_m = m[i] + (cur_grad - m[i]) * (1 - conf_.beta1()); float new_v = v[i] + (cur_grad * cur_grad - v[i]) * (1 - conf_.beta2()); float new_w = num[i]; if (conf_.use_nesterov()) { new_w -= ((cur_grad * (1 - conf_.beta1()) + conf_.beta1() * new_m) * lr) / (sqrt(new_v) + conf_.epsilon()); } else { new_w -= (new_m * lr) / (sqrt(new_v) + conf_.epsilon()); } num[i] = new_w; m[i] = new_m; v[i] = new_v; } beta1_power *= conf_.beta1(); beta2_power *= conf_.beta2(); } OptimizerDump Save(const void* ctx) const override { OptimizerDump dump; AdamOptimizerDump* adam_dump = dump.add_dump()->mutable_adam(); const float* m = static_cast(ctx); const float* v = m + conf_.dim_size(); const float& beta1_power = v[conf_.dim_size()]; const float& beta2_power = v[conf_.dim_size() + 1]; for (int i = 0; i < conf_.dim_size(); ++i) { adam_dump->add_m(m[i]); adam_dump->add_v(v[i]); } adam_dump->set_beta1_power(beta1_power); adam_dump->set_beta2_power(beta2_power); return dump; } void Restore(void* ctx, OptimizerDump dump) const override { const AdamOptimizerDump& adam_dump = dump.dump(0).adam(); float* m = static_cast(ctx); float* v = m + conf_.dim_size(); float& beta1_power = v[conf_.dim_size()]; float& beta2_power = v[conf_.dim_size()]; for (int i = 0; i < conf_.dim_size(); ++i) { m[i] = adam_dump.m(i); v[i] = adam_dump.v(i); } beta1_power = adam_dump.beta1_power(); beta2_power = adam_dump.beta2_power(); } private: AdamOptimizerConfig conf_; }; } // namespace std::unique_ptr NewAdamOptimizer( AdamOptimizerConfig config) { return std::make_unique(std::move(config)); } } // namespace hash_table } // namespace monolith ================================================ FILE: monolith/native_training/runtime/hash_table/optimizer/adam_optimizer.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_OPTIMIZER_ADAM_OPTIMIZER #define MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_OPTIMIZER_ADAM_OPTIMIZER #include "monolith/native_training/runtime/hash_table/optimizer/optimizer.pb.h" #include "monolith/native_training/runtime/hash_table/optimizer/optimizer_interface.h" namespace monolith { namespace hash_table { std::unique_ptr NewAdamOptimizer( AdamOptimizerConfig config); } // namespace hash_table } // namespace monolith #endif // MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_OPTIMIZER_ADAM_OPTIMIZER ================================================ FILE: monolith/native_training/runtime/hash_table/optimizer/adam_optimizer_test.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/runtime/hash_table/optimizer/adam_optimizer.h" #include #include "gmock/gmock.h" #include "gtest/gtest.h" #include "monolith/native_training/runtime/hash_table/optimizer/optimizer.pb.h" #include "monolith/native_training/runtime/hash_table/optimizer/test_utils.h" namespace monolith { namespace hash_table { namespace { using ::testing::Pointwise; using ::testing::FloatNear; using ::testing::ElementsAreArray; TEST(AdamOptimizer, Basic) { AdamOptimizerConfig config; config.set_dim_size(1); auto opt = NewAdamOptimizer(config); TestOptimizerEntry mem(opt.get()); opt->Init(mem.mutable_ctx()); opt->Optimize(mem.mutable_ctx(), mem.mutable_num_span(), {10.0f}, {0.01f}); auto expected = {-0.00990099f}; ASSERT_THAT(mem.num(), Pointwise(FloatNear(1e-6), expected)); // Test dump & restore OptimizerDump dump = opt->Save(mem.ctx()); EXPECT_NEAR(dump.dump(0).adam().m(0), 1, 1e-4); TestOptimizerEntry mem2(opt.get()); opt->Restore(mem2.mutable_ctx(), dump); *mem2.mutable_num() = mem.num(); opt->Optimize(mem.mutable_ctx(), mem.mutable_num_span(), {10.0f}, {0.01f}); auto expected2 = {-0.01983060f}; ASSERT_THAT(mem.num(), Pointwise(FloatNear(1e-6), expected2)); } TEST(AdamOptimizer, ListUpdate) { AdamOptimizerConfig config; config.set_dim_size(2); auto opt = NewAdamOptimizer(config); TestOptimizerEntry mem(opt.get()); opt->Init(mem.mutable_ctx()); opt->Optimize(mem.mutable_ctx(), mem.mutable_num_span(), {10.0f, 1.0f}, {0.01f}); auto expected = {-0.00990099f, -0.00909091f}; ASSERT_THAT(mem.num(), Pointwise(FloatNear(1e-6), expected)); // Test dump & restore OptimizerDump dump = opt->Save(mem.ctx()); EXPECT_NEAR(dump.dump(0).adam().m(0), 1, 1e-4); EXPECT_NEAR(dump.dump(0).adam().m(1), .1, 1e-4); TestOptimizerEntry mem2(opt.get()); opt->Restore(mem2.mutable_ctx(), dump); *mem2.mutable_num() = mem.num(); opt->Optimize(mem.mutable_ctx(), mem.mutable_num_span(), {10.0f, 1.0f}, {0.01f}); auto expected2 = {-0.01983060f, -0.01842895f}; ASSERT_THAT(mem.num(), Pointwise(FloatNear(1e-6), expected2)); } } // namespace } // namespace hash_table } // namespace monolith ================================================ FILE: monolith/native_training/runtime/hash_table/optimizer/amsgrad_optimizer.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include "absl/strings/str_format.h" #include "monolith/native_training/runtime/hash_table/optimizer/amsgrad_optimizer.h" namespace monolith { namespace hash_table { namespace { class AmsgradOptimizer : public OptimizerInterface { public: explicit AmsgradOptimizer(AmsgradOptimizerConfig config) : conf_(std::move(config)) {} int64_t SizeBytes() const override { return (3 * conf_.dim_size() + 2) * sizeof(float); } int64_t UncompressedSizeBytes() const override { return SizeBytes(); } std::string DebugString() const override { return absl::StrFormat("Amsgrad(D=%d)", DimSize()); } int DimSize() const override { return conf_.dim_size(); } int SliceSize() const override { return 1; } void Init(void* ctx) const override { float* m = static_cast(ctx); float* v = m + conf_.dim_size(); float* vhat = v + conf_.dim_size(); float& beta1_power = vhat[conf_.dim_size()]; float& beta2_power = vhat[conf_.dim_size() + 1]; for (int i = 0; i < conf_.dim_size(); ++i) { m[i] = v[i] = vhat[i] = 0; } beta1_power = conf_.beta1(); beta2_power = conf_.beta2(); } void Optimize(void* ctx, absl::Span num, absl::Span grad, absl::Span learning_rates, const int64_t global_step) const override { float* m = static_cast(ctx); float* v = m + conf_.dim_size(); float* vhat = v + conf_.dim_size(); float& beta1_power = vhat[conf_.dim_size()]; float& beta2_power = vhat[conf_.dim_size() + 1]; float lr = learning_rates[0] * sqrt(1 - beta2_power) / (1 - beta1_power); for (int i = 0; i < conf_.dim_size(); ++i) { float cur_grad = grad[i] + conf_.weight_decay_factor() * num[i]; float new_m = m[i] + (cur_grad - m[i]) * (1 - conf_.beta1()); float new_v = v[i] + (cur_grad * cur_grad - v[i]) * (1 - conf_.beta2()); float new_vhat = std::max(vhat[i], new_v); float new_w = num[i]; if (conf_.use_nesterov()) { new_w -= ((cur_grad * (1 - conf_.beta1()) + conf_.beta1() * new_m) * lr) / (sqrt(new_vhat) + conf_.epsilon()); } else { new_w -= (new_m * lr) / (sqrt(new_vhat) + conf_.epsilon()); } num[i] = new_w; m[i] = new_m; v[i] = new_v; vhat[i] = new_vhat; } beta1_power *= conf_.beta1(); beta2_power *= conf_.beta2(); } OptimizerDump Save(const void* ctx) const override { OptimizerDump dump; AmsgradOptimizerDump* amsgrad_dump = dump.add_dump()->mutable_amsgrad(); const float* m = static_cast(ctx); const float* v = m + conf_.dim_size(); const float* vhat = v + conf_.dim_size(); const float& beta1_power = vhat[conf_.dim_size()]; const float& beta2_power = vhat[conf_.dim_size() + 1]; for (int i = 0; i < conf_.dim_size(); ++i) { amsgrad_dump->add_m(m[i]); amsgrad_dump->add_v(v[i]); amsgrad_dump->add_vhat(vhat[i]); } amsgrad_dump->set_beta1_power(beta1_power); amsgrad_dump->set_beta2_power(beta2_power); return dump; } void Restore(void* ctx, OptimizerDump dump) const override { const AmsgradOptimizerDump& amsgrad_dump = dump.dump(0).amsgrad(); float* m = static_cast(ctx); float* v = m + conf_.dim_size(); float* vhat = v + conf_.dim_size(); float& beta1_power = vhat[conf_.dim_size()]; float& beta2_power = vhat[conf_.dim_size() + 1]; for (int i = 0; i < conf_.dim_size(); ++i) { m[i] = amsgrad_dump.m(i); v[i] = amsgrad_dump.v(i); vhat[i] = amsgrad_dump.vhat(i); } beta1_power = amsgrad_dump.beta1_power(); beta2_power = amsgrad_dump.beta2_power(); } private: AmsgradOptimizerConfig conf_; }; } // namespace std::unique_ptr NewAmsgradOptimizer( AmsgradOptimizerConfig config) { return std::make_unique(std::move(config)); } } // namespace hash_table } // namespace monolith ================================================ FILE: monolith/native_training/runtime/hash_table/optimizer/amsgrad_optimizer.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_OPTIMIZER_AMSGRAD_OPTIMIZER #define MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_OPTIMIZER_AMSGRAD_OPTIMIZER #include "monolith/native_training/runtime/hash_table/optimizer/optimizer.pb.h" #include "monolith/native_training/runtime/hash_table/optimizer/optimizer_interface.h" namespace monolith { namespace hash_table { std::unique_ptr NewAmsgradOptimizer( AmsgradOptimizerConfig config); } // namespace hash_table } // namespace monolith #endif // MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_OPTIMIZER_AMSGRAD_OPTIMIZER ================================================ FILE: monolith/native_training/runtime/hash_table/optimizer/amsgrad_optimizer_test.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/runtime/hash_table/optimizer/amsgrad_optimizer.h" #include #include "gmock/gmock.h" #include "gtest/gtest.h" #include "monolith/native_training/runtime/hash_table/optimizer/optimizer.pb.h" #include "monolith/native_training/runtime/hash_table/optimizer/test_utils.h" namespace monolith { namespace hash_table { namespace { using ::testing::Pointwise; using ::testing::FloatNear; using ::testing::ElementsAreArray; TEST(AmsgradOptimizer, Basic) { AmsgradOptimizerConfig config; config.set_dim_size(1); auto opt = NewAmsgradOptimizer(config); TestOptimizerEntry mem(opt.get()); opt->Init(mem.mutable_ctx()); opt->Optimize(mem.mutable_ctx(), mem.mutable_num_span(), {10.0f}, {0.01f}); auto expected = {-0.00990099f}; ASSERT_THAT(mem.num(), Pointwise(FloatNear(1e-6), expected)); // Test dump & restore OptimizerDump dump = opt->Save(mem.ctx()); EXPECT_NEAR(dump.dump(0).amsgrad().m(0), 1, 1e-4); TestOptimizerEntry mem2(opt.get()); opt->Restore(mem2.mutable_ctx(), dump); *mem2.mutable_num() = mem.num(); opt->Optimize(mem.mutable_ctx(), mem.mutable_num_span(), {10.0f}, {0.01f}); auto expected2 = {-0.01983060f}; ASSERT_THAT(mem.num(), Pointwise(FloatNear(1e-6), expected2)); } TEST(AmsgradOptimizer, ListUpdate) { AmsgradOptimizerConfig config; config.set_dim_size(2); auto opt = NewAmsgradOptimizer(config); TestOptimizerEntry mem(opt.get()); opt->Init(mem.mutable_ctx()); opt->Optimize(mem.mutable_ctx(), mem.mutable_num_span(), {10.0f, 1.0f}, {0.01f}); auto expected = {-0.00990099f, -0.00909091f}; ASSERT_THAT(mem.num(), Pointwise(FloatNear(1e-6), expected)); // Test dump & restore OptimizerDump dump = opt->Save(mem.ctx()); EXPECT_NEAR(dump.dump(0).amsgrad().m(0), 1, 1e-4); EXPECT_NEAR(dump.dump(0).amsgrad().m(1), .1, 1e-4); TestOptimizerEntry mem2(opt.get()); opt->Restore(mem2.mutable_ctx(), dump); *mem2.mutable_num() = mem.num(); opt->Optimize(mem.mutable_ctx(), mem.mutable_num_span(), {10.0f, 1.0f}, {0.01f}); auto expected2 = {-0.01983060f, -0.01842895f}; ASSERT_THAT(mem.num(), Pointwise(FloatNear(1e-6), expected2)); } } // namespace } // namespace hash_table } // namespace monolith ================================================ FILE: monolith/native_training/runtime/hash_table/optimizer/avx_benchmark.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // // Created by david on 2020-11-27. // #include "absl/random/random.h" #include "benchmark/benchmark.h" #include "monolith/native_training/runtime/common/cpu_info.h" #include "monolith/native_training/runtime/hash_table/optimizer/avx_utils.h" namespace monolith { namespace hash_table { namespace { void BM_AdagradOptimize(benchmark::State& state) { // NOLINT size_t dim = state.range(0); float lr = 0.01f; std::vector norm(dim, 0), grad(dim, 0); absl::BitGen bit_gen; for (size_t i = 0; i < dim; ++i) { norm[i] = absl::Uniform(bit_gen, .1f, 1.f); grad[i] = absl::Uniform(bit_gen, -1.f, 1.f); } std::vector result(dim, 0); for (auto _ : state) { BaselineAdagradOptimize(result.data(), norm.data(), grad.data(), dim, lr, 0.01); } } void BM_AVXAdagradOptimize(benchmark::State& state) { // NOLINT RunCPUGuard(); size_t dim = state.range(0); float lr = 0.01f; std::vector norm(dim, 0), grad(dim, 0); absl::BitGen bit_gen; for (size_t i = 0; i < dim; ++i) { norm[i] = absl::Uniform(bit_gen, .1f, 1.f); grad[i] = absl::Uniform(bit_gen, -1.f, 1.f); } std::vector result(dim, 0); for (auto _ : state) { Avx256AdagradOptimize(result.data(), norm.data(), grad.data(), dim, lr, 0.01); } } BENCHMARK(BM_AdagradOptimize)->Arg(16)->Arg(64)->Arg(256); BENCHMARK(BM_AVXAdagradOptimize)->Arg(16)->Arg(64)->Arg(256); } // namespace } // namespace hash_table } // namespace monolith BENCHMARK_MAIN(); ================================================ FILE: monolith/native_training/runtime/hash_table/optimizer/avx_test.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // // Created by david on 2020-11-27. // #include #include "absl/random/random.h" #include "gtest/gtest.h" #include "gmock/gmock.h" #include "monolith/native_training/runtime/hash_table/optimizer/avx_utils.h" namespace monolith { namespace hash_table { namespace { void TestAdagradOptimize(size_t dim = 32) { float lr = 0.01f; std::vector norm(dim, 0), grad(dim, 0); absl::BitGen bit_gen; for (size_t i = 0; i < dim; ++i) { norm[i] = absl::Uniform(bit_gen, .1f, 1.f); grad[i] = absl::Uniform(bit_gen, -1.f, 1.f); } std::vector norm2(norm.begin(), norm.end()), grad2(grad.begin(), grad.end()); std::vector result(dim, 0), result_avx(dim, 0); BaselineAdagradOptimize(result.data(), norm.data(), grad.data(), dim, lr, 0.1f); #if defined(_ENABLE_AVX) && defined(__AVX__) Avx256AdagradOptimize(result_avx.data(), norm2.data(), grad2.data(), dim, lr, 0.1f); #else static_assert(false, "AVX is not available, please check and recompile!"); #endif for (size_t i = 0; i < dim; ++i) { EXPECT_NEAR(result[i], result_avx[i], 1e-6); } } TEST(AVX, Basic) { TestAdagradOptimize(1); TestAdagradOptimize(7); TestAdagradOptimize(8); TestAdagradOptimize(16); TestAdagradOptimize(32); TestAdagradOptimize(39); TestAdagradOptimize(224); } } // namespace } // namespace hash_table } // namespace monolith ================================================ FILE: monolith/native_training/runtime/hash_table/optimizer/avx_utils.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // // Created by david on 2020-11-27. // #ifndef MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_OPTIMIZER_AVX_UTILS #define MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_OPTIMIZER_AVX_UTILS #if defined(_ENABLE_AVX) && defined(__AVX__) #include #endif namespace monolith { namespace hash_table { inline void BaselineAdagradOptimize(float* num, float* norm, const float* grad, size_t len, float lr, float w_decay) { for (size_t i = 0; i < len; ++i) { float g = grad[i] + w_decay * num[i]; norm[i] += g * g; float effective_lr = lr / std::sqrt(norm[i]); num[i] -= effective_lr * g; } } inline float BaselineGetGroupNorm(float* num, float* norm, const float* grad, float* zero, size_t len, float effective_lr) { float group_zt_norm = 0; for (size_t i = 0; i < len; ++i) { auto norm_new = norm[i] + grad[i] * grad[i]; auto sigma = (std::sqrt(norm_new) - std::sqrt(norm[i])) / effective_lr; zero[i] += (grad[i] - sigma * num[i]); norm[i] = norm_new; group_zt_norm += zero[i] * zero[i]; } return group_zt_norm; } inline void BaselineSetWeightsWithGroupNorm(float group_zt_norm, float* num, float* norm, float* zero, size_t len, float effective_lr, float l1_regularization_strength, float l2_regularization_strength, float beta) { if (group_zt_norm < l1_regularization_strength) { for (size_t i = 0; i < len; ++i) { num[i] = 0; } } else { float normwise = (l1_regularization_strength - group_zt_norm) / group_zt_norm; for (size_t i = 0; i < len; ++i) { num[i] = effective_lr * zero[i] * normwise / (beta + std::sqrt(norm[i]) + l2_regularization_strength * effective_lr); } } } inline void BaselineGroupFTRLOptimize(float* num, float* norm, const float* grad, float* zero, size_t len, float effective_lr, float l1_regularization_strength, float l2_regularization_strength, float beta) { float group_zt_norm = BaselineGetGroupNorm(num, norm, grad, zero, len, effective_lr); group_zt_norm = std::abs(std::sqrt(group_zt_norm)); BaselineSetWeightsWithGroupNorm(group_zt_norm, num, norm, zero, len, effective_lr, l1_regularization_strength, l2_regularization_strength, beta); } inline void BaseReduceSum(const float* a, const float* b, float* output, size_t len) { for (size_t i = 0; i < len; ++i) { output[i] = a[i] + b[i]; } } #if defined(_ENABLE_AVX) && defined(__AVX__) inline void Avx256AdagradOptimize(float* num, float* norm, const float* grad, size_t len, float lr, float w_decay) { const __m256 lamda = _mm256_set1_ps(w_decay); const __m256 _lr = _mm256_set1_ps(lr); // OPTIMIZE: Loads floating-point vector from an aligned memory address for (; len > 7; len -= 8, num += 8, norm += 8, grad += 8) { const __m256 _num = _mm256_loadu_ps(num); const __m256 _norm = _mm256_loadu_ps(norm); const __m256 _grad = _mm256_loadu_ps(grad); const __m256 updated_grad = _mm256_fmadd_ps(lamda, _num, _grad); __m256 _norm_new = _mm256_fmadd_ps(updated_grad, updated_grad, _norm); _mm256_storeu_ps(norm, _norm_new); const __m256 _norm_new_sqrt = _mm256_sqrt_ps(_norm_new); const __m256 effective_lr = _mm256_div_ps(_lr, _norm_new_sqrt); __m256 _num_new = _mm256_fnmadd_ps(effective_lr, _grad, _num); _mm256_storeu_ps(num, _num_new); } if (len) { BaselineAdagradOptimize(num, norm, grad, len, lr, w_decay); } } // horizontal sum of mm256 inline float sum8(__m256 x) { // hiQuad = ( x7, x6, x5, x4 ) const __m128 hiQuad = _mm256_extractf128_ps(x, 1); // loQuad = ( x3, x2, x1, x0 ) const __m128 loQuad = _mm256_castps256_ps128(x); // sumQuad = ( x3 + x7, x2 + x6, x1 + x5, x0 + x4 ) const __m128 sumQuad = _mm_add_ps(loQuad, hiQuad); // loDual = ( -, -, x1 + x5, x0 + x4 ) const __m128 loDual = sumQuad; // hiDual = ( -, -, x3 + x7, x2 + x6 ) const __m128 hiDual = _mm_movehl_ps(sumQuad, sumQuad); // sumDual = ( -, -, x1 + x3 + x5 + x7, x0 + x2 + x4 + x6 ) const __m128 sumDual = _mm_add_ps(loDual, hiDual); // lo = ( -, -, -, x0 + x2 + x4 + x6 ) const __m128 lo = sumDual; // hi = ( -, -, -, x1 + x3 + x5 + x7 ) const __m128 hi = _mm_shuffle_ps(sumDual, sumDual, 0x1); // sum = ( -, -, -, x0 + x1 + x2 + x3 + x4 + x5 + x6 + x7 ) const __m128 sum = _mm_add_ss(lo, hi); return _mm_cvtss_f32(sum); } inline void Avx256GroupFTRLOptimize(float* num, float* norm, const float* grad, float* zero, size_t len, float effective_lr, float l1_regularization_strength, float l2_regularization_strength, float beta) { const __m256 _lr = _mm256_set1_ps(effective_lr); float group_zt_norm = 0.0; float* numCopy = num; size_t lenCopy = len; float* normCopy = norm; float* zeroCopy = zero; for (; len > 7; len -= 8, num += 8, norm += 8, grad += 8, zero += 8) { const __m256 _group_zt_norm = _mm256_set1_ps(0.0); const __m256 _num = _mm256_loadu_ps(num); const __m256 _norm = _mm256_loadu_ps(norm); const __m256 _grad = _mm256_loadu_ps(grad); const __m256 _zero = _mm256_loadu_ps(zero); const __m256 _new_norm = _mm256_fmadd_ps(_grad, _grad, _norm); const __m256 _norm_new_sqrt = _mm256_sqrt_ps(_new_norm); const __m256 _norm_sqrt = _mm256_sqrt_ps(_norm); __m256 _sigma = _mm256_sub_ps(_norm_new_sqrt, _norm_sqrt); _sigma = _mm256_div_ps(_sigma, _lr); const __m256 _add_zero = _mm256_fnmadd_ps(_sigma, _num, _grad); const __m256 _new_zero = _mm256_add_ps(_zero, _add_zero); _mm256_storeu_ps(zero, _new_zero); _mm256_storeu_ps(norm, _new_norm); group_zt_norm += sum8(_mm256_fmadd_ps(_new_zero, _new_zero, _group_zt_norm)); } if (len) { group_zt_norm += BaselineGetGroupNorm(num, norm, grad, zero, len, effective_lr); } group_zt_norm = std::abs(std::sqrt(group_zt_norm)); if (group_zt_norm < l1_regularization_strength) { for (; lenCopy > 7; lenCopy -= 8, numCopy += 8) { _mm256_storeu_ps(numCopy, _mm256_set1_ps(0.0)); } } else { const __m256 _normwise = _mm256_set1_ps( (l1_regularization_strength - group_zt_norm) / group_zt_norm); const __m256 _l2_regularization = _mm256_set1_ps(l2_regularization_strength); const __m256 _beta = _mm256_set1_ps(beta); for (; lenCopy > 7; lenCopy -= 8, numCopy += 8, normCopy += 8, zeroCopy += 8) { const __m256 _norm = _mm256_loadu_ps(normCopy); const __m256 _zero = _mm256_loadu_ps(zeroCopy); const __m256 _sqrt_norm = _mm256_sqrt_ps(_norm); __m256 _denom = _mm256_fnmadd_ps(_l2_regularization, _lr, _sqrt_norm); _denom = _mm256_add_ps(_denom, _beta); __m256 _numer = _mm256_mul_ps(_lr, _zero); _numer = _mm256_mul_ps(_numer, _normwise); __m256 _new_num = _mm256_div_ps(_numer, _denom); _mm256_storeu_ps(numCopy, _new_num); } } if (lenCopy) { BaselineSetWeightsWithGroupNorm( group_zt_norm, numCopy, normCopy, zeroCopy, lenCopy, effective_lr, l1_regularization_strength, l2_regularization_strength, beta); } } inline void Avx256ReduceSum(const float* a, const float* b, float* output, size_t len) { for (; len > 7; len -= 8, a += 8, b += 8, output += 8) { const __m256 _a = _mm256_loadu_ps(a); const __m256 _b = _mm256_loadu_ps(b); const __m256 _output = _mm256_add_ps(_a, _b); _mm256_storeu_ps(output, _output); } if (len) { BaseReduceSum(a, b, output, len); } } #endif inline void AdagradOptimize(float* num, float* norm, const float* grad, size_t len, float lr, float w_decay) { #if defined(_ENABLE_AVX) && defined(__AVX__) Avx256AdagradOptimize(num, norm, grad, len, lr, w_decay); #else BaselineAdagradOptimize(num, norm, grad, len, lr, w_decay); #endif } inline void ReduceSum(const float* a, const float* b, float* output, size_t len) { #if defined(_ENABLE_AVX) && defined(__AVX__) Avx256ReduceSum(a, b, output, len); #else BaseReduceSum(a, b, output, len); #endif } } // namespace hash_table } // namespace monolith #endif // MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_OPTIMIZER_AVX_UTILS ================================================ FILE: monolith/native_training/runtime/hash_table/optimizer/batch_softmax_optimizer.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/runtime/hash_table/optimizer/batch_softmax_optimizer.h" #include "absl/strings/str_format.h" #include "glog/logging.h" namespace monolith { namespace hash_table { namespace { class BatchSoftmaxOptimizer : public OptimizerInterface { public: explicit BatchSoftmaxOptimizer(BatchSoftmaxOptimizerConfig config) : config_(std::move(config)) { DCHECK_EQ(config_.dim_size(), 1); } int64_t SizeBytes() const override { return sizeof(int64_t); } int64_t UncompressedSizeBytes() const override { return SizeBytes(); } std::string DebugString() const override { return absl::StrFormat("BatchSoftmax(D=%d)", DimSize()); } int DimSize() const override { return config_.dim_size(); } int SliceSize() const override { return 1; } void Init(void* ctx) const override { auto* A = reinterpret_cast(ctx); *A = 0; } void Optimize(void* ctx, absl::Span num, absl::Span grad, absl::Span learning_rates, const int64_t global_step) const override { float& B = num[0]; int64_t& A = *reinterpret_cast(ctx); float alpha = learning_rates[0]; B = (1 - alpha) * B + alpha * static_cast(global_step - A); if (global_step < 0) { LOG(FATAL) << absl::StrFormat( "global_step=%ld is negative, please investigate!", global_step); } A = global_step; } OptimizerDump Save(const void* ctx) const override { OptimizerDump dump; BatchSoftmaxOptimizerDump* batch_softmax_dump = dump.add_dump()->mutable_batch_softmax(); int64_t A = *reinterpret_cast(ctx); batch_softmax_dump->set_global_step(A); return dump; } void Restore(void* ctx, OptimizerDump dump) const override { const BatchSoftmaxOptimizerDump& batch_softmax_dump = dump.dump(0).batch_softmax(); int64_t& A = *reinterpret_cast(ctx); A = batch_softmax_dump.global_step(); } private: BatchSoftmaxOptimizerConfig config_; }; } // namespace std::unique_ptr NewBatchSoftmaxOptimizer( BatchSoftmaxOptimizerConfig config) { return std::make_unique(std::move(config)); } } // namespace hash_table } // namespace monolith ================================================ FILE: monolith/native_training/runtime/hash_table/optimizer/batch_softmax_optimizer.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_OPTIMIZER_BATCH_SOFTMAX_OPTIMIZER_H_ #define MONOLITH_MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_OPTIMIZER_BATCH_SOFTMAX_OPTIMIZER_H_ #include "monolith/native_training/runtime/hash_table/optimizer/optimizer.pb.h" #include "monolith/native_training/runtime/hash_table/optimizer/optimizer_interface.h" namespace monolith { namespace hash_table { std::unique_ptr NewBatchSoftmaxOptimizer( BatchSoftmaxOptimizerConfig config); } // namespace hash_table } // namespace monolith #endif // MONOLITH_MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_OPTIMIZER_BATCH_SOFTMAX_OPTIMIZER_H_ ================================================ FILE: monolith/native_training/runtime/hash_table/optimizer/batch_softmax_optimizer_test.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/runtime/hash_table/optimizer/batch_softmax_optimizer.h" #include #include "gmock/gmock.h" #include "gtest/gtest.h" #include "monolith/native_training/runtime/hash_table/optimizer/optimizer.pb.h" #include "monolith/native_training/runtime/hash_table/optimizer/test_utils.h" namespace monolith { namespace hash_table { namespace { using ::testing::ElementsAreArray; using ::testing::FloatNear; using ::testing::Pointwise; TEST(BatchSoftmaxOptimizer, Basic) { BatchSoftmaxOptimizerConfig config; config.set_dim_size(1); auto opt = NewBatchSoftmaxOptimizer(config); TestOptimizerEntry mem(opt.get()); int64_t global_step = 1; opt->Init(mem.mutable_ctx()); opt->Optimize(mem.mutable_ctx(), mem.mutable_num_span(), {2.0f}, {0.1f}, global_step); EXPECT_FLOAT_EQ(mem.num().front(), 0.1f); // Test dump & restore OptimizerDump dump = opt->Save(mem.ctx()); TestOptimizerEntry mem2(opt.get()); opt->Restore(mem2.mutable_ctx(), dump); *mem2.mutable_num() = mem.num(); opt->Optimize(mem.mutable_ctx(), mem.mutable_num_span(), {2.0f}, {0.1f}); opt->Optimize(mem2.mutable_ctx(), mem2.mutable_num_span(), {2.0f}, {0.1f}); EXPECT_THAT(mem.num(), ElementsAreArray(mem2.num())); } } // namespace } // namespace hash_table } // namespace monolith ================================================ FILE: monolith/native_training/runtime/hash_table/optimizer/dc_optimizer.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include "monolith/native_training/runtime/hash_table/optimizer/dc_optimizer.h" namespace monolith { namespace hash_table { namespace { class DcOptimizer : public OptimizerDecorator { public: explicit DcOptimizer(DcOptimizerConfig config, std::unique_ptr base_opt) : OptimizerDecorator(std::move(base_opt)), conf_(std::move(config)) { } void OptimizeWithLatestValue(void* ctx, absl::Span num, absl::Span grad, absl::Span learning_rates, absl::Span latest_value, const int64_t global_step) const override { std::vector compensated_g(conf_.dim_size()); // add in Float16 stuff later? for (int i = 0; i < conf_.dim_size(); ++i) { float new_grad = grad[i] + conf_.lambda_() * grad[i] * grad[i] * (num[i] - latest_value[i]); compensated_g[i] = new_grad; } base_opt_.get()->Optimize(ctx, num, compensated_g, learning_rates, global_step); } private: DcOptimizerConfig conf_; }; } // namespace std::unique_ptr NewDcOptimizer( DcOptimizerConfig config, std::unique_ptr base_opt) { return std::make_unique(std::move(config), std::move(base_opt)); } } // namespace hash_table } // namespace monolith ================================================ FILE: monolith/native_training/runtime/hash_table/optimizer/dc_optimizer.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_OPTIMIZER_DC_OPTIMIZER #define MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_OPTIMIZER_DC_OPTIMIZER #include "monolith/native_training/runtime/hash_table/optimizer/optimizer.pb.h" #include "monolith/native_training/runtime/hash_table/optimizer/optimizer_interface.h" #include "monolith/native_training/runtime/hash_table/optimizer/optimizer_decorator.h" namespace monolith { namespace hash_table { std::unique_ptr NewDcOptimizer( DcOptimizerConfig config, std::unique_ptr base_opt); } // namespace hash_table } // namespace monolith #endif // MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_OPTIMIZER_DC_OPTIMIZER ================================================ FILE: monolith/native_training/runtime/hash_table/optimizer/dc_optimizer_test.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "gmock/gmock.h" #include "gtest/gtest.h" #include "monolith/native_training/runtime/hash_table/optimizer/optimizer.pb.h" #include "monolith/native_training/runtime/hash_table/optimizer/test_utils.h" #include "monolith/native_training/runtime/hash_table/optimizer/dc_optimizer.h" #include "monolith/native_training/runtime/hash_table/optimizer/adadelta_optimizer.h" namespace monolith { namespace hash_table { namespace { using ::testing::Pointwise; using ::testing::FloatNear; using ::testing::ElementsAreArray; TEST(DcOptimizer, Basic) { AdadeltaOptimizerConfig config1; config1.set_dim_size(1); auto opt1 = NewAdadeltaOptimizer(config1); DcOptimizerConfig config2; config2.set_dim_size(1); config2.set_lambda_(0.1f); auto opt2 = NewDcOptimizer(config2, std::move(opt1)); TestOptimizerEntry mem(opt2.get()); opt2->Init(mem.mutable_ctx()); float arr[] = {0.1f}; opt2->OptimizeWithLatestValue(mem.mutable_ctx(), mem.mutable_num_span(), {10.0f}, {0.01f}, arr); auto expected = {-0.0031603f}; ASSERT_THAT(mem.num(), Pointwise(FloatNear(1e-6), expected)); // Test dump & restore OptimizerDump dump = opt2->Save(mem.ctx()); EXPECT_NEAR(dump.dump(0).adadelta().accum(0), 8.1, 1e-4); TestOptimizerEntry mem2(opt2.get()); opt2->Restore(mem2.mutable_ctx(), dump); *mem2.mutable_num() = mem.num(); arr[0] = 0.0f; opt2->OptimizeWithLatestValue(mem.mutable_ctx(), mem.mutable_num_span(), {10.0f}, {0.01f}, arr); auto expected2 = {-0.0065548f}; ASSERT_THAT(mem.num(), Pointwise(FloatNear(1e-6), expected2)); } TEST(DcOptimizer, ListUpdate) { AdadeltaOptimizerConfig config1; config1.set_dim_size(2); auto opt1 = NewAdadeltaOptimizer(config1); DcOptimizerConfig config2; config2.set_dim_size(2); config2.set_lambda_(0.1f); auto opt2 = NewDcOptimizer(config2, std::move(opt1)); TestOptimizerEntry mem(opt2.get()); opt2->Init(mem.mutable_ctx()); float arr[] = {0.1f, 0.1f}; opt2->OptimizeWithLatestValue(mem.mutable_ctx(), mem.mutable_num_span(), {10.0f, 1.0f}, {0.01f}, arr); auto expected = {-0.0031603f, -0.00301233f}; ASSERT_THAT(mem.num(), Pointwise(FloatNear(1e-6), expected)); // Test dump & restore OptimizerDump dump = opt2->Save(mem.ctx()); EXPECT_NEAR(dump.dump(0).adadelta().accum(0), 8.1, 1e-4); EXPECT_NEAR(dump.dump(0).adadelta().accum(1), .09801, 1e-4); TestOptimizerEntry mem2(opt2.get()); opt2->Restore(mem2.mutable_ctx(), dump); *mem2.mutable_num() = mem.num(); arr[0] = 0.0f; arr[1] = 0.0f; opt2->OptimizeWithLatestValue(mem.mutable_ctx(), mem.mutable_num_span(), {10.0f, 1.0f}, {0.01f}, arr); auto expected2 = {-0.0065548f, -0.00611400f}; ASSERT_THAT(mem.num(), Pointwise(FloatNear(1e-6), expected2)); } } // namespace } // namespace hash_table } // namespace monolith ================================================ FILE: monolith/native_training/runtime/hash_table/optimizer/dynamic_wd_adagrad_optimizer.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/runtime/hash_table/optimizer/dynamic_wd_adagrad_optimizer.h" #include #include "absl/strings/str_format.h" namespace monolith { namespace hash_table { std::unique_ptr NewDynamicWdAdagradOptimizer( DynamicWdAdagradOptimizerConfig config) { throw std::invalid_argument(absl::StrFormat( "optimizer is not implemented yet. %s", config.ShortDebugString())); } } // namespace hash_table } // namespace monolith ================================================ FILE: monolith/native_training/runtime/hash_table/optimizer/dynamic_wd_adagrad_optimizer.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_OPTIMIZER_DYNAMIC_WD_ADAGRAD_OPTIMIZER #define MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_OPTIMIZER_DYNAMIC_WD_ADAGRAD_OPTIMIZER #include "monolith/native_training/runtime/hash_table/optimizer/optimizer.pb.h" #include "monolith/native_training/runtime/hash_table/optimizer/optimizer_interface.h" namespace monolith { namespace hash_table { std::unique_ptr NewDynamicWdAdagradOptimizer( DynamicWdAdagradOptimizerConfig config); } // namespace hash_table } // namespace monolith #endif // MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_OPTIMIZER_DYNAMIC_WD_ADAGRAD_OPTIMIZER ================================================ FILE: monolith/native_training/runtime/hash_table/optimizer/dynamic_wd_adagrad_optimizer_test.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/runtime/hash_table/optimizer/dynamic_wd_adagrad_optimizer.h" #include #include "gmock/gmock.h" #include "gtest/gtest.h" #include "monolith/native_training/runtime/hash_table/optimizer/optimizer.pb.h" #include "monolith/native_training/runtime/hash_table/optimizer/test_utils.h" namespace monolith { namespace hash_table { namespace { using ::testing::Pointwise; using ::testing::FloatNear; using ::testing::ElementsAreArray; } // namespace } // namespace hash_table } // namespace monolith ================================================ FILE: monolith/native_training/runtime/hash_table/optimizer/dynamic_wd_avx_test.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // // Created by david on 2020-11-27. // #include #include "absl/random/random.h" #include "gtest/gtest.h" #include "gmock/gmock.h" #include "monolith/native_training/runtime/hash_table/optimizer/dynamic_wd_avx_utils.h" namespace monolith { namespace hash_table { namespace { void TestDynamicWdAdagradOptimize(size_t dim = 32, int step = 1, bool decouple_wd = false) { float lr = 0.01f; std::vector norm(dim, 0), grad(dim, 0); absl::BitGen bit_gen; for (size_t i = 0; i < dim; ++i) { norm[i] = absl::Uniform(bit_gen, .1f, 1.f); grad[i] = absl::Uniform(bit_gen, -1.f, 1.f); } std::vector norm2(norm.begin(), norm.end()), grad2(grad.begin(), grad.end()); std::vector result(dim, 0), result_avx(dim, 0); for (int i = 0; i < step; ++i) { BaselineDynamicWdAdagradOptimize(result.data(), norm.data(), grad.data(), dim, lr, 0.1f, decouple_wd); #if defined(_ENABLE_AVX) && defined(__AVX__) if (decouple_wd) { Avx256DynamicWdAdagradOptimizeDecoupleWd(result_avx.data(), norm2.data(), grad2.data(), dim, lr, 0.1f); } else { Avx256DynamicWdAdagradOptimize(result_avx.data(), norm2.data(), grad2.data(), dim, lr, 0.1f); } #else static_assert(false, "AVX is not available, please check and recompile!"); #endif } for (size_t i = 0; i < dim; ++i) { EXPECT_NEAR(result[i], result_avx[i], 1e-6); } } TEST(AVX, Basic) { TestDynamicWdAdagradOptimize(1); TestDynamicWdAdagradOptimize(7); TestDynamicWdAdagradOptimize(8); TestDynamicWdAdagradOptimize(16); TestDynamicWdAdagradOptimize(32); TestDynamicWdAdagradOptimize(39); TestDynamicWdAdagradOptimize(224); } TEST(AVX, DecoupleWd) { TestDynamicWdAdagradOptimize(1, /*step=*/ 5, true); TestDynamicWdAdagradOptimize(7, /*step=*/ 5, true); TestDynamicWdAdagradOptimize(8, /*step=*/ 5, true); TestDynamicWdAdagradOptimize(16, /*step=*/ 5, true); TestDynamicWdAdagradOptimize(32, /*step=*/ 5, true); TestDynamicWdAdagradOptimize(39, /*step=*/ 5, true); TestDynamicWdAdagradOptimize(224, /*step=*/ 5, true); } } // namespace } // namespace hash_table } // namespace monolith ================================================ FILE: monolith/native_training/runtime/hash_table/optimizer/dynamic_wd_avx_utils.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // // Created by david on 2020-11-27. // #ifndef MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_OPTIMIZER_DYNAMIC_WD_AVX_UTILS #define MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_OPTIMIZER_DYNAMIC_WD_AVX_UTILS #if defined(_ENABLE_AVX) && defined(__AVX__) #include #endif namespace monolith { namespace hash_table { inline void BaselineDynamicWdAdagradOptimize(float* num, float* norm, const float* grad, size_t len, float lr, float w_decay, bool decouple_wd) { for (size_t i = 0; i < len; ++i) { float g = grad[i]; if (!decouple_wd) { g += w_decay * num[i]; } norm[i] += g * g; float effective_lr = lr / std::sqrt(norm[i]); float grad_update = effective_lr * g; if (decouple_wd) { grad_update += lr * w_decay * num[i]; } num[i] -= grad_update; } } inline void BaseReduceSum(const float* a, const float* b, float* output, size_t len) { for (size_t i = 0; i < len; ++i) { output[i] = a[i] + b[i]; } } #if defined(_ENABLE_AVX) && defined(__AVX__) inline void Avx256DynamicWdAdagradOptimize(float* num, float* norm, const float* grad, size_t len, float lr, float w_decay) { const __m256 lamda = _mm256_set1_ps(w_decay); float lrs[8] = {lr, lr, lr, lr, lr, lr, lr, lr}; const __m256 _lr = _mm256_loadu_ps(lrs); // OPTIMIZE: Loads floating-point vector from an aligned memory address for (; len > 7; len -= 8, num += 8, norm += 8, grad += 8) { const __m256 _num = _mm256_loadu_ps(num); const __m256 _norm = _mm256_loadu_ps(norm); const __m256 _grad = _mm256_loadu_ps(grad); const __m256 updated_grad = _mm256_fmadd_ps(lamda, _num, _grad); __m256 _norm_new = _mm256_fmadd_ps(updated_grad, updated_grad, _norm); _mm256_storeu_ps(norm, _norm_new); const __m256 _norm_new_sqrt = _mm256_sqrt_ps(_norm_new); const __m256 effective_lr = _mm256_div_ps(_lr, _norm_new_sqrt); __m256 _num_new = _mm256_fnmadd_ps(effective_lr, _grad, _num); _mm256_storeu_ps(num, _num_new); } if (len) { BaselineDynamicWdAdagradOptimize(num, norm, grad, len, lr, w_decay, false); } } inline void Avx256DynamicWdAdagradOptimizeDecoupleWd(float* num, float* norm, const float* grad, size_t len, float lr, float w_decay) { const __m256 lamda = _mm256_set1_ps(w_decay); float lrs[8] = {lr, lr, lr, lr, lr, lr, lr, lr}; const __m256 _lr = _mm256_loadu_ps(lrs); // OPTIMIZE: Loads floating-point vector from an aligned memory address for (; len > 7; len -= 8, num += 8, norm += 8, grad += 8) { const __m256 _num = _mm256_loadu_ps(num); const __m256 _norm = _mm256_loadu_ps(norm); const __m256 _grad = _mm256_loadu_ps(grad); __m256 _norm_new = _mm256_fmadd_ps(_grad, _grad, _norm); _mm256_storeu_ps(norm, _norm_new); const __m256 _norm_new_sqrt = _mm256_sqrt_ps(_norm_new); const __m256 effective_lr = _mm256_div_ps(_lr, _norm_new_sqrt); __m256 _num_new = _mm256_fnmadd_ps(effective_lr, _grad, _num); const __m256 effective_wd = _mm256_mul_ps(_lr, lamda); __m256 _num_after_wd = _mm256_fnmadd_ps(effective_wd, _num, _num_new); _mm256_storeu_ps(num, _num_after_wd); } if (len) { BaselineDynamicWdAdagradOptimize(num, norm, grad, len, lr, w_decay, true); } } inline void Avx256ReduceSum(const float* a, const float* b, float* output, size_t len) { for (; len > 7; len -= 8, a += 8, b += 8, output += 8) { const __m256 _a = _mm256_loadu_ps(a); const __m256 _b = _mm256_loadu_ps(b); const __m256 _output = _mm256_add_ps(_a, _b); _mm256_storeu_ps(output, _output); } if (len) { BaseReduceSum(a, b, output, len); } } #endif inline void DynamicWdAdagradOptimize(float* num, float* norm, const float* grad, size_t len, float lr, float w_decay, bool decouple_wd = false) { #if defined(_ENABLE_AVX) && defined(__AVX__) if (decouple_wd) { Avx256DynamicWdAdagradOptimizeDecoupleWd(num, norm, grad, len, lr, w_decay); } else { Avx256DynamicWdAdagradOptimize(num, norm, grad, len, lr, w_decay); } #else BaselineDynamicWdAdagradOptimize( num, norm, grad, len, lr, w_decay, decouple_wd); #endif } inline void ReduceSum(const float* a, const float* b, float* output, size_t len) { #if defined(_ENABLE_AVX) && defined(__AVX__) Avx256ReduceSum(a, b, output, len); #else BaseReduceSum(a, b, output, len); #endif } } // namespace hash_table } // namespace monolith #endif // MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_OPTIMIZER_DYNAMIC_WD_AVX_UTILS ================================================ FILE: monolith/native_training/runtime/hash_table/optimizer/ftrl_optimizer.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include "absl/strings/str_format.h" #include "monolith/native_training/runtime/hash_table/optimizer/ftrl_optimizer.h" namespace monolith { namespace hash_table { namespace { class FtrlOptimizer : public OptimizerInterface { public: explicit FtrlOptimizer(FtrlOptimizerConfig config) : conf_(std::move(config)) {} // We need both Zero and Norm in the opt state. int64_t SizeBytes() const override { return 2 * conf_.dim_size() * sizeof(float); } int64_t UncompressedSizeBytes() const override { return SizeBytes(); } std::string DebugString() const override { return absl::StrFormat("Ftrl(D=%d)", DimSize()); } int DimSize() const override { return conf_.dim_size(); } int SliceSize() const override { return 1; } void Init(void* ctx) const override { float* norm = static_cast(ctx); float* zero = norm + conf_.dim_size(); for (int i = 0; i < conf_.dim_size(); ++i) { norm[i] = conf_.initial_accumulator_value(); zero[i] = 0; } } // Please refer to this link for the algorithm: // https://www.eecs.tufts.edu/~dsculley/papers/ad-click-prediction.pdf void Optimize(void* ctx, absl::Span num, absl::Span grad, absl::Span learning_rates, const int64_t global_step) const override { float* norm = static_cast(ctx); float* zero = norm + conf_.dim_size(); float effective_lr = learning_rates[0]; for (int i = 0; i < conf_.dim_size(); ++i) { auto norm_new = norm[i] + grad[i] * grad[i]; auto sigma = (std::sqrt(norm_new) - std::sqrt(norm[i])) / effective_lr; zero[i] += (grad[i] - sigma * num[i]); norm[i] = norm_new; num[i] = (std::abs(zero[i]) > conf_.l1_regularization_strength()) ? effective_lr * (std::signbit(zero[i]) * conf_.l1_regularization_strength() - zero[i]) / (std::sqrt(norm[i]) + conf_.beta() + conf_.l2_regularization_strength() * effective_lr) : 0.0; } } OptimizerDump Save(const void* ctx) const override { OptimizerDump dump; FtrlOptimizerDump* ftrl_dump = dump.add_dump()->mutable_ftrl(); const float* norm = static_cast(ctx); const float* zero = norm + conf_.dim_size(); for (int i = 0; i < conf_.dim_size(); ++i) { ftrl_dump->add_norm(norm[i]); ftrl_dump->add_zero(zero[i]); } return dump; } void Restore(void* ctx, OptimizerDump dump) const override { const FtrlOptimizerDump& ftrl_dump = dump.dump(0).ftrl(); float* norm = static_cast(ctx); float* zero = norm + conf_.dim_size(); for (int i = 0; i < conf_.dim_size(); ++i) { norm[i] = ftrl_dump.norm(i); zero[i] = ftrl_dump.zero(i); } } private: FtrlOptimizerConfig conf_; }; } // namespace std::unique_ptr NewFtrlOptimizer( FtrlOptimizerConfig config) { return std::make_unique(std::move(config)); } } // namespace hash_table } // namespace monolith ================================================ FILE: monolith/native_training/runtime/hash_table/optimizer/ftrl_optimizer.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_OPTIMIZER_FTRL_OPTIMIZER #define MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_OPTIMIZER_FTRL_OPTIMIZER #include "monolith/native_training/runtime/hash_table/optimizer/optimizer.pb.h" #include "monolith/native_training/runtime/hash_table/optimizer/optimizer_interface.h" namespace monolith { namespace hash_table { std::unique_ptr NewFtrlOptimizer( FtrlOptimizerConfig config); } // namespace hash_table } // namespace monolith #endif // MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_OPTIMIZER_FTRL_OPTIMIZER ================================================ FILE: monolith/native_training/runtime/hash_table/optimizer/ftrl_optimizer_test.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/runtime/hash_table/optimizer/ftrl_optimizer.h" #include #include "gmock/gmock.h" #include "gtest/gtest.h" #include "monolith/native_training/runtime/hash_table/optimizer/optimizer.pb.h" #include "monolith/native_training/runtime/hash_table/optimizer/test_utils.h" namespace monolith { namespace hash_table { namespace { using ::testing::Pointwise; using ::testing::FloatNear; using ::testing::ElementsAreArray; TEST(FtrlOptimizer, Basic) { FtrlOptimizerConfig config; config.set_dim_size(1); auto opt = NewFtrlOptimizer(config); TestOptimizerEntry mem(opt.get()); opt->Init(mem.mutable_ctx()); opt->Optimize(mem.mutable_ctx(), mem.mutable_num_span(), {10.0f}, {0.01f}); auto expected = {-0.009995f}; ASSERT_THAT(mem.num(), Pointwise(FloatNear(1e-6), expected)); // Test dump & restore OptimizerDump dump = opt->Save(mem.ctx()); EXPECT_NEAR(dump.dump(0).ftrl().norm(0), 1e-6, 100.1); TestOptimizerEntry mem2(opt.get()); opt->Restore(mem2.mutable_ctx(), dump); *mem2.mutable_num() = mem.num(); opt->Optimize(mem.mutable_ctx(), mem.mutable_num_span(), {10.0f}, {0.01f}); auto expected2 = {-0.0170643f}; ASSERT_THAT(mem.num(), Pointwise(FloatNear(1e-6), expected2)); } TEST(FtrlOptimizer, ListUpdate) { FtrlOptimizerConfig config; config.set_dim_size(2); auto opt = NewFtrlOptimizer(config); TestOptimizerEntry mem(opt.get()); opt->Init(mem.mutable_ctx()); opt->Optimize(mem.mutable_ctx(), mem.mutable_num_span(), {10.0f, 1.0f}, {0.01f}); auto expected = {-0.009995f, -0.00953463f}; ASSERT_THAT(mem.num(), Pointwise(FloatNear(1e-6), expected)); // Test dump & restore OptimizerDump dump = opt->Save(mem.ctx()); EXPECT_NEAR(dump.dump(0).ftrl().norm(0), 1e-6, 100.1); EXPECT_NEAR(dump.dump(0).ftrl().norm(1), 1e-6, 1.1); TestOptimizerEntry mem2(opt.get()); opt->Restore(mem2.mutable_ctx(), dump); *mem2.mutable_num() = mem.num(); opt->Optimize(mem.mutable_ctx(), mem.mutable_num_span(), {10.0f, 1.0f}, {0.01f}); auto expected2 = {-0.0170643f, -0.0164353f}; ASSERT_THAT(mem.num(), Pointwise(FloatNear(1e-6), expected2)); } } // namespace } // namespace hash_table } // namespace monolith ================================================ FILE: monolith/native_training/runtime/hash_table/optimizer/group_adagrad_optimizer.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include "absl/strings/str_format.h" #include "monolith/native_training/runtime/hash_table/optimizer/group_adagrad_optimizer.h" namespace monolith { namespace hash_table { namespace { class GroupAdaGradOptimizer : public OptimizerInterface { public: explicit GroupAdaGradOptimizer(GroupAdaGradOptimizerConfig config) : conf_(std::move(config)) {} // Only need 4 byte, grad_square_sum = sum(g_max^2) int64_t SizeBytes() const override { return sizeof(float); } int64_t UncompressedSizeBytes() const override { return conf_.dim_size() * sizeof(float); } std::string DebugString() const override { return absl::StrFormat("GroupAdaGrad(D=%d)", DimSize()); } int DimSize() const override { return conf_.dim_size(); } int SliceSize() const override { return 1; } void Init(void* ctx) const override { float* grad_square_sum = static_cast(ctx); *grad_square_sum = conf_.initial_accumulator_value(); } void Optimize(void* ctx, absl::Span num, absl::Span grad, absl::Span learning_rates, const int64_t global_step) const override { float* grad_square_sum = static_cast(ctx); float effective_lr = learning_rates[0]; float max_grad_square = 0.0; std::vector g_decayed; g_decayed.reserve(conf_.dim_size()); for (int i = 0; i < conf_.dim_size(); ++i) { // weight_decay float g = grad[i] + conf_.weight_decay_factor() * num[i]; if (g * g > max_grad_square) { max_grad_square = g * g; } g_decayed.push_back(g); } *grad_square_sum = *grad_square_sum + max_grad_square; float lr = effective_lr / (conf_.beta() + std::sqrt(*grad_square_sum)); float z_norm = 0.0; for (int i = 0; i < conf_.dim_size(); ++i) { num[i] = g_decayed[i] - num[i] / lr; z_norm += num[i] * num[i]; } z_norm = std::sqrt(z_norm); if (z_norm < conf_.l2_regularization_strength()) { for (int i = 0; i < conf_.dim_size(); ++i) { num[i] = 0; } } else { float coeffi = -lr * (z_norm - conf_.l2_regularization_strength()) / z_norm; for (int i = 0; i < conf_.dim_size(); ++i) { num[i] = coeffi * num[i]; } } } OptimizerDump Save(const void* ctx) const override { OptimizerDump dump; GroupAdaGradOptimizerDump* group_adagrad_dump = dump.add_dump()->mutable_group_adagrad(); const float* grad_square_sum = static_cast(ctx); group_adagrad_dump->set_grad_square_sum(*grad_square_sum); return dump; } void Restore(void* ctx, OptimizerDump dump) const override { const GroupAdaGradOptimizerDump& group_adagrad_dump = dump.dump(0).group_adagrad(); float* grad_square_sum = static_cast(ctx); *grad_square_sum = group_adagrad_dump.grad_square_sum(); } private: GroupAdaGradOptimizerConfig conf_; }; } // namespace std::unique_ptr NewGroupAdaGradOptimizer( GroupAdaGradOptimizerConfig config) { return std::make_unique(std::move(config)); } } // namespace hash_table } // namespace monolith ================================================ FILE: monolith/native_training/runtime/hash_table/optimizer/group_adagrad_optimizer.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_OPTIMIZER_GROUP_ADAGRAD_OPTIMIZER #define MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_OPTIMIZER_GROUP_ADAGRAD_OPTIMIZER #include "monolith/native_training/runtime/hash_table/optimizer/optimizer.pb.h" #include "monolith/native_training/runtime/hash_table/optimizer/optimizer_interface.h" namespace monolith { namespace hash_table { std::unique_ptr NewGroupAdaGradOptimizer( GroupAdaGradOptimizerConfig config); } // namespace hash_table } // namespace monolith #endif // MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_OPTIMIZER_GROUP_ADAGRAD_OPTIMIZER ================================================ FILE: monolith/native_training/runtime/hash_table/optimizer/group_adagrad_optimizer_test.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/runtime/hash_table/optimizer/group_adagrad_optimizer.h" #include #include "gmock/gmock.h" #include "gtest/gtest.h" #include "monolith/native_training/runtime/hash_table/optimizer/optimizer.pb.h" #include "monolith/native_training/runtime/hash_table/optimizer/test_utils.h" namespace monolith { namespace hash_table { namespace { using ::testing::ElementsAreArray; using ::testing::FloatNear; using ::testing::Pointwise; TEST(GroupAdaGradOptimizer, Basic) { GroupAdaGradOptimizerConfig config; config.set_dim_size(1); config.set_l2_regularization_strength(1.0); config.set_beta(1.0); config.set_initial_accumulator_value(0.0); config.set_weight_decay_factor(0.0); auto opt = NewGroupAdaGradOptimizer(config); TestOptimizerEntry mem(opt.get()); opt->Init(mem.mutable_ctx()); opt->Optimize(mem.mutable_ctx(), mem.mutable_num_span(), {10.0f}, {0.01f}); auto expected = {-0.008182f}; ASSERT_THAT(mem.num(), Pointwise(FloatNear(1e-6), expected)); // Test dump & restore OptimizerDump dump = opt->Save(mem.ctx()); EXPECT_NEAR(dump.dump(0).group_adagrad().grad_square_sum(), 1e-6, 100.0); TestOptimizerEntry mem2(opt.get()); opt->Restore(mem2.mutable_ctx(), dump); *mem2.mutable_num() = mem.num(); opt->Optimize(mem2.mutable_ctx(), mem2.mutable_num_span(), {10.0f}, {0.01f}); auto expected2 = {-0.014125f}; ASSERT_THAT(mem2.num(), Pointwise(FloatNear(1e-6), expected2)); } TEST(GroupAdaGradOptimizer, ListUpdate) { GroupAdaGradOptimizerConfig config; config.set_dim_size(2); config.set_l2_regularization_strength(0.5); config.set_beta(1.0); config.set_initial_accumulator_value(0.0); config.set_weight_decay_factor(0.0); auto opt = NewGroupAdaGradOptimizer(config); TestOptimizerEntry mem(opt.get()); opt->Init(mem.mutable_ctx()); opt->Optimize(mem.mutable_ctx(), mem.mutable_num_span(), {10.0f, 1.0f}, {0.01f}); auto expected = {-0.008639f, -0.000864f}; ASSERT_THAT(mem.num(), Pointwise(FloatNear(1e-6), expected)); // Test dump & restore OptimizerDump dump = opt->Save(mem.ctx()); EXPECT_NEAR(dump.dump(0).group_adagrad().grad_square_sum(), 1e-6, 100.0); TestOptimizerEntry mem2(opt.get()); opt->Restore(mem2.mutable_ctx(), dump); *mem2.mutable_num() = mem.num(); opt->Optimize(mem2.mutable_ctx(), mem2.mutable_num_span(), {1.0f, 5.0f}, {0.01f}); auto expected2 = {-0.009096f, -0.004778f}; ASSERT_THAT(mem2.num(), Pointwise(FloatNear(1e-6), expected2)); } TEST(GroupAdaGradOptimizer, ZeroLambda) { GroupAdaGradOptimizerConfig config; config.set_dim_size(2); config.set_l2_regularization_strength(0); config.set_beta(1.0); config.set_initial_accumulator_value(0.0); config.set_weight_decay_factor(0.0); auto opt = NewGroupAdaGradOptimizer(config); TestOptimizerEntry mem(opt.get()); opt->Init(mem.mutable_ctx()); opt->Optimize(mem.mutable_ctx(), mem.mutable_num_span(), {10.0f, 1.0f}, {0.01f}); auto expected = {-0.009091f, -0.000909f}; ASSERT_THAT(mem.num(), Pointwise(FloatNear(1e-6), expected)); } TEST(GroupAdaGradOptimizer, SetZero) { GroupAdaGradOptimizerConfig config; config.set_dim_size(2); config.set_l2_regularization_strength(1000); config.set_beta(1.0); config.set_initial_accumulator_value(0.0); config.set_weight_decay_factor(0.0); auto opt = NewGroupAdaGradOptimizer(config); TestOptimizerEntry mem(opt.get()); opt->Init(mem.mutable_ctx()); opt->Optimize(mem.mutable_ctx(), mem.mutable_num_span(), {10.0f, 1.0f}, {0.01f}); auto expected = {0.0f, 0.0f}; ASSERT_THAT(mem.num(), Pointwise(FloatNear(1e-6), expected)); } } // namespace } // namespace hash_table } // namespace monolith ================================================ FILE: monolith/native_training/runtime/hash_table/optimizer/group_ftrl_optimizer.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/runtime/hash_table/optimizer/group_ftrl_optimizer.h" #include #include "absl/strings/str_format.h" namespace monolith { namespace hash_table { std::unique_ptr NewGroupFtrlOptimizer( GroupFtrlOptimizerConfig config) { throw std::invalid_argument(absl::StrFormat( "optimizer is not implemented yet. %s", config.ShortDebugString())); } } // namespace hash_table } // namespace monolith ================================================ FILE: monolith/native_training/runtime/hash_table/optimizer/group_ftrl_optimizer.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_OPTIMIZER_GROUP_FTRL_OPTIMIZER #define MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_OPTIMIZER_GROUP_FTRL_OPTIMIZER #include "monolith/native_training/runtime/hash_table/optimizer/optimizer.pb.h" #include "monolith/native_training/runtime/hash_table/optimizer/optimizer_interface.h" namespace monolith { namespace hash_table { std::unique_ptr NewGroupFtrlOptimizer( GroupFtrlOptimizerConfig config); } // namespace hash_table } // namespace monolith #endif // MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_OPTIMIZER_GROUP_FTRL_OPTIMIZER ================================================ FILE: monolith/native_training/runtime/hash_table/optimizer/group_ftrl_optimizer_test.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/runtime/hash_table/optimizer/group_ftrl_optimizer.h" #include #include "gmock/gmock.h" #include "gtest/gtest.h" #include "monolith/native_training/runtime/hash_table/optimizer/optimizer.pb.h" #include "monolith/native_training/runtime/hash_table/optimizer/test_utils.h" namespace monolith { namespace hash_table { namespace { using ::testing::Pointwise; using ::testing::FloatNear; using ::testing::ElementsAreArray; } // namespace } // namespace hash_table } // namespace monolith ================================================ FILE: monolith/native_training/runtime/hash_table/optimizer/momentum_optimizer.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/runtime/hash_table/optimizer/momentum_optimizer.h" #include #include #include "absl/strings/str_format.h" namespace monolith { namespace hash_table { namespace { class MomentumOptimizer : public OptimizerInterface { public: explicit MomentumOptimizer(MomentumOptimizerConfig config) : conf_(std::move(config)) {} int64_t SizeBytes() const override { return (conf_.dim_size()) * sizeof(float); } int64_t UncompressedSizeBytes() const override { return SizeBytes(); } std::string DebugString() const override { return absl::StrFormat("Momentum(D=%d)", DimSize()); } int DimSize() const override { return conf_.dim_size(); } int SliceSize() const override { return 1; } void Init(void* ctx) const override { float* n = static_cast(ctx); for (int i = 0; i < conf_.dim_size(); ++i) { n[i] = 0; } } void Optimize(void* ctx, absl::Span num, absl::Span grad, absl::Span learning_rates, const int64_t global_step) const override { float* n = static_cast(ctx); float g_total = 0; for (int i = 0; i < conf_.dim_size(); ++i) { float dx = learning_rates[0] * (grad[i] + conf_.weight_decay_factor() * num[i]); float new_n = n[i]; float new_w = num[i]; if (conf_.use_nesterov()) { float prev_n = new_n; new_n = conf_.momentum() * new_n - dx; new_w += -conf_.momentum() * prev_n + (1 + conf_.momentum()) * new_n; } else { new_n = conf_.momentum() * new_n - dx; new_w += new_n; } n[i] = new_n; num[i] = new_w; g_total += new_n; } } OptimizerDump Save(const void* ctx) const override { OptimizerDump dump; MomentumOptimizerDump* momentum_dump = dump.add_dump()->mutable_momentum(); const float* n = static_cast(ctx); for (int i = 0; i < conf_.dim_size(); ++i) { momentum_dump->add_n(n[i]); } return dump; } void Restore(void* ctx, OptimizerDump dump) const override { const MomentumOptimizerDump& momentum_dump = dump.dump(0).momentum(); float* n = static_cast(ctx); for (int i = 0; i < conf_.dim_size(); ++i) { n[i] = momentum_dump.n(i); } } private: MomentumOptimizerConfig conf_; }; } // namespace std::unique_ptr NewMomentumOptimizer( MomentumOptimizerConfig config) { return std::make_unique(std::move(config)); } } // namespace hash_table } // namespace monolith ================================================ FILE: monolith/native_training/runtime/hash_table/optimizer/momentum_optimizer.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_OPTIMIZER_MOMENTUM_OPTIMIZER #define MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_OPTIMIZER_MOMENTUM_OPTIMIZER #include "monolith/native_training/runtime/hash_table/optimizer/optimizer.pb.h" #include "monolith/native_training/runtime/hash_table/optimizer/optimizer_interface.h" namespace monolith { namespace hash_table { std::unique_ptr NewMomentumOptimizer( MomentumOptimizerConfig config); } // namespace hash_table } // namespace monolith #endif // MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_OPTIMIZER_MOMENTUM_OPTIMIZER ================================================ FILE: monolith/native_training/runtime/hash_table/optimizer/momentum_optimizer_test.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/runtime/hash_table/optimizer/momentum_optimizer.h" #include #include "gmock/gmock.h" #include "gtest/gtest.h" #include "monolith/native_training/runtime/hash_table/optimizer/optimizer.pb.h" #include "monolith/native_training/runtime/hash_table/optimizer/test_utils.h" namespace monolith { namespace hash_table { namespace { using ::testing::Pointwise; using ::testing::FloatNear; using ::testing::ElementsAreArray; TEST(MomentumOptimizer, Basic) { MomentumOptimizerConfig config; config.set_dim_size(1); auto opt = NewMomentumOptimizer(config); TestOptimizerEntry mem(opt.get()); opt->Init(mem.mutable_ctx()); opt->Optimize(mem.mutable_ctx(), mem.mutable_num_span(), {10.0f}, {0.01f}); auto expected = {-0.1f}; ASSERT_THAT(mem.num(), Pointwise(FloatNear(1e-6), expected)); // Test dump & restore OptimizerDump dump = opt->Save(mem.ctx()); EXPECT_NEAR(dump.dump(0).momentum().n(0), -0.1, 1e-4); TestOptimizerEntry mem2(opt.get()); opt->Restore(mem2.mutable_ctx(), dump); *mem2.mutable_num() = mem.num(); opt->Optimize(mem.mutable_ctx(), mem.mutable_num_span(), {10.0f}, {0.01f}); auto expected2 = {-0.29f}; ASSERT_THAT(mem.num(), Pointwise(FloatNear(1e-6), expected2)); } TEST(MomentumOptimizer, ListUpdate) { MomentumOptimizerConfig config; config.set_dim_size(2); auto opt = NewMomentumOptimizer(config); TestOptimizerEntry mem(opt.get()); opt->Init(mem.mutable_ctx()); opt->Optimize(mem.mutable_ctx(), mem.mutable_num_span(), {10.0f, 1.0f}, {0.01f}); auto expected = {-0.1f, -0.01f}; ASSERT_THAT(mem.num(), Pointwise(FloatNear(1e-6), expected)); // Test dump & restore OptimizerDump dump = opt->Save(mem.ctx()); EXPECT_NEAR(dump.dump(0).momentum().n(0), -0.1, 1e-4); EXPECT_NEAR(dump.dump(0).momentum().n(1), -0.01, 1e-4); TestOptimizerEntry mem2(opt.get()); opt->Restore(mem2.mutable_ctx(), dump); *mem2.mutable_num() = mem.num(); opt->Optimize(mem.mutable_ctx(), mem.mutable_num_span(), {10.0f, 1.0f}, {0.01f}); auto expected2 = {-0.29f, -0.029f}; ASSERT_THAT(mem.num(), Pointwise(FloatNear(1e-6), expected2)); } } // namespace } // namespace hash_table } // namespace monolith ================================================ FILE: monolith/native_training/runtime/hash_table/optimizer/moving_average_optimizer.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/runtime/hash_table/optimizer/moving_average_optimizer.h" #include #include #include "absl/strings/str_format.h" namespace monolith { namespace hash_table { namespace { class MovingAverageOptimizer : public OptimizerInterface { public: explicit MovingAverageOptimizer(MovingAverageOptimizerConfig config) : conf_(std::move(config)) {} int64_t SizeBytes() const override { return 0; } int64_t UncompressedSizeBytes() const override { return SizeBytes(); } std::string DebugString() const override { return absl::StrFormat("MovingAverage(D=%d)", DimSize()); } int DimSize() const override { return conf_.dim_size(); } int SliceSize() const override { return 1; } void Init(void* ctx) const override {} void Optimize(void* ctx, absl::Span num, absl::Span grad, absl::Span learning_rates, const int64_t global_step) const override { for (int i = 0; i < conf_.dim_size(); ++i) { float new_w = conf_.momentum() * num[i] + (1 - conf_.momentum()) * grad[i]; num[i] = new_w; } } OptimizerDump Save(const void* ctx) const override { OptimizerDump dump; return dump; } void Restore(void* ctx, OptimizerDump dump) const override {} private: MovingAverageOptimizerConfig conf_; }; } // namespace std::unique_ptr NewMovingAverageOptimizer( MovingAverageOptimizerConfig config) { return std::make_unique(std::move(config)); } } // namespace hash_table } // namespace monolith ================================================ FILE: monolith/native_training/runtime/hash_table/optimizer/moving_average_optimizer.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_OPTIMIZER_MOVING_AVERAGE_OPTIMIZER #define MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_OPTIMIZER_MOVING_AVERAGE_OPTIMIZER #include "monolith/native_training/runtime/hash_table/optimizer/optimizer.pb.h" #include "monolith/native_training/runtime/hash_table/optimizer/optimizer_interface.h" namespace monolith { namespace hash_table { std::unique_ptr NewMovingAverageOptimizer( MovingAverageOptimizerConfig config); } // namespace hash_table } // namespace monolith #endif // MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_OPTIMIZER_MOVING_AVERAGE_OPTIMIZER ================================================ FILE: monolith/native_training/runtime/hash_table/optimizer/moving_average_optimizer_test.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/runtime/hash_table/optimizer/moving_average_optimizer.h" #include #include "gmock/gmock.h" #include "gtest/gtest.h" #include "monolith/native_training/runtime/hash_table/optimizer/optimizer.pb.h" #include "monolith/native_training/runtime/hash_table/optimizer/test_utils.h" namespace monolith { namespace hash_table { namespace { using ::testing::Pointwise; using ::testing::FloatNear; using ::testing::ElementsAreArray; TEST(MovingAverageOptimizer, Basic) { MovingAverageOptimizerConfig config; config.set_dim_size(1); auto opt = NewMovingAverageOptimizer(config); TestOptimizerEntry mem(opt.get()); opt->Init(mem.mutable_ctx()); opt->Optimize(mem.mutable_ctx(), mem.mutable_num_span(), {10.0f}, {0.01f}); auto expected = {1.0f}; ASSERT_THAT(mem.num(), Pointwise(FloatNear(1e-6), expected)); // Test dump & restore OptimizerDump dump = opt->Save(mem.ctx()); // EXPECT_NEAR(dump.dump(0).momentum().n(0), -0.1, 1e-4); TestOptimizerEntry mem2(opt.get()); opt->Restore(mem2.mutable_ctx(), dump); *mem2.mutable_num() = mem.num(); opt->Optimize(mem.mutable_ctx(), mem.mutable_num_span(), {10.0f}, {0.01f}); auto expected2 = {1.9f}; ASSERT_THAT(mem.num(), Pointwise(FloatNear(1e-6), expected2)); } TEST(MovingAverageOptimizer, ListUpdate) { MovingAverageOptimizerConfig config; config.set_dim_size(2); auto opt = NewMovingAverageOptimizer(config); TestOptimizerEntry mem(opt.get()); opt->Init(mem.mutable_ctx()); opt->Optimize(mem.mutable_ctx(), mem.mutable_num_span(), {10.0f, 1.0f}, {0.01f}); auto expected = {1.0f, 0.1f}; ASSERT_THAT(mem.num(), Pointwise(FloatNear(1e-6), expected)); // Test dump & restore OptimizerDump dump = opt->Save(mem.ctx()); // EXPECT_NEAR(dump.dump(0).momentum().n(0), -0.1, 1e-4); // EXPECT_NEAR(dump.dump(0).momentum().n(1), -0.01, 1e-4); TestOptimizerEntry mem2(opt.get()); opt->Restore(mem2.mutable_ctx(), dump); *mem2.mutable_num() = mem.num(); opt->Optimize(mem.mutable_ctx(), mem.mutable_num_span(), {10.0f, 1.0f}, {0.01f}); auto expected2 = {1.9f, 0.19f}; ASSERT_THAT(mem.num(), Pointwise(FloatNear(1e-6), expected2)); } } // namespace } // namespace hash_table } // namespace monolith ================================================ FILE: monolith/native_training/runtime/hash_table/optimizer/optimizer.proto ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. syntax = "proto2"; package monolith.hash_table; message AdagradOptimizerConfig { optional int32 dim_size = 1; optional float learning_rate = 2 [default = 0.001]; optional float initial_accumulator_value = 3 [default = 0.1]; optional int32 hessian_compression_times = 4 [default = 1]; optional float weight_decay_factor = 5 [default = 0.]; optional int64 warmup_steps = 6 [default = 0]; } message AdagradOptimizerDump { repeated float norm = 1; } message DynamicWdAdagradOptimizerConfig { optional int32 dim_size = 1; optional float learning_rate = 2 [default = 0.001]; optional float initial_accumulator_value = 3 [default = 0.1]; optional int32 hessian_compression_times = 4 [default = 1]; optional float weight_decay_factor = 5 [default = 0.]; optional int64 warmup_steps = 6 [default = 0]; optional bool decouple_weight_decay = 7 [default = false]; optional bool enable_dynamic_wd = 8 [default = false]; optional float dynamic_wd_temperature = 9 [default = 1.0]; optional bool flip_direction = 10 [default = false]; } message DynamicWdAdagradOptimizerDump { repeated float norm = 1; optional int64 last_update_step = 2; } message SgdOptimizerConfig { optional int32 dim_size = 1; optional float learning_rate = 2 [default = 0.01]; optional int64 warmup_steps = 6 [default = 0]; } message SgdOptimizerDump { } message FtrlOptimizerConfig { optional int32 dim_size = 1; optional float learning_rate = 2 [default = 0.01]; optional float beta = 3 [default = 0.0]; optional float initial_accumulator_value = 4 [default = 0.1]; optional float l1_regularization_strength = 5 [default = 0.0]; optional float l2_regularization_strength = 6 [default = 0.0]; optional int64 warmup_steps = 7 [default = 0]; } message FtrlOptimizerDump { repeated float zero = 1; repeated float norm = 2; } message GroupFtrlOptimizerConfig { optional int32 dim_size = 1; optional float learning_rate = 2 [default = 0.01]; optional float beta = 3 [default = 1.0]; optional float initial_accumulator_value = 4 [default = 0.0]; optional float l1_regularization_strength = 5 [default = 0.0]; optional float l2_regularization_strength = 6 [default = 0.0]; optional int64 warmup_steps = 7 [default = 0]; } message GroupFtrlOptimizerDump { repeated float zero = 1; repeated float norm = 2; } message GroupAdaGradOptimizerConfig { optional int32 dim_size = 1; optional float learning_rate = 2 [default = 0.01]; optional float beta = 3 [default = 0.0]; optional float initial_accumulator_value = 4 [default = 0.1]; optional float l2_regularization_strength = 5 [default = 0.0]; optional float weight_decay_factor = 6 [default = 0.0]; optional int64 warmup_steps = 7 [default = 0]; } message GroupAdaGradOptimizerDump { optional float grad_square_sum = 1; } message AdadeltaOptimizerConfig { optional int32 dim_size = 1; optional float learning_rate = 2 [default = 0.01]; optional float weight_decay_factor = 3 [default = 0.]; optional float averaging_ratio = 4 [default = 0.9]; optional float epsilon = 5 [default = 0.01]; optional int64 warmup_steps = 7 [default = 0]; } message AdadeltaOptimizerDump { repeated float accum = 1; repeated float accum_update = 2; } message AdamOptimizerConfig { optional int32 dim_size = 1; optional float learning_rate = 2 [default = 0.01]; optional float beta1 = 3 [default = 0.9]; optional float beta2 = 4 [default = 0.99]; optional bool use_beta1_warmup = 5 [default = false]; optional float weight_decay_factor = 6 [default = 0.]; optional bool use_nesterov = 7 [default = false]; optional float epsilon = 8 [default = 0.01]; optional int64 warmup_steps = 9 [default = 0]; } message AdamOptimizerDump { repeated float m = 1; repeated float v = 2; optional float beta1_power = 3; optional float beta2_power = 4; } message AmsgradOptimizerConfig { optional int32 dim_size = 1; optional float learning_rate = 2 [default = 0.01]; optional float beta1 = 3 [default = 0.9]; optional float beta2 = 4 [default = 0.99]; optional float weight_decay_factor = 6 [default = 0.]; optional bool use_nesterov = 7 [default = false]; optional float epsilon = 8 [default = 0.01]; optional int64 warmup_steps = 9 [default = 0]; } message AmsgradOptimizerDump { repeated float m = 1; repeated float v = 2; repeated float vhat = 3; optional float beta1_power = 4; optional float beta2_power = 5; } message MomentumOptimizerConfig { optional int32 dim_size = 1; optional float learning_rate = 2 [default = 0.01]; optional float weight_decay_factor = 3 [default = 0.]; optional bool use_nesterov = 4 [default = false]; optional float momentum = 5 [default = 0.9]; optional int64 warmup_steps = 6 [default = 0]; } message MomentumOptimizerDump { repeated float n = 1; } message MovingAverageOptimizerConfig { optional int32 dim_size = 1; optional float momentum = 2 [default = 0.9]; } message BatchSoftmaxOptimizerConfig { optional int32 dim_size = 1; optional float learning_rate = 2 [default = 0.1]; } message BatchSoftmaxOptimizerDump { optional int64 global_step = 1; } message RmspropOptimizerConfig { optional int32 dim_size = 1; optional float learning_rate = 2 [default = 0.01]; optional float weight_decay_factor = 3 [default = 0.]; optional float momentum = 4 [default = 0.9]; } message RmspropOptimizerDump { repeated float n = 1; } message RmspropV2OptimizerConfig { optional int32 dim_size = 1; optional float learning_rate = 2 [default = 0.01]; optional float weight_decay_factor = 3 [default = 0.]; optional float momentum = 4 [default = 0.9]; } message RmspropV2OptimizerDump { repeated float n = 1; } message DcOptimizerConfig { optional int32 dim_size = 1; optional float lambda_ = 2 [default = 0.]; } message OptimizerConfig { oneof type { AdagradOptimizerConfig adagrad = 1; SgdOptimizerConfig sgd = 2; FtrlOptimizerConfig ftrl = 3; DynamicWdAdagradOptimizerConfig dynamic_wd_adagrad = 5; AdadeltaOptimizerConfig adadelta = 6; AdamOptimizerConfig adam = 7; AmsgradOptimizerConfig amsgrad = 8; MomentumOptimizerConfig momentum = 9; MovingAverageOptimizerConfig moving_average = 10; RmspropOptimizerConfig rmsprop = 11; RmspropV2OptimizerConfig rmspropv2 = 12; DcOptimizerConfig dc = 13; GroupFtrlOptimizerConfig group_ftrl = 14; BatchSoftmaxOptimizerConfig batch_softmax = 15; GroupAdaGradOptimizerConfig group_adagrad = 16; } optional bool stochastic_rounding_float16 = 4; // Default false. } message SingleOptimizerDump { oneof type { AdagradOptimizerDump adagrad = 1; SgdOptimizerDump sgd = 2; FtrlOptimizerDump ftrl = 3; DynamicWdAdagradOptimizerDump dynamic_wd_adagrad = 4; AdadeltaOptimizerDump adadelta = 6; AdamOptimizerDump adam = 7; AmsgradOptimizerDump amsgrad = 8; MomentumOptimizerDump momentum = 9; RmspropOptimizerDump rmsprop = 11; RmspropV2OptimizerDump rmspropv2 = 12; GroupFtrlOptimizerDump group_ftrl = 13; BatchSoftmaxOptimizerDump batch_softmax = 14; GroupAdaGradOptimizerDump group_adagrad = 15; } } // TODO(leqi.zou): Consider about adding Arena to improve the performance. message OptimizerDump { repeated SingleOptimizerDump dump = 1; } ================================================ FILE: monolith/native_training/runtime/hash_table/optimizer/optimizer_combination.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/runtime/hash_table/optimizer/optimizer_combination.h" #include "absl/algorithm/container.h" #include "absl/strings/str_format.h" #include "absl/types/span.h" namespace monolith { namespace hash_table { namespace { namespace proto2 = google::protobuf; class CombinedOptimizer : public OptimizerInterface { public: CombinedOptimizer(std::unique_ptr opt1, std::unique_ptr opt2) : opt1_(std::move(opt1)), size_bytes1_(opt1_->SizeBytes()), dim_size1_(opt1_->DimSize()), dump_size1_(GetOptDumpSize(opt1_.get())), slice_size1_(opt1_->SliceSize()), opt2_(std::move(opt2)) {} int64_t SizeBytes() const override { return opt1_->SizeBytes() + opt2_->SizeBytes(); } int64_t UncompressedSizeBytes() const override { return opt1_->UncompressedSizeBytes() + opt2_->UncompressedSizeBytes(); } std::string DebugString() const override { return absl::StrFormat("%s|%s", opt1_->DebugString(), opt2_->DebugString()); } int DimSize() const override { return opt1_->DimSize() + opt2_->DimSize(); } int SliceSize() const override { return opt1_->SliceSize() + opt2_->SliceSize(); } void Init(void* ctx) const override { void* ctx2 = static_cast(ctx) + size_bytes1_; opt1_->Init(ctx); opt2_->Init(ctx2); } void Optimize(void* ctx, absl::Span num, absl::Span grad, absl::Span learning_rates, const int64_t global_step) const override { void* ctx2 = static_cast(ctx) + size_bytes1_; auto num2 = num.subspan(dim_size1_); auto grad2 = grad.subspan(dim_size1_); auto learning_rates2 = learning_rates.subspan(slice_size1_); opt1_->Optimize(ctx, num, grad, learning_rates, global_step); opt2_->Optimize(ctx2, num2, grad2, learning_rates2, global_step); } OptimizerDump Save(const void* ctx) const override { OptimizerDump combined_dump; OptimizerDump dump1 = opt1_->Save(ctx); const void* ctx2 = static_cast(ctx) + size_bytes1_; OptimizerDump dump2 = opt2_->Save(ctx2); absl::c_move(*dump1.mutable_dump(), proto2::RepeatedFieldBackInserter( combined_dump.mutable_dump())); absl::c_move(*dump2.mutable_dump(), proto2::RepeatedFieldBackInserter( combined_dump.mutable_dump())); return combined_dump; } void Restore(void* ctx, OptimizerDump dump) const override { OptimizerDump dump1; for (int i = 0; i < dump_size1_; ++i) { *dump1.add_dump() = std::move(*dump.mutable_dump(i)); } OptimizerDump dump2; for (int i = dump_size1_; i < dump.dump_size(); ++i) { *dump2.add_dump() = std::move(*dump.mutable_dump(i)); } opt1_->Restore(ctx, std::move(dump1)); void* ctx2 = static_cast(ctx) + size_bytes1_; opt2_->Restore(ctx2, std::move(dump2)); } private: int GetOptDumpSize(OptimizerInterface* opt) { auto mem = std::make_unique(opt->SizeBytes()); opt->Init(mem.get()); OptimizerDump dump = opt->Save(mem.get()); return dump.dump_size(); } std::unique_ptr opt1_; const int64_t size_bytes1_; const int dim_size1_; const int dump_size1_; const int slice_size1_; std::unique_ptr opt2_; }; } // namespace std::unique_ptr CombineOptimizers( std::unique_ptr opt1, std::unique_ptr opt2) { return std::make_unique(std::move(opt1), std::move(opt2)); } } // namespace hash_table } // namespace monolith ================================================ FILE: monolith/native_training/runtime/hash_table/optimizer/optimizer_combination.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_OPTIMIZER_OPTIMIZER_COMBINATION #define MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_OPTIMIZER_OPTIMIZER_COMBINATION #include #include "monolith/native_training/runtime/hash_table/optimizer/optimizer_interface.h" namespace monolith { namespace hash_table { // A entry may be optimized by different optimizers so we need to combine two // optimizers. std::unique_ptr CombineOptimizers( std::unique_ptr opt1, std::unique_ptr opt2); } // namespace hash_table } // namespace monolith #endif // MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_OPTIMIZER_OPTIMIZER_COMBINATION ================================================ FILE: monolith/native_training/runtime/hash_table/optimizer/optimizer_combination_test.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/runtime/hash_table/optimizer/optimizer_combination.h" #include "gmock/gmock.h" #include "gtest/gtest.h" #include "monolith/native_training/runtime/hash_table/optimizer/adagrad_optimizer.h" #include "monolith/native_training/runtime/hash_table/optimizer/optimizer.pb.h" #include "monolith/native_training/runtime/hash_table/optimizer/test_utils.h" namespace monolith { namespace hash_table { namespace { using ::testing::Pointwise; using ::testing::FloatNear; TEST(CombineOptimizers, Basic) { AdagradOptimizerConfig config1; config1.set_dim_size(1); config1.set_initial_accumulator_value(1); auto opt1 = NewAdagradOptimizer(config1); AdagradOptimizerConfig config2; config2.set_dim_size(2); config2.set_initial_accumulator_value(2); auto opt2 = NewAdagradOptimizer(config2); auto combined_opt = CombineOptimizers(std::move(opt1), std::move(opt2)); OptimizerDump dump; { TestOptimizerEntry mem(combined_opt.get()); combined_opt->Init(mem.mutable_ctx()); combined_opt->Optimize(mem.mutable_ctx(), mem.mutable_num_span(), {1.0f, 2.0f, 3.0f}, {1.0f, 2.0f}); auto expected = {-0.70710677, -1.6329931, -1.8090681}; EXPECT_THAT(mem.num(), Pointwise(FloatNear(1e-6), expected)); dump = combined_opt->Save(mem.ctx()); } TestOptimizerEntry mem(combined_opt.get()); combined_opt->Restore(mem.mutable_ctx(), dump); combined_opt->Optimize(mem.mutable_ctx(), mem.mutable_num_span(), {1.0f, 2.0f, 3.0f}, {1.0f, 2.0f}); auto expected = {-0.57735026, -1.264911, -1.3416407}; EXPECT_THAT(mem.num(), Pointwise(FloatNear(1e-6), expected)); } } // namespace } // namespace hash_table } // namespace monolith ================================================ FILE: monolith/native_training/runtime/hash_table/optimizer/optimizer_decorator.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_OPTIMIZER_OPTIMIZER_DECORATOR #define MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_OPTIMIZER_OPTIMIZER_DECORATOR #include #include "absl/types/span.h" #include "monolith/native_training/runtime/hash_table/optimizer/optimizer.pb.h" #include "optimizer_interface.h" namespace monolith { namespace hash_table { class OptimizerDecorator : public OptimizerInterface { public: virtual ~OptimizerDecorator() = default; explicit OptimizerDecorator(std::unique_ptr base_opt) : base_opt_(std::move(base_opt)) {} int64_t SizeBytes() const override { return base_opt_.get()->SizeBytes(); } int64_t UncompressedSizeBytes() const override { return base_opt_.get()->UncompressedSizeBytes(); } std::string DebugString() const override { return base_opt_.get()->DebugString(); } int DimSize() const override { return base_opt_.get()->DimSize(); } int SliceSize() const override { return base_opt_.get()->SliceSize(); } void Init(void* ctx) const override { return base_opt_.get()->Init(ctx); } void Optimize(void* ctx, absl::Span num, absl::Span grad, absl::Span learning_rates, const int64_t global_step = 0) const override { return base_opt_.get()->Optimize(ctx, num, grad, learning_rates); } OptimizerDump Save(const void* ctx) const override { return base_opt_.get()->Save(ctx); } void Restore(void* ctx, OptimizerDump dump) const override { return base_opt_.get()->Restore(ctx, dump); } virtual void OptimizeWithLatestValue(void* ctx, absl::Span num, absl::Span grad, absl::Span learning_rates, absl::Span latest_value, const int64_t global_step = 0) const = 0; protected: std::unique_ptr base_opt_; }; } // namespace hash_table } // namespace monolith #endif // MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_OPTIMIZER_OPTIMIZER_DECORATOR ================================================ FILE: monolith/native_training/runtime/hash_table/optimizer/optimizer_factory.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "monolith/native_training/runtime/hash_table/optimizer/optimizer_factory.h" #include "absl/strings/str_format.h" #include "monolith/native_training/runtime/hash_table/optimizer/adadelta_optimizer.h" #include "monolith/native_training/runtime/hash_table/optimizer/adagrad_optimizer.h" #include "monolith/native_training/runtime/hash_table/optimizer/adam_optimizer.h" #include "monolith/native_training/runtime/hash_table/optimizer/amsgrad_optimizer.h" #include "monolith/native_training/runtime/hash_table/optimizer/batch_softmax_optimizer.h" #include "monolith/native_training/runtime/hash_table/optimizer/dynamic_wd_adagrad_optimizer.h" #include "monolith/native_training/runtime/hash_table/optimizer/ftrl_optimizer.h" #include "monolith/native_training/runtime/hash_table/optimizer/group_adagrad_optimizer.h" #include "monolith/native_training/runtime/hash_table/optimizer/group_ftrl_optimizer.h" #include "monolith/native_training/runtime/hash_table/optimizer/momentum_optimizer.h" #include "monolith/native_training/runtime/hash_table/optimizer/moving_average_optimizer.h" #include "monolith/native_training/runtime/hash_table/optimizer/rmsprop_optimizer.h" #include "monolith/native_training/runtime/hash_table/optimizer/sgd_optimizer.h" #include "monolith/native_training/runtime/hash_table/optimizer/stochastic_rounding.h" namespace monolith { namespace hash_table { std::unique_ptr NewOptimizerFromConfig( OptimizerConfig config) { std::unique_ptr opt = nullptr; switch (config.type_case()) { case OptimizerConfig::kAdagrad: opt = NewAdagradOptimizer(std::move(*config.mutable_adagrad())); break; case OptimizerConfig::kSgd: opt = NewSgdOptimizer(std::move(*config.mutable_sgd())); break; case OptimizerConfig::kFtrl: opt = NewFtrlOptimizer(std::move(*config.mutable_ftrl())); break; case OptimizerConfig::kDynamicWdAdagrad: opt = NewDynamicWdAdagradOptimizer( std::move(*config.mutable_dynamic_wd_adagrad())); break; case OptimizerConfig::kAdadelta: opt = NewAdadeltaOptimizer(std::move(*config.mutable_adadelta())); break; case OptimizerConfig::kAdam: opt = NewAdamOptimizer(std::move(*config.mutable_adam())); break; case OptimizerConfig::kAmsgrad: opt = NewAmsgradOptimizer(std::move(*config.mutable_amsgrad())); break; case OptimizerConfig::kMomentum: opt = NewMomentumOptimizer(std::move(*config.mutable_momentum())); break; case OptimizerConfig::kMovingAverage: opt = NewMovingAverageOptimizer( std::move(*config.mutable_moving_average())); break; case OptimizerConfig::kRmsprop: opt = NewRmspropOptimizer(std::move(*config.mutable_rmsprop())); break; case OptimizerConfig::kRmspropv2: opt = NewRmspropV2Optimizer(std::move(*config.mutable_rmspropv2())); break; case OptimizerConfig::kGroupFtrl: opt = NewGroupFtrlOptimizer(std::move(*config.mutable_group_ftrl())); break; case OptimizerConfig::kGroupAdagrad: opt = NewGroupAdaGradOptimizer(std::move(*config.mutable_group_adagrad())); break; case OptimizerConfig::kBatchSoftmax: opt = NewBatchSoftmaxOptimizer(std::move(*config.mutable_batch_softmax())); break; default: throw std::invalid_argument(absl::StrFormat( "optimizer is not implemented yet. %s", config.ShortDebugString())); } if (config.stochastic_rounding_float16()) { opt = std::make_unique( std::move(opt)); } return std::move(opt); } } // namespace hash_table } // namespace monolith ================================================ FILE: monolith/native_training/runtime/hash_table/optimizer/optimizer_factory.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_OPTIMIZER_OPTIMIZER_FACTORY #define MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_OPTIMIZER_OPTIMIZER_FACTORY #include "monolith/native_training/runtime/hash_table/optimizer/optimizer.pb.h" #include "monolith/native_training/runtime/hash_table/optimizer/optimizer_interface.h" namespace monolith { namespace hash_table { std::unique_ptr NewOptimizerFromConfig( OptimizerConfig config); } // namespace hash_table } // namespace monolith #endif // MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_OPTIMIZER_OPTIMIZER_FACTORY ================================================ FILE: monolith/native_training/runtime/hash_table/optimizer/optimizer_interface.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_OPTIMIZER_OPTIMIZER_INTERFACE #define MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_OPTIMIZER_OPTIMIZER_INTERFACE #include #include "absl/types/span.h" #include "monolith/native_training/runtime/hash_table/optimizer/optimizer.pb.h" namespace monolith { namespace hash_table { class OptimizerInterface { public: virtual ~OptimizerInterface() = default; // How many bytes are required for the optimizer virtual int64_t SizeBytes() const = 0; // How many bytes are required for the optimizer if not compressed. virtual int64_t UncompressedSizeBytes() const = 0; virtual std::string DebugString() const = 0; // The dim that this optimizer can support. virtual int DimSize() const = 0; // The slice size that this optimizer holds. virtual int SliceSize() const = 0; // Init optimizer ctx. // |num| is at least DimSize() long. virtual void Init(void* ctx) const = 0; // optimize the num based on gradients and the optimizer's data. // |num|, |grad| are float arrays whose length is at least DimSize(). virtual void Optimize(void* ctx, absl::Span num, absl::Span grad, absl::Span learning_rates, const int64_t global_step = 0) const = 0; // Save and restore the entry. virtual OptimizerDump Save(const void* ctx) const = 0; virtual void Restore(void* ctx, OptimizerDump dump) const = 0; }; } // namespace hash_table } // namespace monolith #endif // MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_OPTIMIZER_OPTIMIZER_INTERFACE ================================================ FILE: monolith/native_training/runtime/hash_table/optimizer/rmsprop_optimizer.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/runtime/hash_table/optimizer/rmsprop_optimizer.h" #include #include #include "absl/strings/str_format.h" namespace monolith { namespace hash_table { namespace { class RmspropOptimizer : public OptimizerInterface { public: explicit RmspropOptimizer(RmspropOptimizerConfig config) : conf_(std::move(config)) {} int64_t SizeBytes() const override { return (conf_.dim_size()) * sizeof(float); } int64_t UncompressedSizeBytes() const override { return SizeBytes(); } std::string DebugString() const override { return absl::StrFormat("Rmsprop(D=%d)", DimSize()); } int DimSize() const override { return conf_.dim_size(); } int SliceSize() const override { return 1; } void Init(void* ctx) const override { float* n = static_cast(ctx); for (int i = 0; i < conf_.dim_size(); ++i) { n[i] = 0; } } void Optimize(void* ctx, absl::Span num, absl::Span grad, absl::Span learning_rates, const int64_t global_step) const override { float* n = static_cast(ctx); for (int i = 0; i < conf_.dim_size(); ++i) { const float& cur_grad = grad[i]; float new_n = n[i]; float new_w = num[i]; double dx = cur_grad + static_cast(conf_.weight_decay_factor()) * new_w; new_n = static_cast(conf_.momentum()) * new_n + (1 - static_cast(conf_.momentum())) * dx * dx; double eta = static_cast(conf_.learning_rate()) / (std::sqrt(new_n) + 1); new_w -= eta * dx; n[i] = new_n; num[i] = new_w; } } OptimizerDump Save(const void* ctx) const override { OptimizerDump dump; RmspropOptimizerDump* rmsprop_dump = dump.add_dump()->mutable_rmsprop(); const float* n = static_cast(ctx); for (int i = 0; i < conf_.dim_size(); ++i) { rmsprop_dump->add_n(n[i]); } return dump; } void Restore(void* ctx, OptimizerDump dump) const override { const RmspropOptimizerDump& rmsprop_dump = dump.dump(0).rmsprop(); float* n = static_cast(ctx); for (int i = 0; i < conf_.dim_size(); ++i) { n[i] = rmsprop_dump.n(i); } } private: RmspropOptimizerConfig conf_; }; class RmspropV2Optimizer : public OptimizerInterface { public: explicit RmspropV2Optimizer(RmspropV2OptimizerConfig config) : conf_(std::move(config)) {} int64_t SizeBytes() const override { return (conf_.dim_size()) * sizeof(float); } int64_t UncompressedSizeBytes() const override { return SizeBytes(); } std::string DebugString() const override { return absl::StrFormat("RmspropV2(D=%d)", DimSize()); } int DimSize() const override { return conf_.dim_size(); } int SliceSize() const override { return 1; } void Init(void* ctx) const override { float* n = static_cast(ctx); for (int i = 0; i < conf_.dim_size(); ++i) { n[i] = 0; } } void Optimize(void* ctx, absl::Span num, absl::Span grad, absl::Span learning_rates, const int64_t global_step) const override { // TODO(eric.wei): implement RMSPropV2 using AVX OptimizeNormal(ctx, num, grad, learning_rates, global_step); } void OptimizeNormal(void* ctx, absl::Span num, absl::Span grad, absl::Span learning_rates, const int64_t global_step) const { float* n = static_cast(ctx); for (int i = 0; i < conf_.dim_size(); ++i) { const float& cur_grad = grad[i]; float new_n = n[i]; float new_w = num[i]; double dx = cur_grad + static_cast(conf_.weight_decay_factor()) * new_w; new_n = static_cast(conf_.momentum()) * new_n + dx * dx; double eta = static_cast(learning_rates[0]) / (std::sqrt(new_n) + 1); new_w -= eta * dx; n[i] = new_n; num[i] = new_w; } } OptimizerDump Save(const void* ctx) const override { OptimizerDump dump; RmspropV2OptimizerDump* rmspropv2_dump = dump.add_dump()->mutable_rmspropv2(); const float* n = static_cast(ctx); for (int i = 0; i < conf_.dim_size(); ++i) { rmspropv2_dump->add_n(n[i]); } return dump; } void Restore(void* ctx, OptimizerDump dump) const override { const RmspropV2OptimizerDump& rmspropv2_dump = dump.dump(0).rmspropv2(); float* n = static_cast(ctx); for (int i = 0; i < conf_.dim_size(); ++i) { n[i] = rmspropv2_dump.n(i); } } private: RmspropV2OptimizerConfig conf_; }; } // namespace std::unique_ptr NewRmspropOptimizer( RmspropOptimizerConfig config) { return std::make_unique(std::move(config)); } std::unique_ptr NewRmspropV2Optimizer( RmspropV2OptimizerConfig config) { return std::make_unique(std::move(config)); } } // namespace hash_table } // namespace monolith ================================================ FILE: monolith/native_training/runtime/hash_table/optimizer/rmsprop_optimizer.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_OPTIMIZER_RMSPROP_OPTIMIZER #define MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_OPTIMIZER_RMSPROP_OPTIMIZER #include "monolith/native_training/runtime/hash_table/optimizer/optimizer.pb.h" #include "monolith/native_training/runtime/hash_table/optimizer/optimizer_interface.h" namespace monolith { namespace hash_table { std::unique_ptr NewRmspropOptimizer( RmspropOptimizerConfig config); std::unique_ptr NewRmspropV2Optimizer( RmspropV2OptimizerConfig config); } // namespace hash_table } // namespace monolith #endif // MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_OPTIMIZER_RMSPROP_OPTIMIZER ================================================ FILE: monolith/native_training/runtime/hash_table/optimizer/rmsprop_optimizer_test.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/runtime/hash_table/optimizer/rmsprop_optimizer.h" #include #include "gmock/gmock.h" #include "gtest/gtest.h" #include "monolith/native_training/runtime/hash_table/optimizer/optimizer.pb.h" #include "monolith/native_training/runtime/hash_table/optimizer/test_utils.h" namespace monolith { namespace hash_table { namespace { using ::testing::Pointwise; using ::testing::FloatNear; using ::testing::ElementsAreArray; TEST(RmspropOptimizer, Basic) { RmspropOptimizerConfig config; config.set_dim_size(1); auto opt = NewRmspropOptimizer(config); TestOptimizerEntry mem(opt.get()); opt->Init(mem.mutable_ctx()); opt->Optimize(mem.mutable_ctx(), mem.mutable_num_span(), {10.0f}, {0.01f}); auto expected = {-0.024025f}; ASSERT_THAT(mem.num(), Pointwise(FloatNear(1e-6), expected)); // Test dump & restore OptimizerDump dump = opt->Save(mem.ctx()); EXPECT_NEAR(dump.dump(0).rmsprop().n(0), 10, 1e-4); TestOptimizerEntry mem2(opt.get()); opt->Restore(mem2.mutable_ctx(), dump); *mem2.mutable_num() = mem.num(); opt->Optimize(mem.mutable_ctx(), mem.mutable_num_span(), {10.0f}, {0.01f}); auto expected2 = {-0.042686f}; ASSERT_THAT(mem.num(), Pointwise(FloatNear(1e-6), expected2)); } TEST(RmspropV2Optimizer, ListUpdate) { RmspropV2OptimizerConfig config; config.set_dim_size(2); auto opt = NewRmspropV2Optimizer(config); TestOptimizerEntry mem(opt.get()); opt->Init(mem.mutable_ctx()); opt->Optimize(mem.mutable_ctx(), mem.mutable_num_span(), {10.0f, 1.0f}, {0.01f}); auto expected = {-0.0090909f, -0.005f}; ASSERT_THAT(mem.num(), Pointwise(FloatNear(1e-6), expected)); // Test dump & restore OptimizerDump dump = opt->Save(mem.ctx()); EXPECT_NEAR(dump.dump(0).rmspropv2().n(0), 100, 1e-4); EXPECT_NEAR(dump.dump(0).rmspropv2().n(1), 1, 1e-4); TestOptimizerEntry mem2(opt.get()); opt->Restore(mem2.mutable_ctx(), dump); *mem2.mutable_num() = mem.num(); opt->Optimize(mem.mutable_ctx(), mem.mutable_num_span(), {10.0f, 1.0f}, {0.01f}); auto expected2 = {-0.0158549f, -0.0092045f}; ASSERT_THAT(mem.num(), Pointwise(FloatNear(1e-6), expected2)); } } // namespace } // namespace hash_table } // namespace monolith ================================================ FILE: monolith/native_training/runtime/hash_table/optimizer/sgd_optimizer.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/runtime/hash_table/optimizer/sgd_optimizer.h" #include #include #include "absl/strings/str_format.h" namespace monolith { namespace hash_table { namespace { class SgdOptimizer : public OptimizerInterface { public: explicit SgdOptimizer(SgdOptimizerConfig config) : conf_(std::move(config)) {} int64_t SizeBytes() const override { return 0; } int64_t UncompressedSizeBytes() const override { return SizeBytes(); } std::string DebugString() const override { return absl::StrFormat("Sgd(D=%d)", DimSize()); } int DimSize() const override { return conf_.dim_size(); } int SliceSize() const override { return 1; } void Init(void* ctx) const override {} void Optimize(void* ctx, absl::Span num, absl::Span grad, absl::Span learning_rates, const int64_t global_step) const { float effective_lr = learning_rates[0]; for (int i = 0; i < conf_.dim_size(); ++i) { num[i] -= effective_lr * grad[i]; } } OptimizerDump Save(const void* ctx) const override { OptimizerDump dump; dump.add_dump()->mutable_sgd(); return dump; } void Restore(void* ctx, OptimizerDump dump) const override { // Do nothing. } private: SgdOptimizerConfig conf_; }; } // namespace std::unique_ptr NewSgdOptimizer(SgdOptimizerConfig config) { return std::make_unique(std::move(config)); } } // namespace hash_table } // namespace monolith ================================================ FILE: monolith/native_training/runtime/hash_table/optimizer/sgd_optimizer.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_OPTIMIZER_SGD_OPTIMIZER #define MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_OPTIMIZER_SGD_OPTIMIZER #include "monolith/native_training/runtime/hash_table/optimizer/optimizer.pb.h" #include "monolith/native_training/runtime/hash_table/optimizer/optimizer_interface.h" namespace monolith { namespace hash_table { std::unique_ptr NewSgdOptimizer(SgdOptimizerConfig config); } // namespace hash_table } // namespace monolith #endif // MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_OPTIMIZER_SGD_OPTIMIZER ================================================ FILE: monolith/native_training/runtime/hash_table/optimizer/sgd_optimizer_test.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/runtime/hash_table/optimizer/sgd_optimizer.h" #include #include "gmock/gmock.h" #include "gtest/gtest.h" #include "monolith/native_training/runtime/hash_table/optimizer/optimizer.pb.h" #include "monolith/native_training/runtime/hash_table/optimizer/test_utils.h" namespace monolith { namespace hash_table { namespace { using ::testing::Pointwise; using ::testing::FloatNear; TEST(SgdOptimizer, Basic) { SgdOptimizerConfig config; config.set_dim_size(1); auto opt = NewSgdOptimizer(config); TestOptimizerEntry mem(opt.get()); opt->Init(mem.mutable_ctx()); opt->Optimize(mem.mutable_ctx(), mem.mutable_num_span(), {1.0f}, {0.1f}); auto expected = {-0.1}; EXPECT_THAT(mem.num(), Pointwise(FloatNear(1e-6), expected)); } } // namespace } // namespace hash_table } // namespace monolith ================================================ FILE: monolith/native_training/runtime/hash_table/optimizer/stochastic_rounding.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/runtime/hash_table/optimizer/stochastic_rounding.h" namespace monolith { namespace hash_table { thread_local std::vector StochasticRoundingFloat16OptimizerDecorator::rng_ = {0, 1}; } // namespace hash_table } // namespace monolith ================================================ FILE: monolith/native_training/runtime/hash_table/optimizer/stochastic_rounding.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_OPTIMIZER_STOCHASTIC_ROUNDING #define MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_OPTIMIZER_STOCHASTIC_ROUNDING #include #include "absl/types/span.h" #include "monolith/native_training/runtime/hash_table/optimizer/optimizer.pb.h" #include "monolith/native_training/runtime/hash_table/optimizer/optimizer_interface.h" #include "third_party/half_sourceforge_net/half.hpp" namespace monolith { namespace hash_table { inline float stochastic_round(float vf, float p) { unsigned int half_up = half_float::detail::float2half(vf); unsigned int half_down = half_float::detail::float2half(vf); float vf_up = half_float::detail::half2float(half_up); float vf_down = half_float::detail::half2float(half_down); if (p <= (vf - vf_down) / (vf_up - vf_down)) { return vf_up; } else { return vf_down; } } class StochasticRoundingFloat16OptimizerDecorator : public OptimizerInterface { public: explicit StochasticRoundingFloat16OptimizerDecorator( std::unique_ptr optimizer) : optimizer_(std::move(optimizer)) {} // optimize the num based on gradients and the optimizer's data. // |num|, |grad| are float arrays whose length is at least DimSize(). // Result `num` will be stochastically rounded. void Optimize(void* ctx, absl::Span num, absl::Span grad, absl::Span learning_rates, const int64_t global_step = 0) const override { optimizer_->Optimize(ctx, num, grad, learning_rates); for (size_t i = 0; i < num.size(); ++i) { num[i] = stochastic_round(num[i], rand()); } } // Forward all other class methods. int64_t SizeBytes() const override { return optimizer_->SizeBytes(); } int64_t UncompressedSizeBytes() const override { return optimizer_->UncompressedSizeBytes(); } std::string DebugString() const override { return optimizer_->DebugString(); } // The dim that this optimizer can support. int DimSize() const override { return optimizer_->DimSize(); } // The slice size that this optimizer holds. int SliceSize() const override { return optimizer_->SliceSize(); } // Init optimizer ctx. // |num| is at least DimSize() long. void Init(void* ctx) const override { optimizer_->Init(ctx); } // Save and restore the entry. OptimizerDump Save(const void* ctx) const override { return optimizer_->Save(ctx); } void Restore(void* ctx, OptimizerDump dump) const override { optimizer_->Restore(ctx, dump); } private: std::unique_ptr optimizer_; static thread_local std::vector rng_; static void update_rng() { rng_[0] = (36969 * (rng_[0] & 65535) + (rng_[0] >> 16)) & 4294967295; rng_[1] = (18000 * (rng_[1] & 65535) + (rng_[1] >> 16)) & 4294967295; } /* * multiply-with-carry generator to generate a float number in [0, 1], * for stochastic rounding from fp32 to fp16 * * About Marsaglia's MWC generator, see * [http://www.cs.yorku.ca/~oz/marsaglia-rng.html] */ static float rand() { update_rng(); return static_cast((((rng_[0] & 65535) << 16) + rng_[1]) & 4294967295) / 4294967296; } }; } // namespace hash_table } // namespace monolith #endif // MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_OPTIMIZER_STOCHASTIC_ROUNDING ================================================ FILE: monolith/native_training/runtime/hash_table/optimizer/stochastic_rounding_test.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "gmock/gmock.h" #include "gtest/gtest.h" #include "monolith/native_training/runtime/hash_table/optimizer/optimizer.pb.h" #include "monolith/native_training/runtime/hash_table/optimizer/optimizer_factory.h" #include "monolith/native_training/runtime/hash_table/optimizer/stochastic_rounding.h" #include "monolith/native_training/runtime/hash_table/optimizer/test_utils.h" namespace monolith { namespace hash_table { namespace { TEST(StochasticRoundingFloat16OptimizerDecorator, Basic) { OptimizerConfig config; config.mutable_sgd()->set_dim_size(1); // Float32 optimizer. // By default, config.stochastic_rounding_float16() == false auto opt = NewOptimizerFromConfig(config); // Float16 optimizer. config.set_stochastic_rounding_float16(true); auto opt_float16 = NewOptimizerFromConfig(config); EXPECT_NE(typeid(*(opt.get())), typeid(*(opt_float16.get()))); } } // namespace } // namespace hash_table } // namespace monolith ================================================ FILE: monolith/native_training/runtime/hash_table/optimizer/test_utils.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_OPTIMIZER_TEST_UTILS #define MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_OPTIMIZER_TEST_UTILS #include "absl/types/span.h" #include "monolith/native_training/runtime/hash_table/optimizer/optimizer_interface.h" namespace monolith { namespace hash_table { // The pre allocated memory for the optimizer. class TestOptimizerEntry { public: explicit TestOptimizerEntry(OptimizerInterface* opt) : opt_(opt) { ctx_ = NewCtx(); num_ = NewNum(); } const void* ctx() { return ctx_.get(); } void* mutable_ctx() { return ctx_.get(); } std::vector* mutable_num() { return &num_; } const std::vector& num() { return num_; } absl::Span mutable_num_span() { return absl::MakeSpan(num_); } private: std::unique_ptr NewCtx() { return std::make_unique(opt_->SizeBytes()); } std::vector NewNum() { return std::vector(opt_->DimSize(), 0.0f); } OptimizerInterface* opt_; std::unique_ptr ctx_; std::vector num_; }; } // namespace hash_table } // namespace monolith #endif // MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_OPTIMIZER_TEST_UTILS ================================================ FILE: monolith/native_training/runtime/hash_table/quantized_entry_accessor.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_OPTIMIZER_QUANTIZED_ENTRY_ACCESSOR_H_ #define MONOLITH_MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_OPTIMIZER_QUANTIZED_ENTRY_ACCESSOR_H_ #include #include "monolith/native_training/runtime/hash_table/compressor/fake_quantizer.h" #include "monolith/native_training/runtime/hash_table/compressor/float_compressor.pb.h" #include "monolith/native_training/runtime/hash_table/entry_accessor_decorator.h" namespace monolith { namespace hash_table { struct SegmentQatConfig { explicit SegmentQatConfig(int dim_size, bool enable_qat = false, float r = 1.0f) : dim_size(dim_size), enable_qat(enable_qat), r(r) {} int dim_size; bool enable_qat; float r; }; // Makes the entry accessor support quantized aware training. class QuantizedEntryAccessor : public EntryAccessorDecorator { public: QuantizedEntryAccessor(std::unique_ptr accessor, std::vector segment_qat_configs) : EntryAccessorDecorator(std::move(accessor)), configs_(std::move(segment_qat_configs)) { for (const auto &config : configs_) { if (config.enable_qat) { fake_quantizers_.emplace_back( std::make_unique(config.r)); } else { fake_quantizers_.emplace_back(nullptr); } } } void Fill(const void *ctx, absl::Span num) const override { int dim_size = entry_accessor_->DimSize(); std::vector quantized(dim_size); auto *ctx_float = static_cast(ctx); // Simulate the quantization on weights int index = 0; for (size_t i = 0; i < fake_quantizers_.size(); ++i) { for (int j = 0; j < configs_[i].dim_size; ++j) { quantized[index] = fake_quantizers_[i] == nullptr ? ctx_float[index] : fake_quantizers_[i]->Quantize(ctx_float[index]); ++index; } } absl::c_copy(absl::MakeConstSpan(quantized), num.begin()); } void Optimize(void *ctx, absl::Span grad, absl::Span learning_rates, const int64_t global_step) const override { // Apply gradients to real weights entry_accessor_->Optimize(ctx, grad, learning_rates, global_step); } private: std::vector configs_; std::vector> fake_quantizers_; }; } // namespace hash_table } // namespace monolith #endif // MONOLITH_MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_OPTIMIZER_QUANTIZED_ENTRY_ACCESSOR_H_ ================================================ FILE: monolith/native_training/runtime/hash_table/quantized_entry_accessor_test.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "absl/strings/str_format.h" #include "gmock/gmock.h" #include "gtest/gtest.h" #include "monolith/native_training/runtime/hash_table/entry_accessor_decorator.h" #include "monolith/native_training/runtime/hash_table/quantized_entry_accessor.h" namespace monolith { namespace hash_table { namespace { using ::testing::_; using ::testing::ElementsAre; using ::testing::ExplainMatchResult; using ::testing::FloatEq; using ::testing::Invoke; using ::testing::Le; using ::testing::NiceMock; using ::testing::Return; using ::testing::WithArgs; class MockEntryAccessor : public EntryAccessorInterface { public: MOCK_CONST_METHOD0(SizeBytes, int64_t()); MOCK_CONST_METHOD0(UncompressedSizeBytes, int64_t()); MOCK_CONST_METHOD0(DebugString, std::string()); MOCK_CONST_METHOD0(DimSize, int()); MOCK_CONST_METHOD0(SliceSize, int()); MOCK_CONST_METHOD1(Init, void(void* ctx)); MOCK_CONST_METHOD2(Fill, void(const void* ctx, absl::Span)); MOCK_CONST_METHOD2(Assign, void(absl::Span num, void* ctx)); MOCK_CONST_METHOD2(AssignAdd, void(absl::Span num, void* ctx)); MOCK_CONST_METHOD2(Save, EntryDump(const void* ctx, uint32_t)); MOCK_CONST_METHOD3(Restore, void(void* ctx, uint32_t*, EntryDump)); MOCK_CONST_METHOD4(Optimize, void(void* ctx, absl::Span, absl::Span, const int64_t)); }; struct MockEntryAccessorFakeOption { float init_value = 1.0f; int dim = 10; }; void MakeMockEntryAccessor( MockEntryAccessor* mock, MockEntryAccessorFakeOption option = MockEntryAccessorFakeOption()) { ON_CALL(*mock, DimSize()).WillByDefault(Return(option.dim)); ON_CALL(*mock, SizeBytes()) .WillByDefault(Return(option.dim * sizeof(float) * 2)); ON_CALL(*mock, Init(_)) .WillByDefault(Invoke([option](void* ctx) { // Initialize embedding auto* w = reinterpret_cast(ctx); for (int i = 0; i < option.dim; ++i) { w[i] = option.init_value; } // Initialize optimizer auto* norm = w + option.dim; for (int i = 0; i < option.dim; ++i) { norm[i] = 1.0f; } })); ON_CALL(*mock, Optimize(_, _, _, _)) .WillByDefault(WithArgs<0, 1, 2, 3>(Invoke([option]( void* ctx, absl::Span grad, absl::Span learning_rates, const int64_t global_step) { auto* embedding = reinterpret_cast(ctx); auto* norm = embedding + option.dim; for (int i = 0; i < option.dim; ++i) { norm[i] += grad[i] * grad[i]; embedding[i] -= grad[i]; } }))); } TEST(QuantizedEntryAccessor, FixedRange) { auto accessor = std::make_unique>(); MakeMockEntryAccessor(accessor.get()); int dim_size = accessor->DimSize(); auto config1 = SegmentQatConfig(dim_size / 2, true, 1.0f); auto config2 = SegmentQatConfig(dim_size / 2, true, 0.5f); const float kStep1 = config1.r / 128, kStep2 = config2.r / 128; QuantizedEntryAccessor quantized_accessor(std::move(accessor), {config1, config2}); auto ctx = std::make_unique(quantized_accessor.SizeBytes()); quantized_accessor.Init(ctx.get()); std::vector num(dim_size), grad(dim_size, 1.0f); quantized_accessor.Fill(ctx.get(), absl::MakeSpan(num)); for (int i = 0; i < dim_size / 2; ++i) { EXPECT_THAT(std::abs(num[i] - 1.0f), Le(kStep1)); } for (int i = dim_size / 2; i < dim_size; ++i) { EXPECT_THAT(std::abs(num[i] - 0.5f), Le(kStep2)); } quantized_accessor.Optimize(ctx.get(), absl::MakeConstSpan(grad), {.0f}, 0); quantized_accessor.Fill(ctx.get(), absl::MakeSpan(num)); for (int i = 0; i < dim_size; ++i) { EXPECT_FLOAT_EQ(num[i], 0.0f); } } } // namespace } // namespace hash_table } // namespace monolith ================================================ FILE: monolith/native_training/runtime/hash_table/retriever/BUILD ================================================ load("@rules_cc//cc:defs.bzl", "cc_library", "cc_proto_library", "cc_test") load("@rules_proto//proto:defs.bzl", "proto_library") load("@com_google_protobuf//:protobuf.bzl", "py_proto_library") package(default_visibility = ["//monolith/native_training/runtime/hash_table:__subpackages__"]) cc_library( name = "retriever_interface", hdrs = ["retriever_interface.h"], deps = [ "@com_google_absl//absl/base", "@com_google_absl//absl/types:span", ], ) cc_library( name = "retriever_base", hdrs = ["retriever_base.h"], deps = [ ":retriever_interface", "@com_google_absl//absl/base", "@com_google_absl//absl/types:span", ], ) cc_library( name = "raw_retriever", srcs = ["raw_retriever.cc"], hdrs = ["raw_retriever.h"], deps = [ ":retriever_base", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/strings:str_format", ], ) cc_test( name = "raw_retriever_test", srcs = ["raw_retriever_test.cc"], deps = [ ":raw_retriever", "@com_google_absl//absl/random", "@com_google_googletest//:gtest_main", ], ) cc_library( name = "fake_quant_retriever", srcs = ["fake_quant_retriever.cc"], hdrs = ["fake_quant_retriever.h"], deps = [ ":retriever_base", "//monolith/native_training/runtime/hash_table/compressor:fake_quantizer", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/strings:str_format", ], ) cc_test( name = "fake_quant_retriever_test", srcs = ["fake_quant_retriever_test.cc"], deps = [ ":fake_quant_retriever", "@com_google_absl//absl/random", "@com_google_googletest//:gtest_main", ], ) cc_library( name = "retriever_combination", srcs = ["retriever_combination.cc"], hdrs = ["retriever_combination.h"], deps = [ ":retriever_base", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/strings:str_format", ], ) cc_test( name = "retriever_combination_test", srcs = ["retriever_combination_test.cc"], deps = [ ":raw_retriever", ":fake_quant_retriever", ":retriever_combination", "@com_google_absl//absl/random", "@com_google_googletest//:gtest_main", ], ) cc_library( name = "hash_net_retriever", srcs = ["hash_net_retriever.cc"], hdrs = ["hash_net_retriever.h"], deps = [ ":retriever_base", "//monolith/native_training/runtime/hash_table/compressor:hash_net_quantizer", "@com_google_absl//absl/algorithm:container", ], ) cc_test( name = "hash_net_retriever_test", srcs = ["hash_net_retriever_test.cc"], deps = [ ":hash_net_retriever", "@com_google_absl//absl/random", "@com_google_googletest//:gtest_main", ], ) ================================================ FILE: monolith/native_training/runtime/hash_table/retriever/fake_quant_retriever.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/runtime/hash_table/retriever/fake_quant_retriever.h" #include #include "absl/algorithm/container.h" #include "absl/strings/str_format.h" #include "monolith/native_training/runtime/hash_table/retriever/retriever_base.h" namespace monolith { namespace hash_table { namespace { class FakeQuantRetriever final : public RetrieverBase { public: FakeQuantRetriever(int dim_size, const FakeQuantizer& fake_quantizer) : RetrieverBase(dim_size), fake_quantizer_(fake_quantizer) {} void Retrieve(const void* ctx, absl::Span num) const override { absl::c_copy(GetNum(ctx), num.begin()); for (int i = 0; i < dim_size_; ++i) { num[i] = fake_quantizer_.Quantize(num[i]); } } void Backward(absl::Span num, absl::Span grad, int64_t global_step) const override {} std::string DebugString() const override { return absl::StrFormat("FakeQuant(D=%d)", RetrieverBase::DimSize()); } private: FakeQuantizer fake_quantizer_; }; } // namespace std::unique_ptr NewFakeQuantRetriever( int dim_size, const FakeQuantizer& fake_quantizer) { return std::make_unique(dim_size, fake_quantizer); } } // namespace hash_table } // namespace monolith ================================================ FILE: monolith/native_training/runtime/hash_table/retriever/fake_quant_retriever.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_RETRIEVER_FAKE_QUANT_RETRIEVER_H_ #define MONOLITH_MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_RETRIEVER_FAKE_QUANT_RETRIEVER_H_ #include #include "monolith/native_training/runtime/hash_table/compressor/fake_quantizer.h" #include "monolith/native_training/runtime/hash_table/retriever/retriever_interface.h" namespace monolith { namespace hash_table { std::unique_ptr NewFakeQuantRetriever(int dim_size, const FakeQuantizer& fake_quantizer); } // namespace hash_table } // namespace monolith #endif // MONOLITH_MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_RETRIEVER_FAKE_QUANT_RETRIEVER_H_ ================================================ FILE: monolith/native_training/runtime/hash_table/retriever/fake_quant_retriever_test.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/runtime/hash_table/retriever/fake_quant_retriever.h" #include #include "absl/random/random.h" #include "gmock/gmock.h" #include "gtest/gtest.h" namespace monolith { namespace hash_table { namespace { using ::testing::ElementsAre; using ::testing::Le; TEST(FakeQuantRetriever, Basic) { int dim_size = 10; float r = 1.0f; const float kStep = r / 128; FakeQuantizer fake_quantizer(1.0f); auto retriever = NewFakeQuantRetriever(dim_size, fake_quantizer); std::vector entry(dim_size); absl::BitGen bit_gen; for (auto& val : entry) { val = absl::Uniform(bit_gen, -1.f, 1.f); } std::vector num(dim_size, 0); retriever->Retrieve(entry.data(), absl::MakeSpan(num)); for (int i = 0; i < dim_size; ++i) { EXPECT_THAT(std::abs(entry[i] - num[i]), Le(kStep)); } } } // namespace } // namespace hash_table } // namespace monolith ================================================ FILE: monolith/native_training/runtime/hash_table/retriever/hash_net_retriever.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/runtime/hash_table/retriever/hash_net_retriever.h" #include #include #include "absl/algorithm/container.h" #include "monolith/native_training/runtime/hash_table/compressor/hash_net_quantizer.h" #include "monolith/native_training/runtime/hash_table/retriever/retriever_base.h" namespace monolith { namespace hash_table { class HashNetRetriever final : public RetrieverBase { public: HashNetRetriever(int dim_size, std::unique_ptr hash_net_quantizer) : RetrieverBase(dim_size), hash_net_quantizer_(std::move(hash_net_quantizer)) {} void Retrieve(const void* ctx, absl::Span num) const override { absl::c_copy(GetNum(ctx), num.begin()); for (int i = 0; i < dim_size_; ++i) { num[i] = hash_net_quantizer_->Forward(num[i]); } } void Backward(absl::Span num, absl::Span grad, int64_t global_step) const override { for (int i = 0; i < dim_size_; ++i) { hash_net_quantizer_->Backward(num[i], &grad[i], global_step); } } std::string DebugString() const override { return absl::StrFormat("HashNet(D=%d)", RetrieverBase::DimSize()); } private: std::unique_ptr hash_net_quantizer_; }; std::unique_ptr NewHashNetRetriever( int dim_size, std::unique_ptr hash_net_quantizer) { return std::make_unique(dim_size, std::move(hash_net_quantizer)); } } // namespace hash_table } // namespace monolith ================================================ FILE: monolith/native_training/runtime/hash_table/retriever/hash_net_retriever.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_RETRIEVER_HASH_NET_RETRIEVER_H_ #define MONOLITH_MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_RETRIEVER_HASH_NET_RETRIEVER_H_ #include #include "monolith/native_training/runtime/hash_table/compressor/hash_net_quantizer.h" #include "monolith/native_training/runtime/hash_table/retriever/retriever_interface.h" namespace monolith { namespace hash_table { std::unique_ptr NewHashNetRetriever( int dim_size, std::unique_ptr hash_net_quantizer); } // namespace hash_table } // namespace monolith #endif // MONOLITH_MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_RETRIEVER_HASH_NET_RETRIEVER_H_ ================================================ FILE: monolith/native_training/runtime/hash_table/retriever/hash_net_retriever_test.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/runtime/hash_table/retriever/hash_net_retriever.h" #include #include "absl/random/random.h" #include "gtest/gtest.h" namespace monolith { namespace hash_table { namespace { TEST(HashNetRetriever, Basic) { FloatCompressorConfig_OneBit config; config.set_dim_size(10); config.set_step_size(100); float amplitude = config.amplitude(); auto hash_net_quantizer = std::make_unique(config); HashNetQuantizer* quantizer = hash_net_quantizer.get(); auto retriever = NewHashNetRetriever(config.dim_size(), std::move(hash_net_quantizer)); std::vector entry(config.dim_size()); absl::BitGen bit_gen; for (auto& val : entry) { val = absl::Uniform(bit_gen, -1.f, 1.f); } std::vector num(config.dim_size(), 0); retriever->Retrieve(entry.data(), absl::MakeSpan(num)); for (int i = 0; i < config.dim_size(); ++i) { EXPECT_FLOAT_EQ(num[i], amplitude * std::tanh(entry[i])); } float grad = 1.0f; int64_t global_step = 100; quantizer->Backward(1.0f, &grad, global_step); EXPECT_FLOAT_EQ(grad, 0.35840667f * amplitude); float scale = quantizer->GetScale(); retriever->Retrieve(entry.data(), absl::MakeSpan(num)); for (int i = 0; i < config.dim_size(); ++i) { EXPECT_FLOAT_EQ(num[i], amplitude * std::tanh(scale * entry[i])); } } } // namespace } // namespace hash_table } // namespace monolith ================================================ FILE: monolith/native_training/runtime/hash_table/retriever/raw_retriever.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/runtime/hash_table/retriever/raw_retriever.h" #include #include "absl/algorithm/container.h" #include "absl/strings/str_format.h" #include "monolith/native_training/runtime/hash_table/retriever/retriever_base.h" namespace monolith { namespace hash_table { namespace { class RawRetriever final : public RetrieverBase { public: explicit RawRetriever(int dim_size) : RetrieverBase(dim_size) {} void Retrieve(const void* ctx, absl::Span num) const override { absl::c_copy(GetNum(ctx), num.begin()); } void Backward(absl::Span num, absl::Span grad, int64_t global_step) const override {} std::string DebugString() const override { return absl::StrFormat("Raw(D=%d)", RetrieverBase::DimSize()); } }; } // namespace std::unique_ptr NewRawRetriever(int dim_size) { return std::make_unique(dim_size); } } // namespace hash_table } // namespace monolith ================================================ FILE: monolith/native_training/runtime/hash_table/retriever/raw_retriever.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_RETRIEVER_RAW_RETRIEVER_H_ #define MONOLITH_MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_RETRIEVER_RAW_RETRIEVER_H_ #include #include "monolith/native_training/runtime/hash_table/retriever/retriever_interface.h" namespace monolith { namespace hash_table { std::unique_ptr NewRawRetriever(int dim_size); } // namespace hash_table } // namespace monolith #endif // MONOLITH_MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_RETRIEVER_RAW_RETRIEVER_H_ ================================================ FILE: monolith/native_training/runtime/hash_table/retriever/raw_retriever_test.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/runtime/hash_table/retriever/raw_retriever.h" #include #include "absl/random/random.h" #include "gmock/gmock.h" #include "gtest/gtest.h" namespace monolith { namespace hash_table { namespace { using ::testing::ElementsAre; TEST(RawRetriever, Basic) { int dim_size = 10; auto retriever = NewRawRetriever(dim_size); std::vector entry(dim_size); absl::BitGen bit_gen; for (auto& val : entry) { val = absl::Uniform(bit_gen, -1.f, 1.f); } std::vector num(dim_size, 0); retriever->Retrieve(entry.data(), absl::MakeSpan(num)); for (int i = 0; i < dim_size; ++i) { EXPECT_EQ(entry[i], num[i]); } } } // namespace } // namespace hash_table } // namespace monolith ================================================ FILE: monolith/native_training/runtime/hash_table/retriever/retriever_base.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_RETRIEVER_RETRIEVER_BASE_H_ #define MONOLITH_MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_RETRIEVER_RETRIEVER_BASE_H_ #include "monolith/native_training/runtime/hash_table/retriever/retriever_interface.h" namespace monolith { namespace hash_table { class RetrieverBase : public RetrieverInterface { public: explicit RetrieverBase(int dim_size) : dim_size_(dim_size), size_bytes_(sizeof(float) * dim_size) {} int64_t SizeBytes() const override { return size_bytes_; } int DimSize() const override { return dim_size_; } protected: absl::Span GetNum(const void* ctx) const { const auto* ctx_float = static_cast(ctx); return absl::MakeConstSpan(ctx_float, ctx_float + dim_size_); } int dim_size_; int64_t size_bytes_; }; } // namespace hash_table } // namespace monolith #endif // MONOLITH_MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_RETRIEVER_RETRIEVER_BASE_H_ ================================================ FILE: monolith/native_training/runtime/hash_table/retriever/retriever_combination.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/runtime/hash_table/retriever/retriever_combination.h" #include #include "absl/strings/str_format.h" #include "monolith/native_training/runtime/hash_table/retriever/retriever_base.h" namespace monolith { namespace hash_table { namespace { class CombinedRetriever final : public RetrieverBase { public: CombinedRetriever(std::unique_ptr retriever1, std::unique_ptr retriever2) : RetrieverBase(retriever1->DimSize() + retriever2->DimSize()), retriever1_(std::move(retriever1)), retriever2_(std::move(retriever2)) {} void Retrieve(const void* ctx, absl::Span num) const override { const void* ctx2 = static_cast(ctx) + retriever1_->SizeBytes(); auto num2 = num.subspan(retriever1_->DimSize()); retriever1_->Retrieve(ctx, num); retriever2_->Retrieve(ctx2, num2); } void Backward(absl::Span num, absl::Span grad, int64_t global_step) const override { int dim_size1 = retriever1_->DimSize(); retriever1_->Backward(num.subspan(0, dim_size1), grad.subspan(0, dim_size1), global_step); retriever2_->Backward(num.subspan(dim_size1), grad.subspan(dim_size1), global_step); } std::string DebugString() const override { return absl::StrFormat("%s|%s", retriever1_->DebugString(), retriever2_->DebugString()); } private: int dim_size_; std::unique_ptr retriever1_; std::unique_ptr retriever2_; }; } // namespace std::unique_ptr CombineRetrievers( std::unique_ptr retriever1, std::unique_ptr retriever2) { return std::make_unique(std::move(retriever1), std::move(retriever2)); } } // namespace hash_table } // namespace monolith ================================================ FILE: monolith/native_training/runtime/hash_table/retriever/retriever_combination.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_RETRIEVER_RETRIEVER_COMBINATION_H_ #define MONOLITH_MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_RETRIEVER_RETRIEVER_COMBINATION_H_ #include #include "monolith/native_training/runtime/hash_table/retriever/retriever_interface.h" namespace monolith { namespace hash_table { std::unique_ptr CombineRetrievers(std::unique_ptr retriever1, std::unique_ptr retriever2); } // namespace hash_table } // namespace monolith #endif // MONOLITH_MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_RETRIEVER_RETRIEVER_COMBINATION_H_ ================================================ FILE: monolith/native_training/runtime/hash_table/retriever/retriever_combination_test.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/runtime/hash_table/retriever/retriever_combination.h" #include #include "absl/random/random.h" #include "gmock/gmock.h" #include "gtest/gtest.h" #include "monolith/native_training/runtime/hash_table/retriever/raw_retriever.h" #include "monolith/native_training/runtime/hash_table/retriever/fake_quant_retriever.h" namespace monolith { namespace hash_table { namespace { using ::testing::Le; TEST(CombinedRetriever, Basic) { int dim_size1 = 10, dim_size2 = 20; int dim_size = dim_size1 + dim_size2; float r = 1.0f; const float kStep = r / 128; FakeQuantizer fake_quantizer(1.0f); auto retriever1 = NewRawRetriever(dim_size1); auto retriever2 = NewFakeQuantRetriever(dim_size2, fake_quantizer); auto retriever = CombineRetrievers(std::move(retriever1), std::move(retriever2)); std::vector entry(dim_size); absl::BitGen bit_gen; for (auto& val : entry) { val = absl::Uniform(bit_gen, -1.f, 1.f); } std::vector num(dim_size, 0); retriever->Retrieve(entry.data(), absl::MakeSpan(num)); for (int i = 0; i < dim_size1; ++i) { EXPECT_EQ(entry[i], num[i]); } for (int i = dim_size1; i < dim_size; ++i) { EXPECT_THAT(std::abs(entry[i] - num[i]), Le(kStep)); } } } // namespace } // namespace hash_table } // namespace monolith ================================================ FILE: monolith/native_training/runtime/hash_table/retriever/retriever_interface.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_RETRIEVER_RETRIEVER_INTERFACE_H_ #define MONOLITH_MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_RETRIEVER_RETRIEVER_INTERFACE_H_ #include "absl/types/span.h" namespace monolith { namespace hash_table { class RetrieverInterface { public: virtual ~RetrieverInterface() = default; // How many bytes could be accessed by the retriever virtual int64_t SizeBytes() const = 0; // The dim that this retriever can support. virtual int DimSize() const = 0; // Retrieve the num data accessed by the retriever. // |num| is a float array whose length is DimSize(). virtual void Retrieve(const void* ctx, absl::Span num) const = 0; // Back propagation virtual void Backward(absl::Span num, absl::Span grad, int64_t global_step) const = 0; virtual std::string DebugString() const = 0; }; } // namespace hash_table } // namespace monolith #endif // MONOLITH_MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_RETRIEVER_RETRIEVER_INTERFACE_H_ ================================================ FILE: monolith/native_training/runtime/hash_table/utils.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_UTILS #define MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_UTILS namespace monolith { namespace hash_table { inline void* AddOffset(void* p, int offset) { return reinterpret_cast(p) + offset; } inline const void* AddOffset(const void* p, int offset) { return reinterpret_cast(p) + offset; } template std::pair ComputeFusedOffsets( const int* slot_size_vec, // num_tables * num_shards const int* table_dims, // num_tables int num_tables, int num_shards, int* key_offsets, // num_tables * num_shards + 1 int* emb_offsets, // num_tables * num_shards + 1 int* keys_per_table, // num_tables int* emb_splits // num_shards ) { if (compute_keys_per_table) std::fill(keys_per_table, keys_per_table + num_tables, 0); int total_keys = 0; int total_embs = 0; int prev_total_emb = 0; key_offsets[0] = emb_offsets[0] = 0; for (int shard_id = 0; shard_id < num_shards; shard_id++) { for (int table_id = 0; table_id < num_tables; table_id++) { int idx = num_tables * shard_id + table_id; int slot_sz = slot_size_vec[idx]; int segment_dim = table_dims[table_id] * slot_sz; if (compute_keys_per_table) keys_per_table[table_id] += slot_sz; total_keys += slot_sz; total_embs += segment_dim; key_offsets[idx + 1] = key_offsets[idx] + slot_sz; emb_offsets[idx + 1] = emb_offsets[idx] + segment_dim; } emb_splits[shard_id] = total_embs - prev_total_emb; prev_total_emb = total_embs; } return std::make_pair(total_keys, total_embs); } } // namespace hash_table } // namespace monolith #endif // MONOLITH_NATIVE_TRAINING_RUNTIME_HASH_TABLE_UTILS ================================================ FILE: monolith/native_training/runtime/hash_table/workspace.bzl ================================================ ================================================ FILE: monolith/native_training/runtime/hopscotch/BUILD ================================================ load("@rules_cc//cc:defs.bzl", "cc_library", "cc_proto_library", "cc_test") package(default_visibility = ["//monolith/native_training/runtime:__subpackages__"]) cc_library( name = "hopscotch_hash_set", srcs = ["hopscotch_hash_set.cc"], hdrs = ["hopscotch_hash_set.h"], deps = [ "//monolith/native_training/runtime/concurrency:micro_one_bit_spin_lock", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/hash", "@com_google_absl//absl/synchronization", ], ) # diable this test since it is not runnable on TCE image. # cc_test( # name = "hopscotch_hash_set_test", # srcs = ["hopscotch_hash_set_test.cc"], # deps = [ # ":hopscotch_hash_set", # "@gperftools//:libtcmalloc", # "@com_google_googletest//:gtest_main", # ], # ) ================================================ FILE: monolith/native_training/runtime/hopscotch/hopscotch_hash_set.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/runtime/hopscotch/hopscotch_hash_set.h" #define likely(x) __builtin_expect(!!(x), 1) #define unlikely(x) __builtin_expect(!!(x), 0) namespace monolith { namespace hopscotch { using FID = int64_t; inline static uint32_t NextPowerOfTwo(uint32_t n) { --n; n |= n >> 1; n |= n >> 2; n |= n >> 4; n |= n >> 8; n |= n >> 16; return n + 1; } inline static int FirstLsbBitIndex(uint32_t x) { return __builtin_ffs(x) - 1; } template HopscotchHashSet::HopscotchHashSet(uint32_t capacity, uint32_t concurrency_level) : capacity_(capacity) { init_ = false; lock_mask_ = NextPowerOfTwo(concurrency_level) - 1; bucket_mask_ = NextPowerOfTwo(capacity * 1.2) - 1; init_lock_.Init(); num_elements_.store(0, std::memory_order_seq_cst); // 可能还没init就会获取 } template void HopscotchHashSet::DoInit() { table_.resize(bucket_mask_ + kHopscotchHashInsertRange + 1); locks_.resize(lock_mask_ + 1); for (size_t i = 0; i <= lock_mask_; ++i) { locks_[i].Init(); } extra_lock_.Init(); clear_lock_.Init(); num_elements_.store(0, std::memory_order_seq_cst); running_threads_.store(0, std::memory_order_seq_cst); DoClear(); } template void HopscotchHashSet::FindCloserFreeBucket( const concurrency::MicroOneBitSpinLock* lock, int* free_bucket, int* free_dist) { int move_bucket = *free_bucket - (kHopscotchHashHopRange - 1); int move_free_dist; for (move_free_dist = kHopscotchHashHopRange - 1; move_free_dist > 0; --move_free_dist) { auto new_lock = &locks_[move_bucket & lock_mask_]; uint32_t start_hop_info = table_[move_bucket].hop_info; int move_new_free_dist = !start_hop_info ? kHopscotchHashHopRange : __builtin_ctz(start_hop_info); if (move_new_free_dist < move_free_dist) { if (new_lock != lock) { new_lock->Lock(); } if (start_hop_info == table_[move_bucket].hop_info) { // new_free_bucket -> free_bucket and empty new_free_bucket int new_free_bucket = move_bucket + move_new_free_dist; table_[*free_bucket].key = table_[new_free_bucket].key; table_[*free_bucket].hash = table_[new_free_bucket].hash; table_[move_bucket].hop_info |= 1u << move_free_dist; table_[move_bucket].hop_info &= ~(1u << move_new_free_dist); *free_bucket = new_free_bucket; *free_dist -= move_free_dist - move_new_free_dist; if (new_lock != lock) { new_lock->Unlock(); } return; } if (new_lock != lock) { new_lock->Unlock(); } } ++move_bucket; } *free_bucket = -1; *free_dist = 0; } template size_t HopscotchHashSet::insert(Key key) { // we do lazy init here to save memory if (!init_) { init_lock_.Lock(); if (!init_) DoInit(); init_ = true; init_lock_.Unlock(); } size_t dropped_keys = 0; if (unlikely(size() > capacity_)) { clear_lock_.Lock(); if (likely(size() > capacity_)) { for (int i = 0; i < locks_.size(); ++i) locks_[i].Lock(); dropped_keys = size(); this->DoClear(); for (int i = 0; i < locks_.size(); ++i) locks_[i].Unlock(); } clear_lock_.Unlock(); } uint32_t hash = HashFunc(key); auto lock = &locks_[hash & lock_mask_]; lock->Lock(); int bucket = hash & bucket_mask_; uint32_t hop_info = table_[bucket].hop_info; // check if already exists while (0 != hop_info) { int i = FirstLsbBitIndex(hop_info); int current = bucket + i; if (key == table_[current].key) { lock->Unlock(); return dropped_keys; } hop_info &= ~(1U << i); } // looking for free bucket int free_bucket = bucket, free_dist = 0; for (; free_dist < kHopscotchHashInsertRange; ++free_dist, ++free_bucket) { if (kHopscotchHashEmpty == table_[free_bucket].hash && kHopscotchHashEmpty == __sync_val_compare_and_swap(&table_[free_bucket].hash, kHopscotchHashEmpty, hash)) { break; } } // insert the new key num_elements_.fetch_add(1, std::memory_order_relaxed); if (free_dist < kHopscotchHashInsertRange) { do { if (free_dist < kHopscotchHashHopRange) { table_[free_bucket].key = key; table_[free_bucket].hash = hash; table_[bucket].hop_info |= 1u << free_dist; lock->Unlock(); return dropped_keys; } FindCloserFreeBucket(lock, &free_bucket, &free_dist); } while (-1 != free_bucket); } else { // insert failed, insert into extra_ map extra_lock_.Lock(); extra_.insert(key); extra_lock_.Unlock(); } lock->Unlock(); return dropped_keys; } template std::vector HopscotchHashSet::GetAndClear() { if (!init_) return {}; clear_lock_.Lock(); for (int i = 0; i < locks_.size(); ++i) locks_[i].Lock(); std::vector results(size()); size_t index = 0; for (auto&& entry : table_) { if (entry.hash) { results[index++] = entry.key; } entry.hash = 0; entry.key = kEmptyKey; entry.hop_info = 0; } for (auto&& key : extra_) { results[index++] = key; } extra_.clear(); num_elements_.store(0, std::memory_order_seq_cst); for (int i = 0; i < locks_.size(); ++i) locks_[i].Unlock(); clear_lock_.Unlock(); return results; } template void HopscotchHashSet::DoClear() { for (size_t i = 0; i < table_.size(); ++i) { table_[i].hash = 0; table_[i].key = kEmptyKey; table_[i].hop_info = 0; } num_elements_.store(0, std::memory_order_seq_cst); extra_.clear(); } template class HopscotchHashSet; template class HopscotchHashSet>; template <> FID GetEmptyValue() { return -1; } template <> std::pair GetEmptyValue>() { return std::make_pair(-1, nullptr); } } // namespace hopscotch } // namespace monolith ================================================ FILE: monolith/native_training/runtime/hopscotch/hopscotch_hash_set.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_MONOLITH_NATIVE_TRAINING_RUNTIME_HOPSCOTCH_HOPSCOTCH_HASH_SET_H_ #define MONOLITH_MONOLITH_NATIVE_TRAINING_RUNTIME_HOPSCOTCH_HOPSCOTCH_HASH_SET_H_ #include #include #include #include #include "absl/container/flat_hash_set.h" #include "absl/hash/hash.h" #include "monolith/native_training/runtime/concurrency/micro_one_bit_spin_lock.h" namespace monolith { namespace hopscotch { using FID = int64_t; #pragma pack(push) #pragma pack(4) template struct hopscotch_entry_t { Key key; uint32_t hash; uint32_t hop_info; }; #pragma pack(pop) template Key GetEmptyValue() { return Key(); } // thread safe hopscotch hash set (insert only) // paper: // http://people.csail.mit.edu/shanir/publications/disc2008_submission_98.pdf template class HopscotchHashSet { public: explicit HopscotchHashSet(uint32_t capacity, uint32_t concurrency_level); // thread safe insert, return number keys cleared size_t insert(Key key); std::vector GetAndClear(); size_t size() const { return num_elements_.load(std::memory_order_relaxed); } uint32_t capacity() const { return capacity_; } private: uint32_t HashFunc(Key key) { return hash_func_(key) | 3; } void FindCloserFreeBucket(const concurrency::MicroOneBitSpinLock* lock, int* free_bucket, int* free_dist); void DoInit(); // clear the hash table, not thread safe void DoClear(); private: static constexpr uint32_t kHopscotchHashInsertRange = 4096; static constexpr uint32_t kHopscotchHashHopRange = 32; static constexpr uint32_t kHopscotchHashEmpty = 0; static constexpr uint32_t kHopscotchHashBusy = 1; Key kEmptyKey = GetEmptyValue(); absl::Hash hash_func_; // for those keys not insert into table absl::flat_hash_set extra_; concurrency::MicroOneBitSpinLock extra_lock_; concurrency::MicroOneBitSpinLock clear_lock_; concurrency::MicroOneBitSpinLock init_lock_; std::vector> table_; std::vector locks_; uint32_t lock_mask_; uint32_t bucket_mask_; std::atomic_int running_threads_; // number of thread doing insertion std::atomic_int num_elements_; // total number of elements uint32_t capacity_; bool init_; }; } // namespace hopscotch } // namespace monolith #endif // MONOLITH_MONOLITH_NATIVE_TRAINING_RUNTIME_HOPSCOTCH_HOPSCOTCH_HASH_SET_H_ ================================================ FILE: monolith/native_training/runtime/hopscotch/hopscotch_hash_set_test.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/runtime/hopscotch/hopscotch_hash_set.h" #include #include #include #include #include #include #include "gtest/gtest.h" #include "google/malloc_extension.h" namespace monolith { namespace hopscotch { namespace { using FID = int64_t; constexpr int kMaxNumKeys = 2097152; constexpr int kConcurrencyLevel = 200; constexpr int kSeed = 2233333; uint64_t GetTime() { struct timeval tv; gettimeofday(&tv, NULL); return tv.tv_sec * 1000000ULL + tv.tv_usec; } static size_t MemoryUsage() { size_t result = 0; if (MallocExtension::instance()->GetNumericProperty( "generic.current_allocated_bytes", &result)) { return result; } return 0; } static size_t memory_last = 0; static uint64_t time_last = 0; void Reset() { memory_last = MemoryUsage(); time_last = GetTime(); } void Report() { printf("time:%6.1f ms, memory:%6.1f M\n", (GetTime() - time_last) / 1000.0, (MemoryUsage() - memory_last) / (1024.0 * 1024)); Reset(); } TEST(HOPSCOTCH_HASH_SET, simple_test) { HopscotchHashSet hash_set(1000, 1); std::vector keys; for (int i = 0; i < 1000; ++i) { keys.emplace_back(std::rand()); hash_set.insert(keys.back()); } auto all = hash_set.GetAndClear(); ASSERT_EQ(all.size(), 1000); std::sort(all.begin(), all.end()); std::sort(keys.begin(), keys.end()); for (int i = 0; i < 1000; ++i) { ASSERT_EQ(keys[i], all[i]); } } template void TestOneMap(MapType* map) { srand(kSeed); for (int i = 0; i < kMaxNumKeys; ++i) { map->insert(std::rand()); } } // test google:dense_hash_set // 2096080 // time: 94.3 ms, memory: 32.0 M // // test std::set // 2096080 // time:1050.2 ms, memory: 96.0 M // // test std::unordered_set // 2096080 // time: 333.8 ms, memory: 48.4 M // // test hopscotch_hash_set // 2096080 // time: 188.3 ms, memory: 64.1 M TEST(HOPSCOTCH_HASH_SET, compare_test) { Reset(); std::cout << "test google:dense_hash_set" << std::endl; google::dense_hash_set dense_hash_set; dense_hash_set.set_empty_key(-1); TestOneMap(&dense_hash_set); std::cout << dense_hash_set.size() << std::endl; Report(); std::cout << "test std::set" << std::endl; std::set std_set; TestOneMap(&std_set); std::cout << std_set.size() << std::endl; Report(); std::cout << "test std::unordered_set" << std::endl; std::unordered_set std_unordered_set; TestOneMap(&std_unordered_set); std::cout << std_unordered_set.size() << std::endl; Report(); std::cout << "test hopscotch_hash_set" << std::endl; HopscotchHashSet hash_set(kMaxNumKeys, 1); TestOneMap(&hash_set); std::cout << hash_set.size() << std::endl; Report(); } TEST(HOPSCOTCH_HASH_SET, multithread_test) { HopscotchHashSet hash_set(kMaxNumKeys, 1000); srand(kSeed); for (int num_thread = 1; num_thread <= 10; ++num_thread) { std::cout << "test for " << num_thread << " threads" << std::endl; std::vector keys; for (int i = 0; i < kMaxNumKeys; ++i) { keys.emplace_back(std::rand()); } std::vector writers(num_thread); for (int i = 0; i < num_thread; ++i) { writers[i] = std::thread( [&](int index) { for (int j = index; j < kMaxNumKeys; j += num_thread) { EXPECT_EQ(0, hash_set.insert(keys[j])); } }, i); } for (int i = 0; i < num_thread; ++i) { writers[i].join(); } auto all = hash_set.GetAndClear(); std::sort(all.begin(), all.end()); std::sort(keys.begin(), keys.end()); keys.erase(std::unique(keys.begin(), keys.end()), keys.end()); std::cout << "insert finished. total insert keys: " << keys.size() << std::endl; ASSERT_EQ(all.size(), keys.size()); for (int i = 0; i < keys.size(); ++i) { ASSERT_EQ(keys[i], all[i]); } } } TEST(HOPSCOTCH_HASH_SET, overflow_test) { HopscotchHashSet hash_set(kMaxNumKeys, 1000); const int num_thread = 10; const int num_keys = kMaxNumKeys * 20 + 10000; std::vector keys(num_keys); for (int i = 0; i < num_keys; ++i) { keys[i] = i; } std::vector> dropped_keys(num_thread); std::vector writers(num_thread); for (int i = 0; i < num_thread; ++i) { writers[i] = std::thread( [&](int index) { for (int j = index; j < num_keys; j += num_thread) { int result = hash_set.insert(keys[j]); if (result != 0) { dropped_keys[index].emplace_back(result); } } }, i); } for (int i = 0; i < num_thread; ++i) { writers[i].join(); } int clear_times = 0; for (int i = 0; i < num_thread; ++i) { clear_times += dropped_keys[i].size(); } EXPECT_EQ(clear_times, 20); } } // namespace } // namespace hopscotch } // namespace monolith ================================================ FILE: monolith/native_training/runtime/ops/BUILD ================================================ load("@rules_cc//cc:defs.bzl", "cc_library", "cc_proto_library", "cc_test") load("@org_tensorflow//tensorflow:tensorflow.bzl", "cc_header_only_library", "tf_cc_test", "tf_custom_op_library", "tf_gen_op_wrapper_py", "tf_gpu_kernel_library_allow_except", "tf_kernel_library") load("@com_google_protobuf//:protobuf.bzl", "py_proto_library") load("@rules_proto//proto:defs.bzl", "proto_library") load("@bazel_skylib//lib:selects.bzl", "selects") package(default_visibility = [ "//monolith:__subpackages__", "@org_tensorflow//:__subpackages__", ]) cc_header_only_library( name = "traceme", deps = [ "@org_tensorflow//tensorflow/core/profiler/lib:traceme", ], ) cc_library( name = "tracelib", deps = [ ":traceme", "//monolith/native_training/runtime/common:metrics", "@com_google_glog//:glog", "@org_tensorflow//tensorflow/core:framework_headers_lib", ], ) tf_gpu_kernel_library_allow_except( name = "embedding_hash_table_tf_bridge", srcs = ["embedding_hash_table_tf_bridge.cc"], hdrs = ["embedding_hash_table_tf_bridge.h"], deps = [ ":hash_filter_tf_bridge", "//monolith/native_training/runtime/common:metrics", "//monolith/native_training/runtime/hash_filter:filter", "//monolith/native_training/runtime/hash_filter:probabilistic_filter", "//monolith/native_training/runtime/hash_filter:sliding_hash_filter", "//monolith/native_training/runtime/hash_table:embedding_hash_table_factory", "//monolith/native_training/runtime/hopscotch:hopscotch_hash_set", "@com_google_absl//absl/memory", "@org_tensorflow//tensorflow/core:framework_headers_lib", "@org_tensorflow//tensorflow/core/kernels:ops_util_hdrs", ], ) cc_library( name = "hash_filter_tf_bridge", srcs = ["hash_filter_tf_bridge.cc"], hdrs = ["hash_filter_tf_bridge.h"], deps = [ ":file_utils", "//monolith/native_training/data/training_instance:reader_util", "//monolith/native_training/runtime/hash_filter:filter", "//monolith/native_training/runtime/hash_table:embedding_hash_table_cc_proto", "@com_google_absl//absl/strings:str_format", "@org_tensorflow//tensorflow/core:framework_headers_lib", "@org_tensorflow//tensorflow/core/kernels:ops_util_hdrs", ], ) cc_library( name = "touched_key_set_tf_bridge", srcs = [], hdrs = ["touched_key_set_tf_bridge.h"], deps = [ "//monolith/native_training/runtime/hopscotch:hopscotch_hash_set", "@com_google_absl//absl/strings:str_format", "@org_tensorflow//tensorflow/core:framework_headers_lib", "@org_tensorflow//tensorflow/core/kernels:ops_util_hdrs", ], ) cc_library( name = "parameter_sync_tf_bridge", srcs = ["parameter_sync_tf_bridge.cc"], hdrs = ["parameter_sync_tf_bridge.h"], deps = [ ":embedding_hash_table_tf_bridge", ":multi_hash_table", "//monolith/native_training/runtime/parameter_sync:dummy_sync_client", "//monolith/native_training/runtime/parameter_sync:dummy_sync_server", "//monolith/native_training/runtime/parameter_sync:sync_client_manager", "@org_tensorflow//tensorflow/core:framework_headers_lib", "@org_tensorflow//tensorflow/core/kernels:ops_util_hdrs", ], ) cc_library( name = "file_utils", srcs = ["file_utils.cc"], hdrs = ["file_utils.h"], deps = [ "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_protobuf//:protobuf_lite", "@com_googlesource_code_re2//:re2", "@org_tensorflow//tensorflow/core:framework_headers_lib", "@org_tensorflow//tensorflow/core/kernels:ops_util_hdrs", ], ) tf_cc_test( name = "file_utils_test", srcs = ["file_utils_test.cc"], deps = [ ":file_utils", "//monolith/native_training/data/training_instance:reader_util", "@com_google_googletest//:gtest_main", "@org_tensorflow//tensorflow/core:test", ], ) tf_kernel_library( name = "clip_ops", srcs = [ "clip_by_global_norm.h", "clip_by_global_norm_op.cc", ], copts = [ "-D_ENABLE_AVX", ], gpu_srcs = [ "clip_by_global_norm.h", "clip_by_global_norm.cu.cc", "global_norm.cu.cc", "clip_by_global_norm_fused.cu.cc", "alloc_utils.h", ], linkopts = [], deps = [ "@org_tensorflow//tensorflow/core:framework_headers_lib", "@org_tensorflow//tensorflow/core/kernels:gpu_device_array_for_custom_op", "@org_tensorflow//tensorflow/core/kernels:gpu_prim_hdrs", ], ) cc_library( name = "multi_hash_table", hdrs = ["multi_hash_table.h"], deps = [ "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", "@org_tensorflow//tensorflow/core:framework_headers_lib", ], ) tf_gpu_kernel_library_allow_except( name = "hash_table_ops", srcs = [ "gpu_multi_hash_table.h", "hash_table/misc_ops.cc", "hash_table_lookup_op.cc", "hash_table_op.cc", "hash_table_restore_op.cc", "hash_table_save_op.cc", "hash_table_update_op.cc", "multi_hash_table.h", "multi_hash_table_lookup_op.cc", "multi_hash_table_op.cc", "multi_hash_table_save_restore_ops.cc", "multi_hash_table_update_op.cc", ], deps = [ ":embedding_hash_table_tf_bridge", ":file_utils", ":hash_filter_tf_bridge", ":multi_hash_table", ":parameter_sync_tf_bridge", "//monolith/native_training/data/training_instance:reader_util", "//monolith/native_training/runtime/concurrency:queue", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/random", "@com_google_absl//absl/strings", "@com_google_absl//absl/time", ], ) cc_library( name = "monolith_internal_ops", alwayslink = 1, ) cc_library( name = "monolith_ops_additional_deps", deps = select({ "@org_tensorflow//tensorflow:framework_shared_object": ["@org_tensorflow//tensorflow/core/platform/hadoop:hadoop_file_system"], "//conditions:default": [], }), ) cc_library( name = "monolith_ops", deps = [ ":clip_ops", ":deep_insight_ops", ":distribution_ops", ":file_ops", ":gen_seq_mask_op", ":hash_filter_ops", ":hash_table_ops", ":inbatch_auc_loss_ops", ":logging_ops", ":monolith_internal_ops", ":monolith_ops_additional_deps", ":parameter_sync_ops", ":remote_predict_op", ":touched_key_set_ops", "//monolith/native_training/data:pb_data_ops", "//monolith/native_training/data/training_instance:pb_datasource_ops", "//monolith/native_training/layers:layer_tf_ops", "//monolith/native_training/optimizers:training_ops", ], alwayslink = 1, ) # if framework_shared_object is true, # we shouldn't link the ops into tensorflow because # we don't separate the ops/kernels implementation. # Instead, we use dynamic load to solve this problem. selects.config_setting_group( name = "monolith_ops_for_tf_condition", match_any = ["@org_tensorflow//tensorflow:framework_shared_object", ":serving_gpu"], ) cc_library( name = "monolith_ops_for_tf", deps = select({ ":monolith_ops_for_tf_condition": [], "//conditions:default": [ ":monolith_ops", ], }), alwayslink = 1, ) tf_kernel_library( name = "monolith_ops_for_load", deps = select({ "@org_tensorflow//tensorflow:framework_shared_object": [":monolith_ops"], "//conditions:default": [], }), ) tf_gen_op_wrapper_py( name = "gen_monolith_ops_base", out = "gen_monolith_ops_base.py", deps = [":monolith_ops"], ) py_library( name = "gen_monolith_ops", srcs = ["gen_monolith_ops.py"], data = [":libtfkernel_monolith_ops_for_load.so"], deps = [ ":gen_monolith_ops_base", "//monolith:utils", "@org_tensorflow//tensorflow:tensorflow_py", ], ) tf_kernel_library( name = "distribution_ops", srcs = [ "alloc_utils.h", "fused_embedding_to_layout.cc", "fused_embedding_to_layout.h", "fused_reorder_by_indices.cc", "map_id_to_embedding_op.cc", "reduce_op.cc", "split_by_indices_op.cc", "static_reshape_op.cc", "unique_mapping_ops.cc", "normalize_merged_split_op.cc", ], copts = [ "-D_ENABLE_AVX", ], gpu_srcs = [ "map_id_to_embedding.cu.cc", "reduce_op.cu.cc", "alloc_utils.h", "fused_embedding_to_layout.h", "fused_embedding_to_layout.cu.cc", "aligned_concat_split.cu.cc", ], # TODO: Figure out how to link "@org_tensorflow//tensorflow/core/kernels:cwise_lib_hdrs" for fill_functor.h deps = [ "//idl:example_cc_proto", "//monolith/native_training/data/training_instance:data_reader", "//monolith/native_training/data/training_instance:parse_instance_lib", "//monolith/native_training/runtime/hash_table:embedding_hash_table_factory", "//monolith/native_training/runtime/ops:traceme", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@org_tensorflow//tensorflow/core:framework_headers_lib", "@org_tensorflow//tensorflow/core/kernels:gpu_device_array_for_custom_op", ], ) tf_gpu_kernel_library_allow_except( name = "hash_filter_ops", srcs = [ "hash_filter_intercept_gradient_op.cc", "hash_filter_op.cc", "hash_filter_restore_op.cc", "hash_filter_save_op.cc", ], deps = [ ":file_utils", ":hash_filter_tf_bridge", "//monolith/native_training/runtime/hash_filter", "//monolith/native_training/runtime/hash_filter:dummy_hash_filter", "//monolith/native_training/runtime/hash_filter:probabilistic_filter", "//monolith/native_training/runtime/hash_filter:sliding_hash_filter", "@org_tensorflow//tensorflow/core:framework_headers_lib", ], ) cc_library( name = "file_ops", srcs = [ "file_ops.cc", ], deps = [ "@org_tensorflow//tensorflow/core:framework_headers_lib", "//monolith/native_training/runtime/hash_table:embedding_hash_table_cc_proto", ], alwayslink = 1, ) cc_library( name = "touched_key_set_ops", srcs = [ "touched_key_set_insert_op.cc", "touched_key_set_op.cc", "touched_key_set_steal_op.cc", ], deps = [ ":touched_key_set_tf_bridge", "@org_tensorflow//tensorflow/core:framework_headers_lib", ], alwayslink = 1, ) cc_library( name = "gen_seq_mask_op", srcs = [ "gen_seq_mask.cc", ], deps = [ "@org_tensorflow//tensorflow/core:framework_headers_lib", ], alwayslink = 1, ) cc_library( name = "inbatch_auc_loss_ops", srcs = [ "inbatch_auc_loss.cc", ], deps = [ "@org_tensorflow//tensorflow/core:framework_headers_lib", ], alwayslink = 1, ) cc_library( name = "remote_predict_op_lib", hdrs = ["remote_predict_op.h"], deps = [ ":agent_heartbeat", ":tracelib", "//monolith/native_training/runtime/common:metrics", "@com_google_absl//absl/status", "@com_google_absl//absl/time", "@com_google_glog//:glog", "@org_tensorflow//tensorflow/core:framework_headers_lib", "@org_tensorflow_serving//tensorflow_serving/apis:prediction_service_proto", ], ) cc_library( name = "prediction_service_grpc", srcs = [ "prediction_service_grpc.cc", ], hdrs = [ "prediction_service_grpc.h", ], deps = [ "@com_google_absl//absl/status", "@com_google_absl//absl/time", "@org_tensorflow_serving//tensorflow_serving/apis:prediction_service_proto", ], ) cc_library( name = "remote_predict_op_grpc", srcs = ["remote_predict_op_grpc.cc"], deps = [ ":prediction_service_grpc", ":remote_predict_op_lib", ], alwayslink = 1, ) alias( name = "remote_predict_op", actual = ":remote_predict_op_grpc", ) tf_gpu_kernel_library_allow_except( name = "parameter_sync_ops", srcs = ["parameter_sync_ops.cc"], deps = [ ":parameter_sync_tf_bridge", "@com_github_grpc_grpc//:grpc++_reflection", "@org_tensorflow//tensorflow/core:framework_headers_lib", ], ) proto_library( name = "logging_ops_proto", srcs = ["logging_ops.proto"], ) cc_proto_library( name = "logging_ops_cc_proto", visibility = ["//visibility:public"], deps = [":logging_ops_proto"], ) py_proto_library( name = "logging_ops_py_proto", srcs = ["logging_ops.proto"], visibility = ["//visibility:public"], ) cc_library( name = "logging_ops", srcs = [ "logging_ops.cc", ], deps = [ ":logging_ops_cc_proto", "//monolith/native_training/runtime/common:metrics", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings:str_format", "@org_tensorflow//tensorflow/core:framework_headers_lib", ], alwayslink = 1, ) cc_library( name = "deep_insight_client_tf_bridge", hdrs = ["deep_insight_client_tf_bridge.h"], deps = [ ":file_metric_writer", "//monolith/native_training/runtime/deep_insight", "@com_google_absl//absl/strings:str_format", "@org_tensorflow//tensorflow/core:framework_headers_lib", "@org_tensorflow//tensorflow/core/kernels:ops_util_hdrs", ], ) cc_library( name = "deep_insight_ops", srcs = [ "deep_insight_ops.cc", ], deps = [ ":deep_insight_client_tf_bridge", ], alwayslink = 1, ) cc_library( name = "agent_heartbeat", srcs = [ "agent_heartbeat.cc", ], hdrs = [ "agent_heartbeat.h", ], deps = [ ":net_utils", "//monolith/agent_service:agent_service_cc_proto_grpc", "@com_github_grpc_grpc//:grpc++", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", "@com_google_glog//:glog", "@org_tensorflow//tensorflow/core/platform:logging", "@org_tensorflow_serving//tensorflow_serving/apis:prediction_service_proto", ], ) tf_cc_test( name = "agent_heartbeat_test", srcs = ["agent_heartbeat_test.cc"], extra_copts = [ "-DTEST_USE_GRPC", ], deps = [ ":agent_heartbeat", ":prediction_service_grpc", "@com_google_googletest//:gtest_main", "@org_tensorflow//tensorflow/core:test", ], ) cc_library( name = "net_utils", srcs = ["net_utils.cc"], hdrs = ["net_utils.h"], ) cc_test( name = "net_utils_test", srcs = ["net_utils_test.cc"], deps = [ ":net_utils", "@com_google_googletest//:gtest_main", ], ) cc_library( name = "serving_deps_with_framework_shared_object", srcs = ["@org_tensorflow//tensorflow:libtensorflow_framework.so.2"], deps = [ "@org_tensorflow//tensorflow/core:distributed_tensorflow_dependencies", "@org_tensorflow//tensorflow/core/distributed_runtime/rpc:grpc_runtime", ], ) cc_library( name = "file_metric_writer", srcs = ["file_metric_writer.cc"], hdrs = ["file_metric_writer.h"], deps = [ "//monolith/native_training/runtime/concurrency:queue", "//monolith/native_training/runtime/concurrency:thread_pool", "@com_google_absl//absl/strings:str_format", "@com_google_glog//:glog", "@org_tensorflow//tensorflow/core:framework_headers_lib", ], ) tf_cc_test( name = "file_metric_writer_test", srcs = ["file_metric_writer_test.cc"], deps = [ ":file_metric_writer", "@com_google_glog//:glog", "@com_google_googletest//:gtest_main", "@org_tensorflow//tensorflow/core:framework", "@org_tensorflow//tensorflow/core:lib", "@org_tensorflow//tensorflow/core:tensorflow", ], ) # Expose monolith ops for tf serving # we may need to change it to tf_gpu_kernel_library_allow_except later cc_library( name = "serving_ops_cc", srcs = [ ], visibility = [ "//visibility:public", ], deps = [ ":monolith_ops", ] + select({ "@org_tensorflow//tensorflow:framework_shared_object": [":serving_deps_with_framework_shared_object"], "//conditions:default": [], }), alwayslink = 1, ) config_setting( name = "serving_gpu", define_values = {"using_cuda": "true"}, ) ================================================ FILE: monolith/native_training/runtime/ops/agent_heartbeat.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/runtime/ops/agent_heartbeat.h" namespace tensorflow { namespace monolith_tf { const char *const kAgentPortEnvVar = "PORT2"; std::unique_ptr NewAgentStub() { const char *agent_port = getenv(kAgentPortEnvVar); if (agent_port == nullptr) { LOG(FATAL) << "missing env " << kAgentPortEnvVar; return nullptr; } auto channel = grpc::CreateChannel("localhost:" + std::string(agent_port), grpc::InsecureChannelCredentials()); return monolith::serving::agent_service::AgentService::NewStub(channel); } void RemoveOtherAddrsIfThereIsLocalAddr( const std::string &host, google::protobuf::RepeatedPtrField *addrs) { std::string local_shard; for (const std::string &addr : *addrs) { if (addr.find(host) == 0) { local_shard = addr; break; } } if (!local_shard.empty()) { addrs->Clear(); addrs->Add(std::move(local_shard)); } } int GetApiVersion( const absl::flat_hash_map> &model_addrs) { for (const auto it : model_addrs) { const std::string &model = it.first; if (model.find(":ps") != std::string::npos) { return 1; } } return 0; } std::string GetModelKey(absl::string_view model_name, absl::string_view server_type, int index) { return absl::StrCat(model_name, ":", server_type, ":", index); } std::string GetModelPsKey(absl::string_view model_name, int index) { return GetModelKey(model_name, "ps", index); } } // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/runtime/ops/agent_heartbeat.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_NATIVE_TRAINING_RUNTIME_OPS_AGENT_HEARTBEAT_H_ #define MONOLITH_NATIVE_TRAINING_RUNTIME_OPS_AGENT_HEARTBEAT_H_ #include #include #include #include #include #include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "absl/synchronization/notification.h" #include "absl/time/clock.h" #include "absl/time/time.h" #include "glog/logging.h" #include "grpcpp/channel.h" #include "grpcpp/create_channel.h" #include "grpcpp/security/credentials.h" #include "monolith/agent_service/agent_service.grpc.pb.h" #include "monolith/agent_service/agent_service.pb.h" #include "monolith/native_training/runtime/ops/net_utils.h" #include "tensorflow/core/platform/default/logging.h" namespace tensorflow { namespace monolith_tf { extern const char *const kAgentPortEnvVar; std::unique_ptr NewAgentStub(); void RemoveOtherAddrsIfThereIsLocalAddr( const std::string &host, google::protobuf::RepeatedPtrField *addrs); int GetApiVersion(const absl::flat_hash_map< std::string, std::vector> &model_addrs); std::string GetModelKey(absl::string_view model_name, absl::string_view server_type, int index); std::string GetModelPsKey(absl::string_view model_name, int index); // Provide getting PredictionServiceType by task, // while update cache data periodically by calling agent service. template class AgentHeartbeat { public: using AgentService = monolith::serving::agent_service::AgentService; AgentHeartbeat() : AgentHeartbeat(NewAgentStub(), absl::Seconds(15)) {} ~AgentHeartbeat() { stopped_.Notify(); heartbeat_thread_->join(); } explicit AgentHeartbeat( std::unique_ptr agent_stub, absl::Duration heartbeat_interval) : agent_stub_(std::move(agent_stub)), heartbeat_interval_(heartbeat_interval) { // Manual update once UpdateAddrs(); { absl::ReaderMutexLock l(&mu_); api_version_ = GetApiVersion(model_addrs_); } heartbeat_thread_ = std::make_unique(HeartbeatFunc, this); } static const AgentHeartbeat &GetInstance() { static AgentHeartbeat *instance = new AgentHeartbeat(); return *instance; } AgentHeartbeat(AgentHeartbeat const &) = delete; void operator=(AgentHeartbeat const &) = delete; // Old API encodes in this way: // API version 0: // model key: `ps:1` // model_name: `ps_1` // // API version 1: // model key: RealModel:ps:1 // model name: RealModel:ps:1 int api_version() const { return api_version_; } // Old APIs. Going to be deprecated. std::shared_ptr GetPredictionServiceByIdx( int idx) const { return GetPredictionService(absl::StrCat("ps:", idx)); } std::shared_ptr GetPredictionService( absl::string_view model_key) const { absl::ReaderMutexLock l(&mu_); auto iter = service_by_model_.find(model_key); if (iter == service_by_model_.end()) { LOG(ERROR) << "model key doesn't exist: " << model_key; return nullptr; } return iter->second; } void TestOnly_UpdateAddrs() { UpdateAddrs(); } absl::flat_hash_map> TestOnly_GetModelAddrs() { absl::ReaderMutexLock l(&mu_); return model_addrs_; } private: static void HeartbeatFunc(AgentHeartbeat *agent) { absl::Time now = absl::Now(); while (!agent->stopped_.WaitForNotificationWithTimeout( now + agent->heartbeat_interval_ - absl::Now())) { now = absl::Now(); agent->UpdateAddrs(); } } // Updates the current model addresses. void GetAddrs(monolith::serving::agent_service::ServerType server_type, absl::flat_hash_map>& new_model_addrs) { grpc::ClientContext context; context.set_deadline(std::chrono::system_clock::now() + absl::ToChronoSeconds(absl::Seconds(5))); monolith::serving::agent_service::HeartBeatRequest req; req.set_server_type(server_type); monolith::serving::agent_service::HeartBeatResponse resp; grpc::Status status = agent_stub_->HeartBeat(&context, req, &resp); if (!status.ok()) { LOG(ERROR) << "agent_service->HeartBeat error, code: " << status.error_code() << ", msg: " << status.error_message(); return; } const std::string my_host_ip = GetMyHostIp(); for (auto &kv : *resp.mutable_addresses()) { const std::string &model = kv.first; auto *resp_addrs = kv.second.mutable_address(); std::vector addr_list; addr_list.reserve(resp_addrs->size()); for (const std::string &addr : *resp_addrs) { addr_list.push_back(addr); } new_model_addrs.insert({model, std::move(addr_list)}); } } void UpdateAddrs() { absl::flat_hash_map> new_model_addrs; GetAddrs(monolith::serving::agent_service::PS, new_model_addrs); GetAddrs(monolith::serving::agent_service::DENSE, new_model_addrs); bool same; { absl::ReaderMutexLock l(&mu_); same = (new_model_addrs == model_addrs_); } if (!same) { absl::flat_hash_map> new_service_by_model; for (auto &kv : new_model_addrs) { new_service_by_model.emplace( kv.first, std::make_shared(kv.second)); } { absl::MutexLock l(&mu_); model_addrs_.swap(new_model_addrs); service_by_model_.swap(new_service_by_model); } } } absl::Notification stopped_; std::unique_ptr agent_stub_; absl::Duration heartbeat_interval_; int api_version_; // This is for the public API to use. absl::flat_hash_map> model_addrs_; mutable absl::Mutex mu_; absl::flat_hash_map> service_by_model_ GUARDED_BY(mu_); std::unique_ptr heartbeat_thread_; }; // Gets the model key. std::string GetModelPsKey(absl::string_view model_name, int index); std::string GetModelKey(absl::string_view model_name, absl::string_view server_type, int index); } // namespace monolith_tf } // namespace tensorflow #endif // MONOLITH_NATIVE_TRAINING_RUNTIME_OPS_AGENT_HEARTBEAT_H_ ================================================ FILE: monolith/native_training/runtime/ops/agent_heartbeat_test.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifdef TEST_USE_GRPC #include "monolith/native_training/runtime/ops/prediction_service_grpc.h" #else #include "monolith/native_training/runtime/ops/prediction_service_archon.h" #endif #include "gmock/gmock.h" #include "grpcpp/server.h" #include "grpcpp/server_builder.h" #include "gtest/gtest.h" #include "monolith/agent_service/agent_service_mock.grpc.pb.h" #include "monolith/native_training/runtime/ops/agent_heartbeat.h" namespace tensorflow { namespace monolith_tf { namespace { namespace tf_serving = ::tensorflow::serving; using DoneCallback = std::function; #ifdef TEST_USE_GRPC using PredictionServiceType = PredictionServiceGrpc; #else using PredictionServiceType = PredictionServiceArchon; #endif using ::monolith::serving::agent_service::AddressList; using ::monolith::serving::agent_service::AgentService; using ::monolith::serving::agent_service::HeartBeatResponse; using ::monolith::serving::agent_service::MockAgentServiceStub; using ::testing::DoAll; using ::testing::ElementsAre; using ::testing::InSequence; using ::testing::Pair; using ::testing::Return; using ::testing::SetArgPointee; using ::testing::UnorderedElementsAre; const absl::Duration kNoHeartbeat = absl::Hours(1000); TEST(AgentHeartbeatTest, Basic) { auto stub = std::make_unique(); { InSequence s; HeartBeatResponse resp1; (*resp1.mutable_addresses())["model_name"].add_address("localhost:1"); EXPECT_CALL(*stub, HeartBeat) .WillOnce(DoAll(SetArgPointee<2>(resp1), Return(grpc::Status::OK))); HeartBeatResponse resp2; (*resp2.mutable_addresses())["model_name"].add_address("localhost:2"); EXPECT_CALL(*stub, HeartBeat) .WillRepeatedly(DoAll(SetArgPointee<2>(resp2), Return(grpc::Status::OK))); } AgentHeartbeat agent(std::move(stub), kNoHeartbeat); for (const auto &p : agent.TestOnly_GetModelAddrs()) { printf("model_name = %s\n", p.first.c_str()); for (const auto addr : p.second) { printf("%s ", addr.c_str()); } puts(""); } EXPECT_THAT( agent.TestOnly_GetModelAddrs(), UnorderedElementsAre(Pair("model_name", ElementsAre("localhost:1")))); agent.TestOnly_UpdateAddrs(); EXPECT_THAT( agent.TestOnly_GetModelAddrs(), UnorderedElementsAre(Pair("model_name", ElementsAre("localhost:2")))); } TEST(AgentHeartbeatTest, HeartBeat) { auto stub = std::make_unique(); { InSequence s; HeartBeatResponse resp1; (*resp1.mutable_addresses())["model_name"].add_address("localhost:1"); EXPECT_CALL(*stub, HeartBeat) .WillOnce(DoAll(SetArgPointee<2>(resp1), Return(grpc::Status::OK))); HeartBeatResponse resp2; (*resp2.mutable_addresses())["model_name"].add_address("localhost:2"); EXPECT_CALL(*stub, HeartBeat) .WillRepeatedly( DoAll(SetArgPointee<2>(resp2), Return(grpc::Status::OK))); } AgentHeartbeat agent(std::move(stub), absl::ZeroDuration()); // Waits for heartbeat update. absl::SleepFor(absl::Seconds(0.2)); EXPECT_THAT( agent.TestOnly_GetModelAddrs(), UnorderedElementsAre(Pair("model_name", ElementsAre("localhost:2")))); } TEST(AgentHeartbeatTest, DefaultInstance) { setenv(kAgentPortEnvVar, "1234", 1); AgentHeartbeat::GetInstance(); } class MockPredictionService : public tf_serving::PredictionService::Service { public: MOCK_METHOD(grpc::Status, Predict, (grpc::ServerContext *, const tf_serving::PredictRequest *, tf_serving::PredictResponse *)); }; std::unique_ptr StartServer( tf_serving::PredictionService::Service *service, int *port) { grpc::ServerBuilder builder; builder.AddListeningPort(absl::StrCat(GetMyHostIp(), ":0"), grpc::InsecureServerCredentials(), port); builder.RegisterService(service); return builder.BuildAndStart(); } TEST(AgentHeartbeatTest, StubTest) { MockPredictionService service; EXPECT_CALL(service, Predict); int port; auto server = StartServer(&service, &port); auto stub = std::make_unique(); HeartBeatResponse resp; (*resp.mutable_addresses())["model_name"].add_address( absl::StrCat(GetMyHostIp(), ":", port)); EXPECT_CALL(*stub, HeartBeat) .WillRepeatedly(DoAll(SetArgPointee<2>(resp), Return(grpc::Status::OK))); AgentHeartbeat agent(std::move(stub), kNoHeartbeat); std::shared_ptr predict = agent.GetPredictionService("model_name"); tf_serving::PredictRequest predict_req; tf_serving::PredictResponse predict_resp; absl::Notification notify; predict->Predict( &predict_req, &predict_resp, [¬ify](absl::Status s, DoneCallback &&op_done) { notify.Notify(); }, 1000, [] {}); notify.WaitForNotification(); } TEST(AgentHeartbeatTest, ApiVersion) { auto stub = std::make_unique(); HeartBeatResponse resp; (*resp.mutable_addresses())["ps:0"].add_address("local_host:0"); EXPECT_CALL(*stub, HeartBeat) .WillRepeatedly(DoAll(SetArgPointee<2>(resp), Return(grpc::Status::OK))); AgentHeartbeat agent(std::move(stub), kNoHeartbeat); EXPECT_THAT(agent.api_version(), 0); } TEST(AgentHeartbeatTest2, ApiVersion) { auto stub = std::make_unique(); HeartBeatResponse resp; (*resp.mutable_addresses())["model_name:ps_0"].add_address("local_host:0"); EXPECT_CALL(*stub, HeartBeat) .WillRepeatedly(DoAll(SetArgPointee<2>(resp), Return(grpc::Status::OK))); AgentHeartbeat agent(std::move(stub), kNoHeartbeat); EXPECT_THAT(agent.api_version(), 1); } } // namespace } // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/runtime/ops/aligned_concat_split.cu.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #if GOOGLE_CUDA #define EIGEN_USE_GPU #include "monolith/native_training/runtime/ops/alloc_utils.h" #include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/kernels/gpu_device_array.h" #include "tensorflow/core/kernels/gpu_device_array_gpu.h" #include "tensorflow/core/util/gpu_kernel_helper.h" namespace tensorflow { namespace monolith_tf { __global__ void flat_concat( GpuDeviceArrayStruct input_ptrs_da, // length = 2N+1 int total, const float* _scale, float* out) { float scale = *_scale; auto _input_ptrs = GetGpuDeviceArrayOnDevice(&input_ptrs_da); extern __shared__ const float* input_ptrs[]; for (int i = threadIdx.x; i < input_ptrs_da.size; i += blockDim.x) input_ptrs[i] = _input_ptrs[i]; __syncthreads(); auto N = (input_ptrs_da.size - 1) / 2; auto sizes = reinterpret_cast(input_ptrs + N); auto offsets = sizes + N; auto tid = threadIdx.x + blockIdx.x * blockDim.x; auto stride = blockDim.x * gridDim.x; int work_id = 0; for (int id = tid; id < total; id += stride) { while (offsets[work_id + 1] <= id) work_id++; int i = id - offsets[work_id]; if (i < sizes[work_id]) { out[id] = input_ptrs[work_id][i] * scale; } else { out[id] = 0.0f; } } } // Flatten each input and then concatenate them. This op also ensures that the // start position of each input in the concat output is suitably aligned (as per // Tensorflow/Eigen's requirement), so that we can perform a split without // copying the underlying memory class AlignedFlatConcat : public OpKernel { public: explicit AlignedFlatConcat(OpKernelConstruction* context) : OpKernel(context) { OP_REQUIRES_OK(context, context->GetAttr("N", &N_)); } void Compute(OpKernelContext* context) override { const auto& gpu_device = context->eigen_gpu_device(); static_assert(sizeof(int) * 2 == sizeof(const float*)); GpuDeviceArrayOnHost input_ptrs(context, 2 * N_ + 1); OP_REQUIRES_OK(context, input_ptrs.Init()); FusedAlignedOutputAllocator fao_alloc(context); std::vector offsets_sizes(2 * N_ + 2); for (int i = 0; i < N_; ++i) { auto sz = context->input(i).NumElements(); input_ptrs.Set(i, context->input(i).flat().data()); offsets_sizes[i] = sz; offsets_sizes[N_ + i] = fao_alloc.get_aligned_total(); fao_alloc.add_slice(sz); } int total = fao_alloc.get_aligned_total(); offsets_sizes[2 * N_] = total; auto data = reinterpret_cast(offsets_sizes.data()); for (int i = 0; i <= N_; ++i) input_ptrs.Set(N_ + i, data[i]); OP_REQUIRES_OK(context, input_ptrs.Finalize()); OP_REQUIRES(context, 2 * N_ + 1 <= 2048, errors::Unknown("Total size of ", 2 * N_ + 1, " is greater than 2048 so is not supported. " "Please contact the developers.")); Tensor* out; OP_REQUIRES_OK(context, context->allocate_output(0, {total}, &out)); auto config = GetGpuLaunchConfig(total, gpu_device); TF_CHECK_OK(GpuLaunchKernel( flat_concat, config.block_count, config.thread_per_block, sizeof(const float*) * (2 * N_ + 1), gpu_device.stream(), input_ptrs.data(), total, context->input(N_).flat().data(), out->flat().data())); } private: int N_; }; class AlignedFlatSplit : public OpKernel { public: explicit AlignedFlatSplit(OpKernelConstruction* context) : OpKernel(context) { OP_REQUIRES_OK(context, context->GetAttr("N", &N_)); } void Compute(OpKernelContext* context) override { FusedAlignedOutputAllocator fao_alloc(context); const auto& flat = context->input(N_); for (int i = 0; i < N_; ++i) { context->set_output(i, fao_alloc.get_slice(context->input(i).shape(), flat)); } } private: int N_; }; REGISTER_OP("MonolithAlignedFlatConcat") .Input("inputs: N * float") .Input("scale: float") .Output("concat: float") .Attr("N: int") .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { c->set_output(0, c->Vector(c->UnknownDim())); return tensorflow::Status::OK(); }); REGISTER_KERNEL_BUILDER( Name("MonolithAlignedFlatConcat").Device(DEVICE_GPU), AlignedFlatConcat); REGISTER_OP("MonolithAlignedFlatSplit") .Input("inputs: N * float") // for shape inference only, data not used .Input("flat: float") .Output("concat: N * float") .Attr("N: int") .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { for (int i = 0; i < c->num_inputs() - 1; ++i) { c->set_output(i, c->input(i)); } return tensorflow::Status::OK(); }); REGISTER_KERNEL_BUILDER(Name("MonolithAlignedFlatSplit").Device(DEVICE_GPU), AlignedFlatSplit); } // namespace monolith_tf } // namespace tensorflow #endif ================================================ FILE: monolith/native_training/runtime/ops/alloc_utils.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor_shape.h" namespace tensorflow { namespace monolith_tf { /** * A tiny, fast allocator that allocates aligned output tensors, * each having different shape. * * Useful when you have > 500 output tensors for your Op * and calls allocate_output become the bottleneck * * How to use: * First, initialize your allocator with the desired alignment. * Note that the alignment is specified in terms of the number of elements of * the corresponding dtype, not in bytes. * FusedAlignedOutputAllocator alloc; * * Then, tell the allocator the total size of your output by calling add_slice in a loop * for (int i = 0; i < num_outputs; i++) { * alloc.add_slice(num_elements_in_this_output); * } * * Then, call .allocate. This will be a single call to ctx->allocate_temp * alloc.allocate(YOUR_DTYPE); * * Finally, get each output in the same order as you call add_slice. * You need to specify the shape for each output. * for (int i = 0; i < num_outputs; i++) { * ctx->set_output(i, alloc.get_slice({DIM_SIZE1, ...})); * } * * There's also a get_unaligned_total that may come in handy * if you want to get the total size of your output without padding */ template class FusedAlignedOutputAllocator { public: explicit FusedAlignedOutputAllocator(OpKernelContext* ctx): ctx_(ctx) { } inline void add_slice(int64 num_elements) { total_ += num_elements; aligned_total_ += round_up_to_align(num_elements); } inline void allocate(DataType dtype) { // allocate_temp may seem suspicious here, but it's properly reference counted // (including its slice), so we don't need to worry about its lifetime problem OP_REQUIRES_OK(ctx_, ctx_->allocate_temp(dtype, {aligned_total_}, &flat_out_)); aligned_total_ = 0; } inline Tensor get_slice(const TensorShape& shape, const Tensor& flat) { int64 num_elements = shape.num_elements(); Tensor reshaped; // note: CopyFrom and Slice doesn't copy the underlying memory (void)reshaped.CopyFrom(flat.Slice(aligned_total_, aligned_total_ + num_elements), shape); aligned_total_ += round_up_to_align(num_elements); return reshaped; } inline Tensor get_slice(const TensorShape& shape) { return get_slice(shape, flat_out_); } inline int64 get_unaligned_total() const { return total_; } inline int64 get_aligned_total() const { return aligned_total_; } private: OpKernelContext* ctx_; int64 aligned_total_ = 0; int64 total_ = 0; Tensor flat_out_; static constexpr int64 round_up_to_align(int64 a) { if (alignment == 0) return a; constexpr int64 temp = alignment - 1; constexpr int64 temp2 = ~temp; return (a + temp) & temp2; } }; } // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/runtime/ops/clip_by_global_norm.cu.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #if GOOGLE_CUDA #define EIGEN_USE_GPU #include "monolith/native_training/runtime/ops/clip_by_global_norm.h" #include "monolith/native_training/runtime/ops/alloc_utils.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/kernels/gpu_device_array.h" #include "tensorflow/core/kernels/gpu_device_array_gpu.h" #include "tensorflow/core/util/gpu_kernel_helper.h" namespace tensorflow { namespace monolith { namespace { __global__ void element_wise_mul( GpuDeviceArrayStruct input_ptrs_da, GpuDeviceArrayStruct output_ptrs_da, GpuDeviceArrayStruct offsets_da, int size, float scale) { const float** input_ptrs = GetGpuDeviceArrayOnDevice(&input_ptrs_da); int* offsets = GetGpuDeviceArrayOnDevice(&offsets_da); float** output_ptrs = GetGpuDeviceArrayOnDevice(&output_ptrs_da); // if using shared memory // Ref: // https://github.com/tensorflow/tensorflow/blob/v2.4.0/tensorflow/core/kernels/split_lib_gpu.cu.cc#L124 GPU_DYNAMIC_SHARED_MEM_DECL(sizeof(int), unsigned char, smem); int* smem_offsets = reinterpret_cast(smem); for (int x = threadIdx.x; x < offsets_da.size; x += blockDim.x) { smem_offsets[x] = offsets[x]; } __syncthreads(); offsets = smem_offsets; int i = 0; GPU_1D_KERNEL_LOOP(idx, size) { // safe offsets read: when idx == size - 1, i+1 == num_inputs while (offsets[i + 1] <= idx) ++i; int j = idx - offsets[i]; output_ptrs[i][j] = ldg(input_ptrs[i] + j) * scale; } } } // namespace typedef Eigen::GpuDevice GPUDevice; template <> struct ClipByGlobalNormImpl { static void Compute(OpKernelContext* context, float scale) { const auto& gpu_device = context->eigen_gpu_device(); auto N_ = context->num_inputs() - 2; GpuDeviceArrayOnHost input_ptrs_da(context, N_); GpuDeviceArrayOnHost offsets(context, N_ + 1); OP_REQUIRES_OK(context, input_ptrs_da.Init()); OP_REQUIRES_OK(context, offsets.Init()); monolith_tf::FusedAlignedOutputAllocator fao_alloc(context); for (int i = 0; i < N_; ++i) { input_ptrs_da.Set(i, context->input(i).flat().data()); offsets.Set(i, fao_alloc.get_unaligned_total()); fao_alloc.add_slice(context->input(i).NumElements()); } int total = fao_alloc.get_unaligned_total(); offsets.Set(N_, total); OP_REQUIRES_OK(context, input_ptrs_da.Finalize()); OP_REQUIRES_OK(context, offsets.Finalize()); GpuDeviceArrayOnHost output_ptrs_da(context, N_); OP_REQUIRES_OK(context, output_ptrs_da.Init()); fao_alloc.allocate(DT_FLOAT); for (int i = 0; i < N_; ++i) { auto t = fao_alloc.get_slice(context->input(i).shape()); output_ptrs_da.Set(i, t.flat().data()); context->set_output(i, std::move(t)); } OP_REQUIRES_OK(context, output_ptrs_da.Finalize()); auto config = GetGpuLaunchConfig(total, gpu_device); const int smem_usage = sizeof(int) * (N_ + 1); TF_CHECK_OK(GpuLaunchKernel( element_wise_mul, config.block_count, config.thread_per_block, smem_usage, gpu_device.stream(), input_ptrs_da.data(), output_ptrs_da.data(), offsets.data(), total, scale)); } }; REGISTER_KERNEL_BUILDER(Name("MonolithClipByGlobalNorm") .Device(DEVICE_GPU) .HostMemory("global_norm") .HostMemory("clip_norm"), ClipByGlobalNorm); } // namespace monolith } // namespace tensorflow #endif // GOOGLE_CUDA ================================================ FILE: monolith/native_training/runtime/ops/clip_by_global_norm.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_NATIVE_TRAINING_RUNTIME_OPS_CLIP_BY_GLOBAL_NORM #define MONOLITH_NATIVE_TRAINING_RUNTIME_OPS_CLIP_BY_GLOBAL_NORM #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/util/work_sharder.h" namespace tensorflow { namespace monolith { template struct ClipByGlobalNormImpl { static void Compute(OpKernelContext* context, float scale); }; template class ClipByGlobalNorm : public OpKernel { public: explicit ClipByGlobalNorm(OpKernelConstruction* context) : OpKernel(context) {} void Compute(OpKernelContext* context) override { int num_inputs = context->num_inputs() - 2; float global_norm = context->input(num_inputs).scalar()(); float clip_norm = context->input(num_inputs + 1).scalar()(); if (global_norm > clip_norm) { ClipByGlobalNormImpl::Compute(context, clip_norm / global_norm); } else { // If no clip, output as input. for (int i = 0; i < num_inputs; ++i) { context->set_output(i, context->input(i)); } } } }; } // namespace monolith } // namespace tensorflow #endif // MONOLITH_NATIVE_TRAINING_RUNTIME_OPS_CLIP_BY_GLOBAL_NORM ================================================ FILE: monolith/native_training/runtime/ops/clip_by_global_norm_fused.cu.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #if GOOGLE_CUDA #define EIGEN_USE_GPU #include "monolith/native_training/runtime/ops/alloc_utils.h" #include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/kernels/gpu_device_array.h" #include "tensorflow/core/kernels/gpu_device_array_gpu.h" #include "tensorflow/core/kernels/gpu_prim.h" #include "tensorflow/core/util/gpu_kernel_helper.h" namespace tensorflow { namespace monolith_tf { namespace { template __global__ void globalReduceSum( GpuDeviceArrayStruct input_ptrs_da, GpuDeviceArrayStruct offsets_da, float* out, int size) { const float** input_ptrs = GetGpuDeviceArrayOnDevice(&input_ptrs_da); int* offsets = GetGpuDeviceArrayOnDevice(&offsets_da); extern __shared__ int smem[]; for (int x = threadIdx.x; x < offsets_da.size; x += blockDim.x) { smem[x] = offsets[x]; } __syncthreads(); offsets = smem; float thread_sum = 0; int i = 0; GPU_1D_KERNEL_LOOP(idx, size) { // safe offsets read: when idx == size - 1, i+1 == N_ while (offsets[i + 1] <= idx) ++i; int j = idx - offsets[i]; float v = ldg(input_ptrs[i] + j); thread_sum += v * v; // l2 } // thread reduce sum to block reduce sum typedef gpuprim::BlockReduce BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; float block_sum = BlockReduce(temp_storage).Sum(thread_sum); if (threadIdx.x == 0) // block reduce sum to global reduce sum atomicAdd(out, block_sum); } __global__ void element_wise_mul( GpuDeviceArrayStruct input_ptrs_da, GpuDeviceArrayStruct output_ptrs_da, GpuDeviceArrayStruct offsets_da, int size, float scale) { const float** input_ptrs = GetGpuDeviceArrayOnDevice(&input_ptrs_da); int* offsets = GetGpuDeviceArrayOnDevice(&offsets_da); float** output_ptrs = GetGpuDeviceArrayOnDevice(&output_ptrs_da); extern __shared__ int smem[]; for (int x = threadIdx.x; x < offsets_da.size; x += blockDim.x) { smem[x] = offsets[x]; } __syncthreads(); offsets = smem; int i = 0; GPU_1D_KERNEL_LOOP(idx, size) { // safe offsets read: when idx == size - 1, i+1 == num_inputs while (offsets[i + 1] <= idx) ++i; int j = idx - offsets[i]; output_ptrs[i][j] = ldg(input_ptrs[i] + j) * scale; } } } // namespace class ClipByGlobalNormFused : public OpKernel { public: explicit ClipByGlobalNormFused(OpKernelConstruction* context) : OpKernel(context) { OP_REQUIRES_OK(context, context->GetAttr("N", &N_)); } void Compute(OpKernelContext* context) override { const auto& gpu_device = context->eigen_gpu_device(); GpuDeviceArrayOnHost input_ptrs_da(context, N_); GpuDeviceArrayOnHost offsets(context, N_ + 1); OP_REQUIRES_OK(context, input_ptrs_da.Init()); OP_REQUIRES_OK(context, offsets.Init()); FusedAlignedOutputAllocator fao_alloc(context); for (int i = 0; i < N_; ++i) { input_ptrs_da.Set(i, context->input(i).flat().data()); offsets.Set(i, fao_alloc.get_unaligned_total()); fao_alloc.add_slice(context->input(i).NumElements()); } int total = fao_alloc.get_unaligned_total(); offsets.Set(N_, total); OP_REQUIRES_OK(context, input_ptrs_da.Finalize()); OP_REQUIRES_OK(context, offsets.Finalize()); Tensor d_norm; OP_REQUIRES_OK(context, context->allocate_temp(DT_FLOAT, {}, &d_norm)); gpu_device.memset(d_norm.data(), 0, sizeof(float)); constexpr int block_sz = 1024; const int smem_usage = sizeof(int) * (N_ + 1); TF_CHECK_OK(GpuLaunchKernel( globalReduceSum, std::min(gpu_device.maxGpuThreadsPerMultiProcessor() / block_sz, 1) * gpu_device.getNumGpuMultiProcessors(), block_sz, smem_usage, gpu_device.stream(), input_ptrs_da.data(), offsets.data(), d_norm.flat().data(), total)); // async kernel launch above can hide some latency of the code below until // synchronize Tensor* h_norm_out; OP_REQUIRES_OK(context, context->allocate_output(N_, {}, &h_norm_out)); gpu_device.memcpyDeviceToHost(h_norm_out->data(), d_norm.data(), sizeof(float)); GpuDeviceArrayOnHost output_ptrs_da(context, N_); OP_REQUIRES_OK(context, output_ptrs_da.Init()); float clip_norm = context->input(N_).scalar()(); fao_alloc.allocate(DT_FLOAT); for (int i = 0; i < N_; ++i) { auto t = fao_alloc.get_slice(context->input(i).shape()); output_ptrs_da.Set(i, t.flat().data()); // if this ends up unused, it will be overwritten context->set_output(i, std::move(t)); } OP_REQUIRES_OK(context, output_ptrs_da.Finalize()); auto config = GetGpuLaunchConfig(total, gpu_device); gpu_device.synchronize(); float global_norm = std::sqrt(h_norm_out->scalar()()); if (global_norm > clip_norm) { TF_CHECK_OK(GpuLaunchKernel(element_wise_mul, config.block_count, config.thread_per_block, smem_usage, gpu_device.stream(), input_ptrs_da.data(), output_ptrs_da.data(), offsets.data(), total, clip_norm / global_norm)); } else { for (int i = 0; i < N_; ++i) { *context->mutable_output(i) = context->input(i); } } } private: int N_; }; REGISTER_OP("MonolithClipByGlobalNormFused") .Input("grad_list: N * float") .Input("clip_norm: float") .Output("clipped: N * float") .Output("global_norm: float") .Attr("N: int") .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { for (int i = 0; i < c->num_inputs(); ++i) { c->set_output(i, c->input(i)); } return tensorflow::Status::OK(); }); REGISTER_KERNEL_BUILDER(Name("MonolithClipByGlobalNormFused") .Device(DEVICE_GPU) .HostMemory("global_norm") .HostMemory("clip_norm"), ClipByGlobalNormFused); } // namespace monolith_tf } // namespace tensorflow #endif ================================================ FILE: monolith/native_training/runtime/ops/clip_by_global_norm_op.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "clip_by_global_norm.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/util/work_sharder.h" namespace tensorflow { namespace monolith { typedef Eigen::ThreadPoolDevice CPUDevice; template <> struct ClipByGlobalNormImpl { static void Compute(OpKernelContext* context, float scale) { int num_inputs = context->num_inputs() - 2; bool user_parallel = num_inputs > 4; auto func = [context, scale](int64 start, int64 end) { for (int64 i = start; i < end; ++i) { Tensor* temp; context->allocate_output(i, context->input(i).shape(), &temp); temp->flat() = context->input(i).flat() * scale; } }; if (user_parallel) { auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads()); Shard(worker_threads.num_threads, worker_threads.workers, num_inputs, (num_inputs + worker_threads.num_threads - 1) / worker_threads.num_threads, func); } else { func(0, num_inputs); } } }; REGISTER_KERNEL_BUILDER(Name("MonolithClipByGlobalNorm").Device(DEVICE_CPU), ClipByGlobalNorm); // End: Kernel Definition REGISTER_OP("MonolithClipByGlobalNorm") .Input("grad_list: N * float") .Input("global_norm: float") .Input("clip_norm: float") .Output("clip_grad_list: N * float") .Attr("N: int") .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { int input_n = c->num_inputs() - 2; for (int i = 0; i < input_n; ++i) { c->set_output(i, c->input(i)); } return tensorflow::Status::OK(); }); } // namespace monolith } // namespace tensorflow ================================================ FILE: monolith/native_training/runtime/ops/deep_insight_client_tf_bridge.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_NATIVE_TRAINING_RUNTIME_OPS_DEEP_INSIGHT_CLIENT_TF_BRIDGE #define MONOLITH_NATIVE_TRAINING_RUNTIME_OPS_DEEP_INSIGHT_CLIENT_TF_BRIDGE #include #include "absl/strings/str_format.h" #include "monolith/native_training/runtime/deep_insight/deep_insight.h" #include "monolith/native_training/runtime/ops/file_metric_writer.h" #include "tensorflow/core/framework/resource_mgr.h" using monolith::deep_insight::ExtraField; namespace tensorflow { namespace monolith_tf { class DeepInsightClientTfBridge : public ResourceBase { public: explicit DeepInsightClientTfBridge( std::unique_ptr deep_insight_client, std::unique_ptr file_metric_writer) : deep_insight_client_(std::move(deep_insight_client)), file_metric_writer_(std::move(file_metric_writer)) {} std::string SendV2( const std::string& model_name, const std::vector& targets, uint64_t uid, int64_t req_time, int64_t train_time, const std::vector& labels, const std::vector& preds, const std::vector& sample_rates, float sample_ratio, const std::vector>& extra_fields, bool return_msgs) { std::string msg = deep_insight_client_->SendV2( model_name, targets, uid, req_time, train_time, labels, preds, sample_rates, sample_ratio, extra_fields, true); file_metric_writer_->Write(msg); return return_msgs ? msg : ""; } int64_t GenerateTrainingTime() { return deep_insight_client_->GenerateTrainingTime(); } uint64_t GetTotalSendCounter() { return deep_insight_client_->GetTotalSendCounter(); } std::string DebugString() const override { return absl::StrFormat("Total send counter: %d", deep_insight_client_->GetTotalSendCounter()); } private: std::unique_ptr deep_insight_client_; std::unique_ptr file_metric_writer_; }; } // namespace monolith_tf } // namespace tensorflow #endif // MONOLITH_NATIVE_TRAINING_RUNTIME_OPS_DEEP_INSIGHT_CLIENT_TF_BRIDGE ================================================ FILE: monolith/native_training/runtime/ops/deep_insight_ops.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/runtime/ops/deep_insight_client_tf_bridge.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/framework/resource_op_kernel.h" using monolith::deep_insight::ExtraField; using monolith::deep_insight::FloatExtraField; using monolith::deep_insight::Int64ExtraField; using monolith::deep_insight::StringExtraField; namespace tensorflow { namespace monolith_tf { class MonolithCreateDeepInsightClientOp : public ResourceOpKernel { public: explicit MonolithCreateDeepInsightClientOp(OpKernelConstruction* ctx) : ResourceOpKernel(ctx) { OP_REQUIRES_OK( ctx, ctx->GetAttr("enable_metrics_counter", &enable_metrics_counter_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("is_fake", &is_fake_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("filename", &filename_)); } ~MonolithCreateDeepInsightClientOp() override {} private: Status CreateResource(DeepInsightClientTfBridge** deep_insight_client_bridge) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) override { auto deep_insight_client = std::make_unique( enable_metrics_counter_, is_fake_); auto file_metric_writer = std::make_unique(filename_); *deep_insight_client_bridge = new DeepInsightClientTfBridge( std::move(deep_insight_client), std::move(file_metric_writer)); return Status::OK(); } bool enable_metrics_counter_; bool is_fake_; std::string filename_; }; class MonolithWriteDeepInsightOp : public OpKernel { public: explicit MonolithWriteDeepInsightOp(OpKernelConstruction* ctx) : OpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("model_name", &model_name_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("target", &target_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("sample_ratio", &sample_ratio_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("return_msgs", &return_msgs_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("use_zero_train_time", &use_zero_train_time_)); } void Compute(OpKernelContext* ctx) override { DeepInsightClientTfBridge* deep_insight_client = nullptr; OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &deep_insight_client)); core::ScopedUnref unref(deep_insight_client); auto uids_vec = ctx->input(1).vec(); auto req_times_vec = ctx->input(2).vec(); auto labels_vec = ctx->input(3).vec(); auto preds_vec = ctx->input(4).vec(); auto sample_rates_vec = ctx->input(5).vec(); int64_t train_time = use_zero_train_time_ ? 0 : deep_insight_client->GenerateTrainingTime(); int64_t batch_size = labels_vec.dimension(0); Tensor* msgs; ctx->allocate_output(0, {batch_size}, &msgs); auto msgs_vec = msgs->vec(); std::vector targets; targets.push_back(target_); for (uint32_t i = 0; i < batch_size; i++) { std::vector labels, preds, sample_rates; std::vector> extra_fields; labels.push_back(labels_vec(i)); preds.push_back(preds_vec(i)); sample_rates.push_back(sample_rates_vec(i)); std::string msg = deep_insight_client->SendV2( model_name_, targets, uids_vec(i), req_times_vec(i), train_time, labels, preds, sample_rates, sample_ratio_, extra_fields, return_msgs_); msgs_vec(i) = msg; } }; private: std::string model_name_; std::string target_; float sample_ratio_; bool return_msgs_; bool use_zero_train_time_; }; class MonolithWriteDeepInsightV2 : public OpKernel { public: explicit MonolithWriteDeepInsightV2(OpKernelConstruction* ctx) : OpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("model_name", &model_name_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("targets", &targets_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("sample_ratio", &sample_ratio_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("return_msgs", &return_msgs_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("use_zero_train_time", &use_zero_train_time_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("extra_fields_keys", &extra_fields_keys_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("Tfields", &extra_fields_dtypes_)); } void Compute(OpKernelContext* ctx) override { DeepInsightClientTfBridge* deep_insight_client = nullptr; OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &deep_insight_client)); core::ScopedUnref unref(deep_insight_client); auto req_times_vec = ctx->input(1).vec(); // batch auto labels_mat = ctx->input(2).matrix(); // num_heads x batch auto preds_mat = ctx->input(3).matrix(); // num_heads x batch auto sample_rates_mat = ctx->input(4).matrix(); // num_heads x batch OpInputList extra_fields_values; OP_REQUIRES_OK( ctx, ctx->input_list("extra_fields_values", &extra_fields_values)); int64_t train_time = use_zero_train_time_ ? 0 : deep_insight_client->GenerateTrainingTime(); int64_t batch_size = labels_mat.dimension(1); std::vector>> batched_extra_fields; // batch x num_keys std::vector uids_vec; for (uint32_t b = 0; b < batch_size; b++) { batched_extra_fields.emplace_back(); auto& extra_fields = batched_extra_fields.back(); for (size_t i = 0; i < extra_fields_dtypes_.size(); i++) { if (extra_fields_dtypes_.at(i) == tensorflow::DT_FLOAT) { extra_fields.push_back(std::make_shared( extra_fields_keys_.at(i), extra_fields_values[i].vec()(b))); } else if (extra_fields_dtypes_.at(i) == tensorflow::DT_INT64) { if (extra_fields_keys_.at(i) == "uid") { uids_vec.push_back(extra_fields_values[i].vec()(b)); } else { extra_fields.push_back(std::make_shared( extra_fields_keys_.at(i), extra_fields_values[i].vec()(b))); } } else if (extra_fields_dtypes_.at(i) == tensorflow::DT_STRING) { extra_fields.push_back(std::make_shared( extra_fields_keys_.at(i), extra_fields_values[i].vec()(b))); } } } Tensor* msgs; ctx->allocate_output(0, {batch_size}, &msgs); auto msgs_vec = msgs->vec(); for (uint32_t i = 0; i < batch_size; i++) { std::vector labels, preds, sample_rates; for (int j = 0; j < targets_.size(); j++) { labels.push_back(labels_mat(j, i)); preds.push_back(preds_mat(j, i)); sample_rates.push_back(sample_rates_mat(j, i)); } std::string msg = deep_insight_client->SendV2( model_name_, targets_, uids_vec.at(i), req_times_vec(i), train_time, labels, preds, sample_rates, sample_ratio_, batched_extra_fields.at(i), return_msgs_); msgs_vec(i) = msg; } }; private: std::string model_name_; std::vector targets_; float sample_ratio_; bool return_msgs_; bool use_zero_train_time_; std::vector extra_fields_keys_; std::vector extra_fields_dtypes_; }; REGISTER_OP("MonolithCreateDeepInsightClient") .Output("handle: resource") .Attr("enable_metrics_counter: bool = false") .Attr("is_fake: bool = false") .Attr("filename: string = ''") .Attr("container: string = ''") .Attr("shared_name: string = ''") .SetIsStateful() .SetShapeFn(shape_inference::ScalarShape); REGISTER_KERNEL_BUILDER( Name("MonolithCreateDeepInsightClient").Device(DEVICE_CPU), MonolithCreateDeepInsightClientOp); REGISTER_OP("MonolithWriteDeepInsight") .Input("deep_insight_client_handle: resource") .Input("uids: int64") .Input("req_times: int64") .Input("labels: float") .Input("preds: float") .Input("sample_rates: float") .Output("msgs: string") .Attr("model_name: string") .Attr("target: string = 'ctr_head'") .Attr("sample_ratio: float = 0.01") .Attr("return_msgs: bool = false") .Attr("use_zero_train_time: bool = false") .SetShapeFn([](shape_inference::InferenceContext* ctx) { ctx->set_output(0, ctx->input(1)); return Status::OK(); }); REGISTER_KERNEL_BUILDER(Name("MonolithWriteDeepInsight").Device(DEVICE_CPU), MonolithWriteDeepInsightOp); REGISTER_OP("MonolithWriteDeepInsightV2") .Input("deep_insight_client_handle: resource") .Input("req_times: int64") .Input("labels: float") .Input("preds: float") .Input("sample_rates: float") .Input("extra_fields_values: Tfields") .Output("msgs: string") .Attr("model_name: string") .Attr("extra_fields_keys: list(string)") .Attr("Tfields: list(type)") .Attr("targets: list(string)") .Attr("sample_ratio: float = 0.01") .Attr("return_msgs: bool = false") .Attr("use_zero_train_time: bool = false") .SetShapeFn([](shape_inference::InferenceContext* ctx) { ctx->set_output(0, ctx->input(1)); return Status::OK(); }); REGISTER_KERNEL_BUILDER(Name("MonolithWriteDeepInsightV2").Device(DEVICE_CPU), MonolithWriteDeepInsightV2); } // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/runtime/ops/embedding_hash_table_tf_bridge.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/runtime/ops/embedding_hash_table_tf_bridge.h" #include #include #include #include #include #include "absl/container/flat_hash_map.h" #include "absl/memory/memory.h" #include "absl/types/span.h" #include "monolith/native_training/runtime/common/metrics.h" #include "monolith/native_training/runtime/hash_table/embedding_hash_table_factory.h" #include "monolith/native_training/runtime/hash_table/optimizer/avx_utils.h" namespace tensorflow { namespace monolith_tf { namespace { namespace hash_table = ::monolith::hash_table; using ::monolith::hash_table::EmbeddingHashTableConfig; constexpr int64_t kSecPerHour = 60 * 60; Status ValidateDim(const Tensor& t, int64 expected_dim) { if (TF_PREDICT_FALSE(t.NumElements() != expected_dim)) { return errors::InvalidArgument("The dim doesn't match expectation. ", t.NumElements(), " vs ", expected_dim); } return Status::OK(); } } // namespace Status EmbeddingHashTableTfBridge::New( monolith::hash_table::EmbeddingHashTableConfig config, HashFilterTfBridge* hash_filter, EmbeddingHashTableTfBridge** new_bridge, const std::string& name, hash_table::GpuExtraArgs args) { auto bridge = core::RefCountPtr( new EmbeddingHashTableTfBridge(hash_filter)); bridge->config_ = config; bridge->name_ = name; bridge->dim_size_ = 0; for (const auto& segment : config.entry_config().segments()) { bridge->dim_size_ += segment.dim_size(); } try { bridge->table_ = hash_table::NewEmbeddingHashTableFromConfig(config, std::move(args)); } catch (const std::exception& e) { return errors::InvalidArgument(e.what()); } bridge->max_update_ts_sec_ = std::make_unique>(0); bridge->last_evict_ts_sec_ = std::make_unique>(0); auto& bridge_ref = *bridge; bridge_ref.evict_finished_ = true; if (config.enable_feature_eviction()) { bridge_ref.evict_finished_ = false; const int evict_features_every_n_hours = config.feature_evict_every_n_hours(); auto evict_func = [&bridge_ref, evict_features_every_n_hours]() { while (!bridge_ref.evict_finished_) { const int64_t last_evict_ts_sec = bridge_ref.last_evict_ts_sec(); const int64_t current_ts_sec = std::chrono::duration_cast( std::chrono::system_clock::now().time_since_epoch()) .count(); if (last_evict_ts_sec == 0) { bridge_ref.set_last_evict_ts_sec(current_ts_sec); } if (last_evict_ts_sec != 0 && current_ts_sec - last_evict_ts_sec >= evict_features_every_n_hours * kSecPerHour) { LOG_EVERY_N_SEC(INFO, 60) << "embedding_hash_table_tf_bridge: started feature eviction, " "current_ts_sec : " << current_ts_sec << " last_evict_ts_sec : " << last_evict_ts_sec << " max_update_ts_sec: " << bridge_ref.max_update_ts_sec(); bridge_ref.table_->Evict(bridge_ref.max_update_ts_sec()); bridge_ref.set_last_evict_ts_sec(current_ts_sec); LOG_EVERY_N_SEC(INFO, 60) << "embedding_hash_table_tf_bridge: finished feature eviction"; } std::this_thread::sleep_for(std::chrono::seconds(10)); } }; bridge_ref.evict_thread_ = std::make_unique(evict_func); } *new_bridge = bridge.release(); return Status::OK(); } Status EmbeddingHashTableTfBridge::BatchLookup(OpKernelContext* ctx, const int num_ids, int64_t* ids, float* out_embedding, int64_t* hit_fid_count) const { try { std::vector> out_embeddings; out_embeddings.reserve(num_ids); for (int i = 0; i < num_ids; ++i) { out_embeddings.push_back( absl::MakeSpan(out_embedding + i * dim_size(), dim_size())); } *hit_fid_count = table_->BatchLookup(absl::MakeSpan(ids, num_ids), absl::MakeSpan(out_embeddings)); if (IsServingEntryType() && num_ids) { const std::string tagkv = absl::StrFormat("name=%s", name_); float hit_rate = *hit_fid_count / static_cast(num_ids); monolith::GetMetrics()->emit_timer("lookup_fid_hit_rate", hit_rate, tagkv); } return Status::OK(); } catch (const std::exception& e) { return errors::InvalidArgument(e.what()); } } Status EmbeddingHashTableTfBridge::BatchLookupEntry( OpKernelContext* ctx, const int num_ids, int64_t* ids, EntryDump* out_entries) const { try { table_->BatchLookupEntry(absl::MakeSpan(ids, num_ids), absl::MakeSpan(out_entries, num_ids)); return Status::OK(); } catch (const std::exception& e) { return errors::InvalidArgument(e.what()); } } Status EmbeddingHashTableTfBridge::Lookup(OpKernelContext* ctx, int64 id, float* out_embedding) const { try { table_->Lookup(id, absl::MakeSpan(out_embedding, dim_size())); return Status::OK(); } catch (const std::exception& e) { return errors::InvalidArgument(e.what()); } } Status EmbeddingHashTableTfBridge::LookupEntry(OpKernelContext* ctx, int64 id, EntryDump* out_entry) const { try { table_->LookupEntry(id, absl::MakeSpan(out_entry, 1)); return Status::OK(); } catch (const std::exception& e) { return errors::InvalidArgument(e.what()); } } Status EmbeddingHashTableTfBridge::Assign(OpKernelContext* ctx, int num_ids, const int64_t* ids, const float* embeddings, int64_t update_time) const { try { int64_t max_value = std::max(update_time, max_update_ts_sec_->load()); max_update_ts_sec_->store(max_value); std::vector ids_after_filter; std::vector> embeddings_after_filter; ids_after_filter.reserve(num_ids); embeddings_after_filter.reserve(num_ids); for (int i = 0; i < num_ids; ++i) { int64_t id = ids[i]; if (!table_->Contains(id) && hash_filter_->ShouldBeFiltered(id, table_.get())) { continue; } ids_after_filter.push_back(id); embeddings_after_filter.emplace_back( absl::MakeSpan(embeddings + i * dim_size(), dim_size())); } table_->Assign(absl::MakeSpan(ids_after_filter), absl::MakeSpan(embeddings_after_filter), update_time); return Status::OK(); } catch (const std::exception& e) { return errors::InvalidArgument(e.what()); } } Status EmbeddingHashTableTfBridge::AssignAdd(OpKernelContext* ctx, int64 id, const Tensor& tensor, int64_t update_time) const { // Here max_update_ts_sec_ only need a fuzzy maximum value. // We don't need use strict locking or compare_and_change // to change this max_update_ts_sec_. And we can save some performance here. int64_t max_value = std::max(update_time, max_update_ts_sec_->load()); max_update_ts_sec_->store(max_value); if (!table_->Contains(id) && hash_filter_->ShouldBeFiltered(id, table_.get())) { return Status::OK(); } try { TF_RETURN_IF_ERROR(ValidateDim(tensor, dim_size_)); auto span = absl::MakeConstSpan(static_cast(tensor.data()), dim_size()); table_->AssignAdd(id, span, update_time); return Status::OK(); } catch (const std::exception& e) { return errors::InvalidArgument(e.what()); } } Status EmbeddingHashTableTfBridge::AssignAdd2(int64 id, absl::Span value, int64_t update_time) { int64_t max_value = std::max(update_time, max_update_ts_sec_->load()); max_update_ts_sec_->store(max_value); if (hash_filter_->ShouldBeFiltered(id, table_.get())) { return Status::OK(); } try { table_->AssignAdd(id, value, update_time); return Status::OK(); } catch (const std::exception& e) { return errors::InvalidArgument(e.what()); } } Status EmbeddingHashTableTfBridge::Reinitialize(const int64_t* ids, int64_t num_ids, int* status) { auto id_vec = absl::MakeConstSpan(ids, num_ids); try { table_->Reinitialize(id_vec, absl::MakeSpan(status, num_ids)); if (hash_set_ != nullptr) { for (int64_t id : id_vec) { hash_set_->insert(std::make_pair(id, this)); } } return Status::OK(); } catch (const std::exception& e) { return errors::InvalidArgument(e.what()); } } Status EmbeddingHashTableTfBridge::BatchOptimize( OpKernelContext* ctx, size_t num_ids, const int64_t* ids, const float* tensor, absl::Span learning_rates, int64_t update_time, bool enable_dedup, const int64_t global_step) const { int max_value = std::max(update_time, max_update_ts_sec_->load()); max_update_ts_sec_->store(max_value); try { std::vector ids_after_filter; std::vector> grads_after_filter; // TODO(zouxuan): add theadlocal cache instead of allocating everytime. std::unique_ptr cache_grads; if (enable_dedup) { // To avoid repeated alloc, we do a conservative block allocation. cache_grads = std::make_unique(num_ids * dim_size()); // The first step we do a dedup, where all grads and occurences are // grouped by IDs. absl::flat_hash_map ids_to_grads; absl::flat_hash_map ids_to_counts; ids_to_grads.reserve(num_ids); ids_to_counts.reserve(num_ids); for (int i = 0; i < num_ids; ++i) { int64_t id = ids[i]; if (!ids_to_grads.count(id)) { ids_to_counts[id] = 1; const float* grad_src = tensor + i * dim_size(); float* grad_dest = cache_grads.get() + ids_to_grads.size() * dim_size(); std::memcpy(grad_dest, grad_src, dim_size() * sizeof(float)); ids_to_grads[id] = grad_dest; } else { const float* grad_src = tensor + i * dim_size(); float* grad_dest = ids_to_grads[id]; hash_table::ReduceSum(grad_dest, grad_src, grad_dest, dim_size()); ++(ids_to_counts[id]); } } // The second step is to perform a filtering, and creates the vect of IDs // and grads for update. ids_after_filter.reserve(num_ids); grads_after_filter.reserve(num_ids); for (const auto& entry : ids_to_counts) { int64_t id = entry.first; uint32_t filter_count = entry.second; if (!table_->Contains(id) && hash_filter_->ShouldBeFiltered(id, filter_count, table_.get())) { continue; } ids_after_filter.emplace_back(id); grads_after_filter.emplace_back( absl::MakeSpan(ids_to_grads[id], dim_size())); } } else { // We do simple increments (by 1) on the hash filters. ids_after_filter.reserve(num_ids); grads_after_filter.reserve(num_ids); for (int i = 0; i < num_ids; ++i) { int64_t id = ids[i]; if (!table_->Contains(id) && hash_filter_->ShouldBeFiltered(id, table_.get())) { continue; } ids_after_filter.emplace_back(id); grads_after_filter.emplace_back( absl::MakeSpan(tensor + i * dim_size(), dim_size())); } } // The final step is to perform an update based on the optimizer it uses. table_->BatchOptimize(absl::MakeSpan(ids_after_filter), absl::MakeSpan(grads_after_filter), learning_rates, update_time, global_step); if (hash_set_ != nullptr) { for (int64_t id : ids_after_filter) { hash_set_->insert(std::make_pair(id, this)); } } return Status::OK(); } catch (const std::exception& e) { return errors::InvalidArgument(e.what()); } } Status EmbeddingHashTableTfBridge::Optimize( OpKernelContext* ctx, int64 id, absl::Span grads, absl::Span learning_rates, int64_t update_time, int64_t global_step) const { int max_value = std::max(update_time, max_update_ts_sec_->load()); max_update_ts_sec_->store(max_value); try { table_->Optimize(id, grads, learning_rates, update_time, global_step); if (hash_set_ != nullptr) { hash_set_->insert(std::make_pair(id, this)); } return Status::OK(); } catch (const std::exception& e) { return errors::InvalidArgument(e.what()); } } Status EmbeddingHashTableTfBridge::LockAll(std::unique_ptr* ctx) { try { *ctx = table_->LockAll(); return Status::OK(); } catch (const std::exception& e) { return errors::ResourceExhausted(e.what()); } } Status EmbeddingHashTableTfBridge::Save(OpKernelContext* ctx, DumpShard shard, WriteFn write_fn, DumpIterator* iter) const { try { table_->Save(shard, std::move(write_fn), iter); return Status::OK(); } catch (const std::exception& e) { return errors::ResourceExhausted(e.what()); } } Status EmbeddingHashTableTfBridge::Restore( OpKernelContext* ctx, DumpShard shard, std::function get_fn) const { try { int64_t update_time = table_->Restore(shard, std::move(get_fn)); // Here we make sure max value is updated correctly when there are // multiple threads to update this value simultaneously. // There is no overhead since this operation is called once for each shard. while (true) { int64_t old_value = max_update_ts_sec_->load(); int64_t new_value = std::max(old_value, update_time); bool ret = max_update_ts_sec_->compare_exchange_weak(old_value, new_value); if (ret == true) { break; } } return Status::OK(); } catch (const std::exception& e) { return errors::ResourceExhausted(e.what()); } } std::string EmbeddingHashTableTfBridge::DebugString() const { return config_.DebugString(); } std::string EmbeddingHashTableTfBridge::Summary() const { return table_->DebugString(); } int32 EmbeddingHashTableTfBridge::dim_size() const { return dim_size_; } int32 EmbeddingHashTableTfBridge::slice_size() const { return table_->SliceSize(); } int64 EmbeddingHashTableTfBridge::max_update_ts_sec() const { return max_update_ts_sec_->load(); } int64 EmbeddingHashTableTfBridge::last_evict_ts_sec() const { return last_evict_ts_sec_->load(); } void EmbeddingHashTableTfBridge::set_last_evict_ts_sec( const int64_t last_evict_ts_sec) { *last_evict_ts_sec_ = last_evict_ts_sec; } bool EmbeddingHashTableTfBridge::IsServingEntryType() const { return config_.entry_config().entry_type() == hash_table::EntryConfig_EntryType_SERVING; } std::vector> EmbeddingHashTableTfBridge::TouchedKeySet() const { if (hash_set_) { return hash_set_->GetAndClear(); } return {}; } void EmbeddingHashTableTfBridge::SetHopscotchHashSet( HopscotchHashSet>* hash_set) { CHECK(hash_set_ == nullptr); hash_set_ = hash_set; } const EmbeddingHashTableConfig& EmbeddingHashTableTfBridge::GetConfig() const { return config_; } EmbeddingHashTableTfBridge::~EmbeddingHashTableTfBridge() { // Let the eviction thread stop evict_finished_ = true; if (evict_thread_) { evict_thread_->join(); } hash_set_ = nullptr; } } // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/runtime/ops/embedding_hash_table_tf_bridge.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_NATIVE_TRAINING_RUNTIME_OPS_EMBEDDING_HASH_TABLE_TF_BRIDGE_H_ #define MONOLITH_NATIVE_TRAINING_RUNTIME_OPS_EMBEDDING_HASH_TABLE_TF_BRIDGE_H_ #include #include #include #include #include "monolith/native_training/runtime/hash_table/embedding_hash_table.pb.h" #include "monolith/native_training/runtime/hash_table/embedding_hash_table_interface.h" #include "monolith/native_training/runtime/hopscotch/hopscotch_hash_set.h" #include "monolith/native_training/runtime/ops/hash_filter_tf_bridge.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/platform/status.h" namespace tensorflow { namespace monolith_tf { template using HopscotchHashSet = monolith::hopscotch::HopscotchHashSet; // A hash table which can be used in TF runtime. // It captures all potential exceptions and convert them into error. class EmbeddingHashTableTfBridge : public ResourceBase { public: using EntryDump = monolith::hash_table::EntryDump; static Status New(monolith::hash_table::EmbeddingHashTableConfig config, HashFilterTfBridge* hash_filter, EmbeddingHashTableTfBridge** new_bridge, const std::string& name, monolith::hash_table::GpuExtraArgs args = {}); ~EmbeddingHashTableTfBridge(); // BatchLookup |ids| and write it into |embeddings| Status BatchLookup(OpKernelContext* ctx, const int num_ids, int64_t* ids, float* out_embedding, int64_t* hit_fid_count) const; Status BatchLookupEntry(OpKernelContext* ctx, const int num_ids, int64_t* ids, EntryDump* out_entries) const; // Lookup |id| and write it into |embedding| Status Lookup(OpKernelContext* ctx, int64 id, float* out_embedding) const; Status LookupEntry(OpKernelContext* ctx, int64 id, EntryDump* out_entry) const; // TODO(leqi.zou): Unify the API here. // 1. Remove all batch APIs. // 2. Replace int64_t to int64 Status Assign(OpKernelContext* ctx, int num_ids, const int64_t* ids, const float* embeddings, int64_t update_time) const; // TODO(leqi.zou): Replace this API by AssignAdd2. Status AssignAdd(OpKernelContext* ctx, int64 id, const Tensor& tensor, int64_t update_time) const; Status AssignAdd2(int64 id, absl::Span value, int64_t update_time); Status Reinitialize(const int64_t* ids, int64_t num_ids, int* status); Status BatchOptimize(OpKernelContext* ctx, size_t num_ids, const int64_t* ids, const float* tensor, absl::Span learning_rates, int64_t update_time, bool enable_dedup, const int64_t global_step) const; Status Optimize(OpKernelContext* ctx, int64 id, absl::Span grads, absl::Span learning_rates, int64_t update_time, int64_t global_step) const; using DumpShard = monolith::hash_table::EmbeddingHashTableInterface::DumpShard; using DumpIterator = monolith::hash_table::EmbeddingHashTableInterface::DumpIterator; using WriteFn = monolith::hash_table::EmbeddingHashTableInterface::WriteFn; using LockCtx = monolith::hash_table::EmbeddingHashTableInterface::LockCtx; // For the functor injected, it is ok to throw exceptions. Status LockAll(std::unique_ptr* ctx); Status Save(OpKernelContext* ctx, DumpShard shard, WriteFn write_fn, DumpIterator* iter) const; Status Restore(OpKernelContext* ctx, DumpShard shard, std::function get_fn) const; void Clear() const { table_->Clear(); } int64_t Size() const { return table_->Size(); } int32 dim_size() const; int32 slice_size() const; int64 max_update_ts_sec() const; int64 last_evict_ts_sec() const; void set_last_evict_ts_sec(const int64_t last_evict_ts_sec); bool IsServingEntryType() const; std::string DebugString() const override; std::string Summary() const; void SetHopscotchHashSet( HopscotchHashSet>* hash_set); HopscotchHashSet>* GetHashSet() const { return hash_set_; } std::vector> TouchedKeySet() const; const monolith::hash_table::EmbeddingHashTableConfig& GetConfig() const; monolith::hash_table::EmbeddingHashTableInterface* GetTable() const { return table_.get(); } private: explicit EmbeddingHashTableTfBridge(HashFilterTfBridge* hash_filter) : hash_filter_(hash_filter) {} std::string name_; std::unique_ptr table_; monolith::hash_table::EmbeddingHashTableConfig config_; int64 dim_size_ = 0; std::unique_ptr> max_update_ts_sec_; HashFilterTfBridge* hash_filter_; std::unique_ptr evict_thread_; std::unique_ptr> last_evict_ts_sec_; mutex evict_mu_; bool evict_finished_ TF_GUARDED_BY(evict_mu_); HopscotchHashSet>* hash_set_ = nullptr; }; } // namespace monolith_tf } // namespace tensorflow #endif // MONOLITH_NATIVE_TRAINING_RUNTIME_OPS_EMBEDDING_HASH_TABLE_TF_BRIDGE_H_ ================================================ FILE: monolith/native_training/runtime/ops/file_metric_writer.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/runtime/ops/file_metric_writer.h" #include "absl/strings/str_format.h" #include "tensorflow/core/lib/io/record_writer.h" #include "tensorflow/core/platform/path.h" namespace monolith { namespace deep_insight { using tensorflow::Env; using tensorflow::Status; using tensorflow::io::RecordWriter; using tensorflow::io::RecordWriterOptions; FileMetricWriter::FileMetricWriter(std::string filename) : filename_(std::move(filename)), finished_(false), total_produce_(0), total_consume_(0), total_dump_(0) { LOG(INFO) << "deepinsight dump filename: " << filename_; if (!filename_.empty()) { Env* env = Env::Default(); std::string dirname(tensorflow::io::Dirname(filename_)); TF_CHECK_OK(env->RecursivelyCreateDir(dirname)); TF_CHECK_OK(env->NewWritableFile(filename_, &fp_)); queue_ = std::make_unique>(8192); thread_pool_ = std::make_unique(1); thread_pool_->Schedule([this]() { RecordWriterOptions options; RecordWriter writer(fp_.get(), options); while (!finished_ || !queue_->empty()) { std::string msg; bool ok = queue_->try_pop(msg, std::chrono::milliseconds(10)); if (ok) { ++total_consume_; LOG_EVERY_N_SEC(INFO, 300) << absl::StrFormat("Consume %ld records", total_consume_); Status s = writer.WriteRecord(msg); if (s.ok()) { ++total_dump_; LOG_EVERY_N_SEC(INFO, 300) << absl::StrFormat("Dump %ld records", total_dump_); } else { LOG(ERROR) << absl::StrFormat( "Failed to write record: %s, status=%s", msg, s.error_message()); } } else { LOG_EVERY_N_SEC(INFO, 300) << "Failed to try pop, queue maybe empty!"; } } LOG(INFO) << absl::StrFormat("Totally produce %ld records", total_produce_); LOG(INFO) << absl::StrFormat("Totally consume %ld records", total_consume_); LOG(INFO) << absl::StrFormat("Totally dump %ld records", total_dump_); TF_CHECK_OK(writer.Close()); TF_CHECK_OK(fp_->Close()); }); } } FileMetricWriter::~FileMetricWriter() { finished_ = true; } void FileMetricWriter::Write(const std::string& msg) { if (fp_) { while (true) { bool ok = queue_->try_push(msg, std::chrono::milliseconds(10)); if (ok) { ++total_produce_; LOG_EVERY_N_SEC(INFO, 300) << absl::StrFormat("Produce %ld records", total_produce_); break; } else { LOG_EVERY_N_SEC(INFO, 60) << "Failed to try push, queue maybe full!"; } } } } } // namespace deep_insight } // namespace monolith ================================================ FILE: monolith/native_training/runtime/ops/file_metric_writer.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_MONOLITH_NATIVE_TRAINING_RUNTIME_DEEP_INSIGHT_FILE_METRIC_WRITER_H_ #define MONOLITH_MONOLITH_NATIVE_TRAINING_RUNTIME_DEEP_INSIGHT_FILE_METRIC_WRITER_H_ #include #include #include #include "tensorflow/core/platform/file_system.h" #include "monolith/native_training/runtime/concurrency/queue.h" #include "monolith/native_training/runtime/concurrency/thread_pool.h" namespace monolith { namespace deep_insight { class FileMetricWriter { public: explicit FileMetricWriter(std::string filename); ~FileMetricWriter(); void Write(const std::string& msg); private: std::string filename_; std::atomic_bool finished_; int64_t total_produce_; int64_t total_consume_; int64_t total_dump_; std::unique_ptr fp_; std::unique_ptr> queue_; std::unique_ptr thread_pool_; }; } // namespace deep_insight } // namespace monolith #endif // MONOLITH_MONOLITH_NATIVE_TRAINING_RUNTIME_DEEP_INSIGHT_FILE_METRIC_WRITER_H_ ================================================ FILE: monolith/native_training/runtime/ops/file_metric_writer_test.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/runtime/ops/file_metric_writer.h" #include #include #include "glog/logging.h" #include "gtest/gtest.h" #include "tensorflow/core/lib/io/record_writer.h" namespace monolith { namespace deep_insight { using tensorflow::Env; using tensorflow::Status; using tensorflow::io::RecordWriter; using tensorflow::io::RecordWriterOptions; TEST(FileMetricWriterTest, Basic) { Env* env = Env::Default(); std::string filename; CHECK(env->LocalTempFilename(&filename)); FileMetricWriter writer(filename); writer.Write("hello"); writer.Write("world"); } } // namespace deep_insight } // namespace monolith ================================================ FILE: monolith/native_training/runtime/ops/file_ops.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "absl/synchronization/mutex.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/framework/resource_op_kernel.h" #include "tensorflow/core/lib/io/record_writer.h" #include "tensorflow/core/platform/path.h" #include "monolith/native_training/runtime/hash_table/embedding_hash_table.pb.h" namespace tensorflow { namespace monolith_tf { namespace { // It is a thin wrapper of GFile. Make it compatible with ResourceKernelOp // and thread safe. class FileResource : public ResourceBase { public: explicit FileResource(std::unique_ptr f, absl::string_view debugging_info) : f_(std::move(f)), debugging_info_(debugging_info), closed_(false) {} std::string DebugString() const override { return debugging_info_; } Status Close() { absl::MutexLock l(&mu_); closed_ = true; return f_->Close(); } Status Append(StringPiece data) { absl::MutexLock l(&mu_); return f_->Append(data); } Status AppendRecord(const string& serialized) { absl::MutexLock l(&mu_); if (!record_writer_) { record_writer_.reset(new tensorflow::io::RecordWriter( f_.get(), tensorflow::io::RecordWriterOptions::CreateRecordWriterOptions(""))); } return record_writer_->WriteRecord(serialized); } ~FileResource() { absl::MutexLock l(&mu_); if (!closed_) { auto s = f_->Close(); if (!s.ok()) { LOG(ERROR) << "Unable to close file " << debugging_info_ << " :" << s.ToString(); } } } private: absl::Mutex mu_; std::unique_ptr f_ ABSL_GUARDED_BY(mu_); std::unique_ptr record_writer_ ABSL_GUARDED_BY(mu_); const std::string debugging_info_; bool closed_ ABSL_GUARDED_BY(mu_); }; } // namespace class MonolithWritableFileOp : public ResourceOpKernel { public: explicit MonolithWritableFileOp(OpKernelConstruction* c) : ResourceOpKernel(c) { OP_REQUIRES_OK(c, c->GetAttr("filename", &filename_)); env_ = c->env(); } ~MonolithWritableFileOp() override {} private: Status CreateResource(FileResource** file_wrapper) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) override { const std::string dir = std::string(io::Dirname(filename_)); if (!env_->FileExists(dir).ok()) { TF_RETURN_IF_ERROR(env_->RecursivelyCreateDir(dir)); } std::unique_ptr f; TF_RETURN_IF_ERROR(env_->NewWritableFile(filename_, &f)); *file_wrapper = new FileResource(std::move(f), filename_); return Status::OK(); } std::string filename_; Env* env_; }; REGISTER_OP("MonolithWritableFile") .Output("handle: resource") .Attr("filename: string") .Attr("container: string = ''") .Attr("shared_name: string = ''") .SetIsStateful() .SetShapeFn(shape_inference::ScalarShape); REGISTER_KERNEL_BUILDER(Name("MonolithWritableFile").Device(DEVICE_CPU), MonolithWritableFileOp); class MonolithWritableFileCloseOp : public OpKernel { public: explicit MonolithWritableFileCloseOp(OpKernelConstruction* c) : OpKernel(c) {} void Compute(OpKernelContext* c) override { FileResource* f; OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &f)); core::ScopedUnref unref(f); OP_REQUIRES_OK(c, f->Close()); } }; REGISTER_OP("MonolithWritableFileClose") .Input("handle: resource") .SetIsStateful() .SetShapeFn(shape_inference::UnknownShape); REGISTER_KERNEL_BUILDER(Name("MonolithWritableFileClose").Device(DEVICE_CPU), MonolithWritableFileCloseOp); class MonolithWritableFileAppendOp : public OpKernel { public: explicit MonolithWritableFileAppendOp(OpKernelConstruction* c) : OpKernel(c) {} void Compute(OpKernelContext* c) override { FileResource* f; OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &f)); core::ScopedUnref unref(f); const auto& content = c->input(1).scalar()(); OP_REQUIRES_OK(c, f->Append(content)); } }; REGISTER_OP("MonolithWritableFileAppend") .Input("handle: resource") .Input("content: string") .SetIsStateful() .SetShapeFn(shape_inference::UnknownShape); REGISTER_KERNEL_BUILDER(Name("MonolithWritableFileAppend").Device(DEVICE_CPU), MonolithWritableFileAppendOp); class MonolithEntryDumpFileAppendOp : public OpKernel { public: explicit MonolithEntryDumpFileAppendOp(OpKernelConstruction* c) : OpKernel(c) {} void Compute(OpKernelContext* c) override { FileResource* f; OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &f)); core::ScopedUnref unref(f); const auto& item_id = c->input(1).flat(); const auto& bias = c->input(2).flat(); const auto& embedding = c->input(3).flat(); size_t batch_size = item_id.size(); CHECK_GT(batch_size, 0); CHECK_EQ(embedding.size() % batch_size, 0); size_t embedding_len = embedding.size() / batch_size; for (size_t batch_id = 0; batch_id < batch_size; batch_id++) { monolith::hash_table::EntryDump d; d.set_id(item_id(batch_id)); d.add_num(bias(batch_id)); for (size_t i = 0; i < embedding_len; i++) { d.add_num(embedding(batch_id * embedding_len + i)); } OP_REQUIRES_OK(c, f->AppendRecord(d.SerializeAsString())); } } }; REGISTER_OP("MonolithEntryDumpFileAppend") .Input("handle: resource") .Input("item_id: int64") .Input("bias: float") .Input("embedding: float") .SetIsStateful() .SetShapeFn(shape_inference::UnknownShape); REGISTER_KERNEL_BUILDER(Name("MonolithEntryDumpFileAppend").Device(DEVICE_CPU), MonolithEntryDumpFileAppendOp); } // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/runtime/ops/file_utils.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/runtime/ops/file_utils.h" #include "absl/strings/str_format.h" #include "re2/re2.h" #include "tensorflow/core/platform/errors.h" namespace tensorflow { namespace monolith_tf { namespace { const char* const kShardedFileFormat = "%s-%05d-of-%05d"; } std::string GetShardedFileName(absl::string_view basename, int shard, int nshards) { return absl::StrFormat(kShardedFileFormat, basename, shard, nshards); } Status ValidateShardedFiles(absl::string_view basename, absl::Span filenames, FileSpec* spec) { std::vector show; for (absl::string_view filename : filenames) { if (filename.substr(0, basename.size()) != basename) { return errors::InvalidArgument("Filename ", filename, " doesn't belong to ", basename); } absl::string_view suffix = filename.substr(basename.size()); int shard, nshards; // Ignore invalid files. if (!RE2::FullMatch(suffix, R"raw(-(\d{5})?-of-(\d{5})?)raw", &shard, &nshards)) { continue; } if (show.empty()) { show.resize(nshards); } if (nshards != (int)show.size()) { return errors::InvalidArgument("Filename ", filename, " doesn't match nshards. ", show.size()); } if (shard >= nshards) { return errors::InvalidArgument("Shard ", shard, "exceeds ", nshards, " for ", filename); } show[shard] = true; } if (show.empty()) { return errors::InvalidArgument("There is no valid sharded files for ", basename); } for (int i = 0; i < (int)show.size(); ++i) { if (!show[i]) { return errors::InvalidArgument("Shard ", i, " doesn't show up for ", basename); } } if (spec != nullptr) { *spec = FileSpec::ShardedFileSpec(basename, show.size()); } return Status::OK(); } FileSpec FileSpec::ShardedFileSpec(absl::string_view prefix, int nshards) { FileSpec spec; spec.type_ = FileSpec::SHARDED_FILES; spec.prefix_ = std::string(prefix); spec.nshards_ = nshards; return spec; } std::vector FileSpec::GetFilenames() const { std::vector filenames; switch (type_) { case FileSpec::SHARDED_FILES: for (int i = 0; i < nshards_; ++i) { filenames.push_back(GetShardedFileName(prefix_, i, nshards_)); } break; default: break; } return filenames; } } // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/runtime/ops/file_utils.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_NATIVE_TRAINING_RUNTIME_OPS_FILE_UTILS #define MONOLITH_NATIVE_TRAINING_RUNTIME_OPS_FILE_UTILS #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "tensorflow/core/platform/status.h" namespace tensorflow { namespace monolith_tf { // Returns sharded file name. std::string GetShardedFileName(absl::string_view basename, int shard, int nshards); // A spec reprsents a set of files. class FileSpec final { public: FileSpec() {} static FileSpec ShardedFileSpec(absl::string_view prefix, int nshards); std::vector GetFilenames() const; int nshards() const { return nshards_; } private: enum Type { UNKNOWN, SHARDED_FILES }; Type type_ = UNKNOWN; std::string prefix_; int nshards_ = 0; }; // Validates if filenames construct a valid file spec for base name. Status ValidateShardedFiles(absl::string_view basename, absl::Span filenames, FileSpec* spec = nullptr); } // namespace monolith_tf } // namespace tensorflow #endif // MONOLITH_NATIVE_TRAINING_RUNTIME_OPS_FILE_UTILS ================================================ FILE: monolith/native_training/runtime/ops/file_utils_test.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/runtime/ops/file_utils.h" #include #include "gmock/gmock.h" #include "gtest/gtest.h" #include "tensorflow/core/lib/core/status_test_util.h" namespace tensorflow { namespace monolith_tf { namespace { using ::testing::ElementsAre; TEST(ValidateShardedFilesTest, Basic) { FileSpec spec; TF_EXPECT_OK(ValidateShardedFiles("a/b", {"a/b-00000-of-00001"})); TF_EXPECT_OK(ValidateShardedFiles( "a/b", {"a/b-00000-of-00002", "a/b-00001-of-00002"}, &spec)); EXPECT_THAT(spec.nshards(), 2); TF_EXPECT_OK(ValidateShardedFiles( "a", {"a-00000-of-00001", "a-00000-of-00001-tmp-1234"})); TF_EXPECT_OK(ValidateShardedFiles( "a", {"a-00000-of-00001", "a-00000-of-00002-tmp-1234"})); EXPECT_FALSE(ValidateShardedFiles("a/b", {"a/b-00000-of-00002"}).ok()); EXPECT_FALSE( ValidateShardedFiles("a/b", {"a/b-00000-of-00001", "a/b-00001-of-00001"}) .ok()); EXPECT_FALSE( ValidateShardedFiles("a/b", {"a/b-00000-of-00001", "a/b-00000-of-00002", "a/b-00001-of-00002"}) .ok()); EXPECT_FALSE(ValidateShardedFiles("a/b", {"random-string"}).ok()); EXPECT_FALSE(ValidateShardedFiles("a/b", {"a/b-random-string"}).ok()); } TEST(ValidateShardedFilesTest, FileSpecTest) { auto spec = FileSpec::ShardedFileSpec("a/b", 2); EXPECT_THAT(spec.GetFilenames(), ElementsAre("a/b-00000-of-00002", "a/b-00001-of-00002")); } TEST(ValidateShardedFilesTest, LargeFileSet) { std::vector filenames; for (int i = 0; i < 100; ++i) { filenames.push_back(GetShardedFileName("/a", i, 100)); } TF_EXPECT_OK(ValidateShardedFiles("/a", filenames)); } } // namespace } // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/runtime/ops/fused_embedding_to_layout.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/runtime/ops/fused_embedding_to_layout.h" namespace tensorflow { namespace monolith_tf { namespace fused_layout { void *MemCopy(float *dest, const float *src, std::size_t count) { return std::memcpy(dest, src, count * sizeof(float)); } template void OptimizedSumpooling(const float *src, const int dim_num, void *init_ptr, float *dst, void *one_mutex_ptr = nullptr, int mean_pool_fid_num = 0) { std::mutex *one_mutex = static_cast(one_mutex_ptr); TInit *init = static_cast(init_ptr); if (one_mutex) { one_mutex->lock(); } if (init && *init) { if (mean_pool_fid_num) { for (size_t i = 0; i < dim_num; ++i) { dst[i] = (src[i] / mean_pool_fid_num); } } else { MemCopy(dst, src, dim_num); } *init = false; } else { // ::monolith::hash_table::ReduceSum(src, dst, dst, dim_num); if (mean_pool_fid_num) { for (size_t i = 0; i < dim_num; ++i) { dst[i] += (src[i] / mean_pool_fid_num); } } else { for (size_t i = 0; i < dim_num; ++i) { dst[i] += src[i]; } } } if (one_mutex) { one_mutex->unlock(); } } NoneLayout::NoneLayout(const std::string &name, const OutConfig &out_conf, OpInputList &tensor_list, int &start_idx) : Layout(name, out_conf) { int offset = 0; CHECK(out_conf.slice_configs_size() == out_conf.shape_size()); for (const SliceConfig &slice_conf : out_conf.slice_configs()) { slice_to_tensor_.insert( {GetKey(slice_conf), {&tensor_list[start_idx++], offset++}}); } } // op output NoneLayout::NoneLayout(const std::string &name, const OutConfig &out_conf, OpOutputList &tensor_list, int &start_idx) : Layout(name, out_conf) { int offset = 0; CHECK(out_conf.slice_configs_size() == out_conf.shape_size()); for (const SliceConfig &slice_conf : out_conf.slice_configs()) { slice_to_tensor_.insert( {GetKey(slice_conf), {tensor_list[start_idx++], offset++}}); } } PtrWrapper NoneLayout::GetSlice(int row_id, const SliceConfig &slice_conf) { auto key = GetKey(slice_conf); auto it = slice_to_tensor_.find(key); if (it != slice_to_tensor_.end()) { auto &layout_info = it->second; const LayoutShape &shape = out_config_.shape(layout_info.second); if (slice_conf.pooling_type() == PoolingType::FIRSTN) { CHECK_EQ(shape.dims_size(), 3); // none seq [batch_size, max_seq_len, num_dim] const auto tensor = layout_info.first->tensor(); return PtrWrapper{&tensor(row_id, 0, 0), shape.dims(1) * shape.dims(2), layout_info.first->NumElements()}; } else { CHECK_EQ(shape.dims_size(), 2); // none [batch_size, num_dim] const auto mat = layout_info.first->matrix(); return PtrWrapper{&mat(row_id, 0), shape.dims(1), layout_info.first->NumElements()}; } } } DefaultLayout::DefaultLayout(const std::string &name, const OutConfig &out_conf, OpInputList &tensor_list, int &start_idx) : Layout(name, out_conf) { int offset = 0; CHECK_EQ(out_conf.shape_size(), 1); CHECK_NE(out_conf.out_type(), OutType::NONE); for (const SliceConfig &slice_conf : out_conf.slice_configs()) { slice_to_tensor_.insert( {GetKey(slice_conf), {&tensor_list[start_idx], offset}}); if (out_conf.out_type() == OutType::STACK) { offset += 1; } else if (out_conf.out_type() == OutType::CONCAT) { offset += slice_conf.end() - slice_conf.start(); } else { CHECK(out_conf.out_type() == OutType::ADDN); } } start_idx++; } DefaultLayout::DefaultLayout(const std::string &name, const OutConfig &out_conf, OpOutputList &tensor_list, int &start_idx) : Layout(name, out_conf) { int offset = 0; CHECK_EQ(out_conf.shape_size(), 1); CHECK_NE(out_conf.out_type(), OutType::NONE); for (const SliceConfig &slice_conf : out_conf.slice_configs()) { slice_to_tensor_.insert( {GetKey(slice_conf), {tensor_list[start_idx], offset}}); if (out_conf.out_type() == OutType::STACK) { offset += 1; } else if (out_conf.out_type() == OutType::CONCAT) { offset += slice_conf.end() - slice_conf.start(); } else { CHECK(out_conf.out_type() == OutType::ADDN); } } start_idx++; } PtrWrapper DefaultLayout::GetSlice(int row_id, const SliceConfig &slice_conf) { auto key = GetKey(slice_conf); auto it = slice_to_tensor_.find(key); if (it != slice_to_tensor_.end()) { auto &layout_info = it->second; CHECK_EQ(out_config_.shape_size(), 1); const LayoutShape &shape = out_config_.shape(0); // TODO(zhangru): support concat/stack seq if (slice_conf.pooling_type() == PoolingType::FIRSTN) { CHECK(shape.dims_size() > 2 && shape.dims_size() < 5); if (shape.dims_size() == 3) { // concat [batch_size, max_seq_len, num_dims]; // add_n [batch_size, , num_dim]; const auto tensor = layout_info.first->tensor(); return PtrWrapper{&tensor(row_id, 0, layout_info.second), shape.dims(1) * shape.dims(2), layout_info.first->NumElements()}; } else { // if (shape.dims_size() == 4) { // stack [batch_size, features_size, max_seq_len , num_dim]; const auto tensor = layout_info.first->tensor(); return PtrWrapper{&tensor(row_id, 0, 0, layout_info.second), shape.dims(1) * shape.dims(2) * shape.dims(3), layout_info.first->NumElements()}; } } else { CHECK(shape.dims_size() > 1 && shape.dims_size() < 4); if (shape.dims_size() == 2) { // concat [batch_size, num_dims]; // add_n [batch_size, num_dim]; const auto mat = layout_info.first->matrix(); return PtrWrapper{&mat(row_id, layout_info.second), shape.dims(1), layout_info.first->NumElements()}; } else { // if (shape.dims_size() == 3) { // stack [batch_size, features_size , num_dim]; const auto tensor = layout_info.first->tensor(); return PtrWrapper{&tensor(row_id, layout_info.second, 0), shape.dims(1) * shape.dims(2), layout_info.first->NumElements()}; } } } } MonolithEmbeddingToLayoutBase::MonolithEmbeddingToLayoutBase( OpKernelConstruction *ctx, int version) : OpKernel(ctx), version_(version) { std::string serialized; OP_REQUIRES_OK(ctx, ctx->GetAttr("feature_cfgs", &serialized)); OP_REQUIRES( ctx, feature_cfgs_.ParseFromArray(serialized.data(), serialized.size()), errors::FailedPrecondition("Failed to parse the feature_cfgs_.")); OP_REQUIRES_OK(ctx, ctx->GetAttr("variant_type", &variant_type_)); if (version_ >= 2) { OP_REQUIRES_OK(ctx, ctx->GetAttr("ps_num", &ps_num_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("parallel_flag", ¶llel_flag_)); } // set max_sequence_length/pooling_type/slice_idx/feature_idx here: // use the index in the sorted feature_names_used as feature_idx. const auto &feature_names_used = feature_cfgs_.feature_configs(); std::vector feature_names; std::map> table_feature_dim_map; for (const auto &feature_conf_pair : feature_names_used) { feature_names.push_back(feature_conf_pair.first); int dims_sum = 0; for (size_t slice_idx = 0; slice_idx < feature_conf_pair.second.slice_dims_size(); slice_idx++) { dims_sum += feature_conf_pair.second.slice_dims(slice_idx); } table_feature_dim_map[feature_conf_pair.second.table()] [feature_conf_pair.first] = dims_sum; } std::sort(feature_names.begin(), feature_names.end()); { table_feature_dim_.resize(table_feature_dim_map.size()); int i = 0; // table_feature_dim_map is map, already sort for (auto &iter : table_feature_dim_map) { auto &table_name = iter.first; auto &record_dims = table_feature_dim_[i]; std::vector feature_name_tmp; auto &feature_dim_map = iter.second; for (auto &sub_iter : feature_dim_map) { feature_name_tmp.push_back(sub_iter.first); } std::sort(feature_name_tmp.begin(), feature_name_tmp.end()); record_dims.resize(feature_name_tmp.size()); for (size_t j = 0; j < feature_name_tmp.size(); ++j) { record_dims[j] = feature_dim_map[feature_name_tmp[j]]; } ++i; } } std::vector> slice_idx_per_feature; // feature_index: {start: slice_idx} slice_idx_per_feature.resize(feature_names_used.size()); for (size_t feature_idx = 0; feature_idx < feature_names.size(); feature_idx++) { std::unordered_map start2slice_idx; const auto &feature_name = feature_names[feature_idx]; const auto &feat_conf = feature_names_used.at(feature_name); int slice_prefix_sum_ = 0; for (size_t slice_idx = 0; slice_idx < feat_conf.slice_dims_size(); slice_idx++) { start2slice_idx[slice_prefix_sum_] = slice_idx; slice_prefix_sum_ += feat_conf.slice_dims(slice_idx); } max_slice_num_ = std::max(max_slice_num_, feat_conf.slice_dims_size()); slice_idx_per_feature[feature_idx] = start2slice_idx; } auto *out_configs = feature_cfgs_.mutable_out_configs(); for (auto &pair : *out_configs) { layout_names_.push_back(pair.first); for (auto &slice_config : *pair.second.mutable_slice_configs()) { const auto &feature_name = slice_config.feature_name(); const auto &feat_conf = feature_names_used.at(feature_name); slice_config.set_max_sequence_length(feat_conf.max_sequence_length()); slice_config.set_pooling_type(feat_conf.pooling_type()); CHECK(!(pair.second.out_type() == OutType::ADDN && slice_config.pooling_type() == PoolingType::FIRSTN)); auto it = std::find(feature_names.begin(), feature_names.end(), feature_name); if (it != feature_names.end()) { int feature_idx = it - feature_names.begin(); slice_config.set_feature_idx(feature_idx); slice_config.set_slice_idx( slice_idx_per_feature[feature_idx][slice_config.start()]); } } } std::sort(layout_names_.begin(), layout_names_.end()); } MonolithEmbeddingToLayoutOp::MonolithEmbeddingToLayoutOp( OpKernelConstruction *ctx, int version /* = 1*/) : MonolithEmbeddingToLayoutBase(ctx, version) {} void MonolithEmbeddingToLayoutOp::Compute(OpKernelContext *ctx) { // Grab the input tensor OpInputList embeddings_list; OP_REQUIRES_OK(ctx, ctx->input_list("embeddings_list", &embeddings_list)); const Tensor *fids_offset_input; OP_REQUIRES_OK(ctx, ctx->input("fid_offset", &fids_offset_input)); const Tensor *feature_offset_input; OP_REQUIRES_OK(ctx, ctx->input("feature_offset", &feature_offset_input)); const Tensor *nfl_offset_input; OP_REQUIRES_OK(ctx, ctx->input("nfl_offset", &nfl_offset_input)); const Tensor *batch_size_tensor; OP_REQUIRES_OK(ctx, ctx->input("batch_size", &batch_size_tensor)); const Tensor *nfl_size_tensor; const Tensor *feature_size_tensor; const Tensor *fid_size_tensor; const Tensor *emb_size_tensor; if (GetVersion() == 5) { OP_REQUIRES_OK(ctx, ctx->input("nfl_size", &nfl_size_tensor)); OP_REQUIRES_OK(ctx, ctx->input("feature_size", &feature_size_tensor)); OP_REQUIRES_OK(ctx, ctx->input("fid_size", &fid_size_tensor)); OP_REQUIRES_OK(ctx, ctx->input("emb_size", &emb_size_tensor)); } const auto fids_offset_vec = fids_offset_input->flat(); int total_fid_num = fids_offset_input->dim_size(0); const auto feature_offset_vec = feature_offset_input->flat(); int total_feature_num = feature_offset_input->dim_size(0); const auto nfl_offset_vec = nfl_offset_input->flat(); int total_nfl_num = nfl_offset_input->dim_size(0); int req_num = 1; int32 max_batch_size = 0; std::vector each_req_batch_size_offset(1, 0); std::vector each_req_nfl_offset(1, 0); std::vector each_req_feature_offset(1, 0); std::vector each_req_fid_offset(1, 0); if (GetVersion() == 5) { const auto batch_size_vec = batch_size_tensor->flat(); req_num = batch_size_tensor->dim_size(0); for (size_t i = 0; i < req_num; ++i) { each_req_batch_size_offset.push_back( each_req_batch_size_offset[i] + batch_size_vec(i)); max_batch_size = std::max(batch_size_vec(i), max_batch_size); } const auto nfl_size_vec = nfl_size_tensor->flat(); for (size_t i = 0; i < req_num; ++i) { each_req_nfl_offset.push_back(each_req_nfl_offset[i] + nfl_size_vec(i)); } CHECK_EQ(each_req_nfl_offset.back(), total_nfl_num); const auto feature_size_vec = feature_size_tensor->flat(); for (size_t i = 0; i < req_num; ++i) { each_req_feature_offset.push_back( each_req_feature_offset[i] + feature_size_vec(i)); } CHECK_EQ(each_req_feature_offset.back(), total_feature_num); const auto fid_size_vec = fid_size_tensor->flat(); for (size_t i = 0; i < req_num; ++i) { each_req_fid_offset.push_back(each_req_fid_offset[i] + fid_size_vec(i)); } CHECK_EQ(each_req_fid_offset.back(), total_fid_num); } else { max_batch_size = batch_size_tensor->scalar()(); each_req_batch_size_offset.push_back(max_batch_size); each_req_nfl_offset.push_back(total_nfl_num); each_req_feature_offset.push_back(total_feature_num); each_req_fid_offset.push_back(total_fid_num); } req_sum_ += req_num; process_num_++; LOG_EVERY_N_SEC(INFO, 60) << "input avg req num: " << req_sum_ * 1.0 / process_num_; OpOutputList layout_tensor_list; OP_REQUIRES_OK(ctx, ctx->output_list("tensors", &layout_tensor_list)); std::vector embeddings_data; if (GetVersion() == 2) { CHECK_EQ(req_num, 1); OpInputList fid_list_row_split; OP_REQUIRES_OK(ctx, ctx->input_list("fid_list_row_split", &fid_list_row_split)); int ps_num = GetPsNum(); const std::vector> &table_feature_dim = GetFeatureInTableDim(); embeddings_data.reserve(GetFeatureCfgs().feature_configs_size() * ps_num); CHECK_EQ(embeddings_list.size(), ps_num * table_feature_dim.size()); CHECK_EQ(embeddings_list.size(), fid_list_row_split.size()); for (size_t table_i = 0; table_i < table_feature_dim.size(); ++table_i) { auto &feature_dims = table_feature_dim[table_i]; for (size_t ps_i = 0; ps_i < ps_num; ++ps_i) { int emb_index = table_i * ps_num + ps_i; auto embeddings_flat = embeddings_list[emb_index].flat(); auto embeddings_size = embeddings_flat.size(); auto embeddings_ptr = embeddings_flat.data(); auto fid_list_row_split_flat = fid_list_row_split[emb_index].flat(); CHECK_EQ(static_cast(feature_dims.size() + 1), fid_list_row_split_flat.size()); int pre_offset = 0; int pre_emb_offset = 0; for (size_t feature_i = 0; feature_i < feature_dims.size(); ++feature_i) { int dim = feature_dims[feature_i]; int offset = fid_list_row_split_flat(feature_i + 1); int fid_count = (offset - pre_offset); embeddings_data.push_back(PtrWrapper{embeddings_ptr + pre_emb_offset, dim, fid_count * dim}); pre_offset = offset; pre_emb_offset += fid_count * dim; CHECK(pre_emb_offset <= embeddings_size); } } } } else if (GetVersion() == 3) { embeddings_data.reserve(embeddings_list.size()); for (size_t i = 0; i < embeddings_list.size(); ++i) { const auto &embeddings_mat_ptr_ = embeddings_list[i].flat().data(); embeddings_data.push_back(PtrWrapper( {embeddings_mat_ptr_, 1, embeddings_list[i].flat().size()})); } } else if (GetVersion() == 4) { int ps_num = GetPsNum(); const std::vector> &table_feature_dim = GetFeatureInTableDim(); CHECK_EQ(embeddings_list.size(), 1); const auto embeddings_list_flat = embeddings_list[0].flat(); const Tensor *fid_list_emb_row_lenth_tensor; OP_REQUIRES_OK(ctx, ctx->input("fid_list_emb_row_lenth", &fid_list_emb_row_lenth_tensor)); const auto fid_list_emb_row_lenth_flat = fid_list_emb_row_lenth_tensor->flat(); CHECK_EQ(fid_list_emb_row_lenth_flat.size(), table_feature_dim.size() * ps_num); embeddings_data.resize(req_num * fid_list_emb_row_lenth_flat.size()); int pre_count = 0; for (size_t i = 0; i < req_num * fid_list_emb_row_lenth_flat.size(); ++i) { int req_i = i / fid_list_emb_row_lenth_flat.size(); int table_idx = (i % fid_list_emb_row_lenth_flat.size()) % table_feature_dim.size(); int ps_index = (i % fid_list_emb_row_lenth_flat.size()) / table_feature_dim.size(); int index = req_i * fid_list_emb_row_lenth_flat.size() + table_idx * ps_num + ps_index; embeddings_data[index].ptr = embeddings_list_flat.data() + pre_count; embeddings_data[index].offset = 1; embeddings_data[index].count = fid_list_emb_row_lenth_flat(i); pre_count += fid_list_emb_row_lenth_flat(i); } CHECK_EQ(pre_count, req_num * embeddings_list_flat.size()); } else if (GetVersion() == 5) { const auto emb_size_vec = emb_size_tensor->flat(); std::vector> each_req_emb_offset( embeddings_list.size(), std::vector(req_num + 1, 0)); for (size_t i = 0; i < embeddings_list.size(); ++i) { for (size_t req_i = 0; req_i < req_num; ++req_i) { each_req_emb_offset[i][req_i + 1] = each_req_emb_offset[i][req_i] + emb_size_vec(i + req_i * embeddings_list.size()); } CHECK_EQ(each_req_emb_offset[i].back(), embeddings_list[i].flat().size()); } embeddings_data.reserve(req_num * embeddings_list.size()); for (size_t req_i = 0; req_i < req_num; req_i++) { for (size_t i = 0; i < embeddings_list.size(); ++i) { const auto &embeddings_mat_ptr_ = embeddings_list[i].flat().data(); embeddings_data.push_back( PtrWrapper({embeddings_mat_ptr_ + each_req_emb_offset[i][req_i], 1, each_req_emb_offset[i][req_i + 1] - each_req_emb_offset[i][req_i]})); } } } else { CHECK_EQ(req_num, 1); embeddings_data.reserve(embeddings_list.size()); for (size_t i = 0; i < embeddings_list.size(); ++i) { const auto &embeddings_mat_ptr_ = embeddings_list[i].flat().data(); embeddings_data.push_back(PtrWrapper( {embeddings_mat_ptr_, embeddings_list[i].dim_size(1), embeddings_list[i].dim_size(0) * embeddings_list[i].dim_size(1)})); } } { auto activity = std::make_unique([]() { return "AllocateTensors"; }); int offset = 0; const auto &out_configs = GetFeatureCfgs().out_configs(); for (const auto &layout_name : GetLayoutNames()) { const OutConfig &out_conf = out_configs.at(layout_name); for (const auto shape : out_conf.shape()) { Tensor *tensor; TensorShape tensor_shape; for (size_t i = 0; i < shape.dims_size(); ++i) { if (i == 0) { tensor_shape.AddDim(shape.dims(i) == -1 ? each_req_batch_size_offset.back() : shape.dims(i)); } else { CHECK_GT(shape.dims(i), 0); tensor_shape.AddDim(shape.dims(i)); } } OP_REQUIRES_OK( ctx, layout_tensor_list.allocate(offset++, tensor_shape, &tensor)); } } } int offset = 0; std::vector> layouts; { auto activity = std::make_unique([]() { return "CreateLayout"; }); for (const auto &layout_name : GetLayoutNames()) { const OutConfig &out_conf = GetFeatureCfgs().out_configs().at(layout_name); switch (out_conf.out_type()) { case OutType::NONE: layouts.push_back(std::make_shared( layout_name, out_conf, layout_tensor_list, offset)); break; default: layouts.push_back(std::make_shared( layout_name, out_conf, layout_tensor_list, offset)); break; } } } TaskRun(layouts, embeddings_data, fids_offset_vec.data(), total_fid_num, feature_offset_vec.data(), total_feature_num, nfl_offset_vec.data(), total_nfl_num, max_batch_size, each_req_batch_size_offset, each_req_nfl_offset, each_req_feature_offset, each_req_fid_offset, req_num, ctx, &layout_tensor_list); } void ForwardTaskRunImpl(int slice_conf_i, int dim_num, int64 nfl_idx, ::monolith::io::proto::OutType out_type, ::monolith::io::proto::PoolingType pooling_type, int max_sequence_length, int start, const uint64 *fids_offset_vec, int total_fid_num, const int32 *feature_offset_vec, int total_feature_num, const uint32 *nfl_offset_vec, int total_nfl_num, int batch_size, const PtrWrapper *embeddings_data, int embeddings_data_size, PtrWrapper *ptr_info_ptr) { PtrWrapper &ptr_info = *ptr_info_ptr; bool is_shared; int nfl_offset, feature_num; GetFeatureInfo(nfl_idx, nfl_offset_vec, total_nfl_num, total_feature_num, &is_shared, &nfl_offset, &feature_num); if (!feature_num) return; // nfl exits std::unique_ptr tmp; if (is_shared && (out_type == OutType::ADDN)) { tmp.reset(new float[dim_num]()); } int feature_idx = nfl_offset + 0; for (size_t index = 0; index < batch_size; ++index) { int temp_offset = index * ptr_info.offset; if (pooling_type == PoolingType::FIRSTN) { CHECK(temp_offset + max_sequence_length * dim_num <= ptr_info.count); } else { CHECK(temp_offset + dim_num <= ptr_info.count); } if (!is_shared || index == 0) { bool init = (out_type != OutType::ADDN) || tmp; GatherEmb( feature_idx, max_sequence_length, pooling_type, dim_num, start, embeddings_data, embeddings_data_size, fids_offset_vec, total_fid_num, feature_offset_vec, total_feature_num, const_cast(tmp ? tmp.get() : ptr_info.ptr + temp_offset), OptimizedSumpooling, MemCopy, nullptr, nullptr, DefaultGetInitFunc, &init); if (tmp) { bool init_tmp = (out_type != OutType::ADDN) || (slice_conf_i == 0); OptimizedSumpooling( tmp.get(), dim_num, &init_tmp, const_cast(ptr_info.ptr + temp_offset)); } feature_idx++; } else { if (tmp) { bool init_tmp = (slice_conf_i == 0); // && index == 0 OptimizedSumpooling( tmp.get(), dim_num, &init_tmp, const_cast(ptr_info.ptr + temp_offset)); } else { switch (pooling_type) { case PoolingType::SUM: case PoolingType::MEAN: MemCopy(const_cast(ptr_info.ptr + temp_offset), ptr_info.ptr, dim_num); break; case PoolingType::FIRSTN: MemCopy(const_cast(ptr_info.ptr + temp_offset), ptr_info.ptr, dim_num * max_sequence_length); break; default: break; } } } } } void MonolithEmbeddingToLayoutOp::TaskRun( const std::vector> &layouts, const std::vector &embeddings_data, const uint64 *fids_offset_vec, int total_fid_num, const int32 *feature_offset_vec, int total_feature_num, const uint32 *nfl_offset_vec, int total_nfl_num, int batch_size, const std::vector &each_req_batch_size_offset, const std::vector &each_req_nfl_offset, const std::vector &each_req_feature_offset, const std::vector &each_req_fid_offset, int req_num, OpKernelContext *ctx, OpOutputList *layout_tensor_list) { CHECK_EQ(req_num, 1); for (int32 idx = 0; idx < layout_tensor_list->size(); ++idx) { (*layout_tensor_list)[idx]->flat().setZero(); } auto gather_emb_fn = [&, this](int start, int end) { for (int64 para_i = start; para_i < end; ++para_i) { auto &layout = layouts.at(para_i); // CHECK(end - start == 1); const ::google::protobuf::RepeatedPtrField &layout_slice_configs = layout->GetSliceConfig(); for (uint slice_conf_i = 0; slice_conf_i < layout_slice_configs.size(); ++slice_conf_i) { const SliceConfig &slice_conf = layout_slice_configs[slice_conf_i]; int dim_num = slice_conf.end() - slice_conf.start(); PtrWrapper ptr_info = layout->GetSlice(0, slice_conf); const int64 nfl_idx = slice_conf.feature_idx(); ForwardTaskRunImpl(slice_conf_i, dim_num, nfl_idx, layout->out_type(), slice_conf.pooling_type(), slice_conf.max_sequence_length(), slice_conf.start(), fids_offset_vec, total_fid_num, feature_offset_vec, total_feature_num, nfl_offset_vec, total_nfl_num, batch_size, embeddings_data.data(), embeddings_data.size(), &ptr_info); } } }; { auto activity = std::make_unique([]() { return "GatherEmbFn"; }); int parallel_flag = GetParallelFlag(); if (parallel_flag == 0) { for (int i = 0; i < layouts.size(); ++i) { gather_emb_fn(i, i + 1); } } else { auto workers = ctx->device()->tensorflow_cpu_worker_threads()->workers; workers->ParallelFor( layouts.size(), thread::ThreadPool::SchedulingParams( thread::ThreadPool::SchedulingStrategy::kFixedBlockSize, absl::nullopt, 1), // block_size gather_emb_fn); } } } class MonolithEmbeddingToLayoutOpV2 : public MonolithEmbeddingToLayoutOp { public: explicit MonolithEmbeddingToLayoutOpV2(OpKernelConstruction *ctx) : MonolithEmbeddingToLayoutOp(ctx, 2) {} }; class MonolithEmbeddingToLayoutOpV3 : public MonolithEmbeddingToLayoutOp { public: explicit MonolithEmbeddingToLayoutOpV3(OpKernelConstruction *ctx) : MonolithEmbeddingToLayoutOp(ctx, 3) {} }; class MonolithEmbeddingToLayoutOpV4 : public MonolithEmbeddingToLayoutOp { public: explicit MonolithEmbeddingToLayoutOpV4(OpKernelConstruction *ctx) : MonolithEmbeddingToLayoutOp(ctx, 4) {} }; class MonolithEmbeddingToLayoutOpV5 : public MonolithEmbeddingToLayoutOp { public: explicit MonolithEmbeddingToLayoutOpV5(OpKernelConstruction *ctx) : MonolithEmbeddingToLayoutOp(ctx, 5) {} }; MonolithEmbeddingToLayoutGradOp::MonolithEmbeddingToLayoutGradOp( OpKernelConstruction *ctx, int version /* = 1*/) : MonolithEmbeddingToLayoutBase(ctx, version) {} void MonolithEmbeddingToLayoutGradOp::Compute(OpKernelContext *ctx) { // Grab the input tensor OpInputList embeddings_list; OP_REQUIRES_OK(ctx, ctx->input_list("embeddings_list", &embeddings_list)); const Tensor *fids_offset_input; OP_REQUIRES_OK(ctx, ctx->input("fid_offset", &fids_offset_input)); const Tensor *feature_offset_input; OP_REQUIRES_OK(ctx, ctx->input("feature_offset", &feature_offset_input)); const Tensor *nfl_offset_input; OP_REQUIRES_OK(ctx, ctx->input("nfl_offset", &nfl_offset_input)); const Tensor *batch_size_tensor; OP_REQUIRES_OK(ctx, ctx->input("batch_size", &batch_size_tensor)); OpInputList tensors_grad; OP_REQUIRES_OK(ctx, ctx->input_list("tensors_grad", &tensors_grad)); const auto fids_offset_vec = fids_offset_input->flat(); int total_fid_num = fids_offset_input->dim_size(0); const auto feature_offset_vec = feature_offset_input->flat(); int total_feature_num = feature_offset_input->dim_size(0); const auto nfl_offset_vec = nfl_offset_input->flat(); int total_nfl_num = nfl_offset_input->dim_size(0); int32 batch_size = batch_size_tensor->scalar()(); std::vector> ufid_grads_info; OpOutputList embeddings_grad_list; OP_REQUIRES_OK( ctx, ctx->output_list("embeddings_grad_list", &embeddings_grad_list)); std::vector embeddings_grads_data; int init_counter = 0; if (GetVersion() == 2) { OpInputList fid_list_row_split; OP_REQUIRES_OK(ctx, ctx->input_list("fid_list_row_split", &fid_list_row_split)); int ps_num = GetPsNum(); const std::vector> &table_feature_dim = GetFeatureInTableDim(); ufid_grads_info.reserve(GetFeatureCfgs().feature_configs_size() * ps_num); embeddings_grads_data.reserve(GetFeatureCfgs().feature_configs_size() * ps_num); CHECK_EQ(embeddings_list.size(), ps_num * table_feature_dim.size()); CHECK_EQ(embeddings_list.size(), fid_list_row_split.size()); for (size_t table_i = 0; table_i < table_feature_dim.size(); ++table_i) { auto &feature_dims = table_feature_dim[table_i]; for (size_t ps_i = 0; ps_i < ps_num; ++ps_i) { int emb_index = table_i * ps_num + ps_i; Tensor *tensor; OP_REQUIRES_OK( ctx, embeddings_grad_list.allocate( emb_index, embeddings_list[emb_index].shape(), &tensor)); auto embeddings_grad_flat = embeddings_grad_list[emb_index]->flat(); auto embeddings_grad_size = embeddings_grad_flat.size(); auto embeddings_grad_ptr = embeddings_grad_flat.data(); auto fid_list_row_split_flat = fid_list_row_split[emb_index].flat(); CHECK_EQ(static_cast(feature_dims.size() + 1), fid_list_row_split_flat.size()); int pre_offset = 0; int pre_emb_offset = 0; for (size_t feature_i = 0; feature_i < feature_dims.size(); ++feature_i) { int dim = feature_dims[feature_i]; int offset = fid_list_row_split_flat(feature_i + 1); int fid_count = (offset - pre_offset); embeddings_grads_data.push_back(PtrWrapper{ embeddings_grad_ptr + pre_emb_offset, dim, fid_count * dim}); ufid_grads_info.emplace_back(std::make_pair(init_counter, fid_count)); pre_offset = offset; pre_emb_offset += fid_count * dim; CHECK(pre_emb_offset <= embeddings_grad_size); init_counter += fid_count; } } } } else if (GetVersion() == 3 || GetVersion() == 5) { embeddings_grads_data.reserve(embeddings_list.size()); for (size_t i = 0; i < embeddings_list.size(); ++i) { Tensor *tensor; OP_REQUIRES_OK(ctx, embeddings_grad_list.allocate( i, embeddings_list[i].shape(), &tensor)); embeddings_grads_data.push_back( PtrWrapper({embeddings_grad_list[i]->flat().data(), 1, embeddings_grad_list[i]->flat().size()})); } } else if (GetVersion() == 4) { int ps_num = GetPsNum(); const std::vector> &table_feature_dim = GetFeatureInTableDim(); CHECK_EQ(embeddings_list.size(), 1); Tensor *embeddings_grad_list_tensor; OP_REQUIRES_OK( ctx, embeddings_grad_list.allocate(0, {embeddings_list[0].shape()}, &embeddings_grad_list_tensor)); const auto embeddings_grad_list_flat = embeddings_grad_list_tensor->flat(); const Tensor *fid_list_emb_row_lenth_tensor; OP_REQUIRES_OK(ctx, ctx->input("fid_list_emb_row_lenth", &fid_list_emb_row_lenth_tensor)); const auto fid_list_emb_row_lenth_flat = fid_list_emb_row_lenth_tensor->flat(); CHECK_EQ(fid_list_emb_row_lenth_flat.size(), table_feature_dim.size() * ps_num); embeddings_grads_data.resize(fid_list_emb_row_lenth_flat.size()); int pre_count = 0; for (size_t i = 0; i < fid_list_emb_row_lenth_flat.size(); ++i) { int table_idx = i % table_feature_dim.size(); int ps_index = i / table_feature_dim.size(); int index = table_idx * ps_num + ps_index; embeddings_grads_data[index].ptr = embeddings_grad_list_flat.data() + pre_count; embeddings_grads_data[index].offset = 1; embeddings_grads_data[index].count = fid_list_emb_row_lenth_flat(i); pre_count += fid_list_emb_row_lenth_flat(i); } CHECK_EQ(pre_count, embeddings_grad_list_flat.size()); } else { embeddings_grads_data.reserve(embeddings_list.size()); for (size_t i = 0; i < embeddings_list.size(); ++i) { Tensor *tensor; OP_REQUIRES_OK(ctx, embeddings_grad_list.allocate( i, embeddings_list[i].shape(), &tensor)); int dim = embeddings_list[i].dim_size(1); int fid_count = embeddings_list[i].dim_size(0); embeddings_grads_data.push_back(PtrWrapper{ embeddings_grad_list[i]->flat().data(), dim, fid_count * dim}); ufid_grads_info.emplace_back(std::make_pair(init_counter, fid_count)); init_counter += fid_count; } } // wrapper of bool for avoid : // invalid initialization of non-const reference of type 'bool&' from an // rvalue of type 'bool' GroupA init(init_counter, GetMaxSliceNum()); int offset = 0; std::vector> layouts; for (const auto &layout_name : GetLayoutNames()) { const OutConfig &out_conf = GetFeatureCfgs().out_configs().at(layout_name); switch (out_conf.out_type()) { case OutType::NONE: layouts.push_back(std::make_shared(layout_name, out_conf, tensors_grad, offset)); break; default: layouts.push_back(std::make_shared( layout_name, out_conf, tensors_grad, offset)); break; } } TaskRun(layouts, &ufid_grads_info, fids_offset_vec.data(), total_fid_num, feature_offset_vec.data(), total_feature_num, nfl_offset_vec.data(), total_nfl_num, batch_size, ctx, &embeddings_grad_list, &embeddings_grads_data, (GetVersion() == 3 || GetVersion() == 4 || GetVersion() == 5) ? nullptr : &init); } static constexpr int NUM_LOCKS = 512; void *ScatterGradGetMutexFuncFunc(void *main_params, int32 index1, int32 index2) { std::mutex *mutex_list = static_cast(main_params); int lock_idx = index1 * index2; int mutex_idx = lock_idx % NUM_LOCKS; auto one_mutex = mutex_list + mutex_idx; return one_mutex; } struct ScatterGradGetInitFuncParams { int slice_conf_slice_idx; const std::vector> *ufid_grads_info; GroupA *init; }; void *ScatterGradGetInitFunc(void *main_params, int32 index1, int32 index2) { ScatterGradGetInitFuncParams *params = static_cast(main_params); const auto &fid_info = params->ufid_grads_info->at(index1); CUSTOM_CHECK(index2 < fid_info.second); int real_ufid_idx = fid_info.first + index2; auto init_p = params->init->Get(real_ufid_idx, params->slice_conf_slice_idx); return init_p; } void MonolithEmbeddingToLayoutGradOp::TaskRun( const std::vector> &layouts, const std::vector> *ufid_grads_info, const uint64 *fids_offset_vec, int total_fid_num, const int32 *feature_offset_vec, int total_feature_num, const uint32 *nfl_offset_vec, int total_nfl_num, int batch_size, OpKernelContext *ctx, OpOutputList *embeddings_grad_list, std::vector *embeddings_grads_data, GroupA *init) { for (int32 idx = 0; idx < embeddings_grad_list->size(); ++idx) { (*embeddings_grad_list)[idx]->flat().setConstant(0); } int parallel_flag = GetParallelFlag(); // mutex/init per op compute, because there are several(>1) grad op // calculated togather. std::unique_ptr mutex_list; if (parallel_flag != 0) { mutex_list = std::make_unique(NUM_LOCKS); } auto scatter_grad_fn = [&, this](int start, int end) { for (int64 para_i = start; para_i < end; ++para_i) { auto &layout = layouts.at(para_i); // CHECK(end - start == 1); const ::google::protobuf::RepeatedPtrField &layout_slice_configs = layout->GetSliceConfig(); for (const SliceConfig &slice_conf : layout_slice_configs) { int dim_num = slice_conf.end() - slice_conf.start(); PtrWrapper ptr_info = layout->GetSlice(0, slice_conf); const int64 &nfl_idx = slice_conf.feature_idx(); bool is_shared; int nfl_offset, feature_num; GetFeatureInfo(nfl_idx, nfl_offset_vec, total_nfl_num, total_feature_num, &is_shared, &nfl_offset, &feature_num); if (!feature_num) continue; // nfl exits int feature_idx = nfl_offset + 0; for (size_t index = 0; index < batch_size; ++index) { int temp_offset = index * ptr_info.offset; if (slice_conf.pooling_type() == PoolingType::FIRSTN) { CHECK(temp_offset + slice_conf.max_sequence_length() * dim_num <= ptr_info.count); } else { CHECK(temp_offset + dim_num <= ptr_info.count); } ScatterGradGetInitFuncParams init_params( {slice_conf.slice_idx(), ufid_grads_info, init}); ScatterGrad(feature_idx, slice_conf.max_sequence_length(), slice_conf.pooling_type(), ptr_info.ptr + temp_offset, dim_num, slice_conf.start(), fids_offset_vec, total_fid_num, feature_offset_vec, total_feature_num, embeddings_grads_data->size(), embeddings_grads_data->data(), OptimizedSumpooling, (mutex_list ? ScatterGradGetMutexFuncFunc : nullptr), (mutex_list ? mutex_list.get() : nullptr), (init ? ScatterGradGetInitFunc : nullptr), &init_params); if (!is_shared) { // train don't have shared feature feature_idx++; } } } } }; if (parallel_flag == 0) { for (int i = 0; i < layouts.size(); ++i) { scatter_grad_fn(i, i + 1); } } else { auto worker_threads = ctx->device()->tensorflow_cpu_worker_threads(); worker_threads->workers->ParallelFor( layouts.size(), thread::ThreadPool::SchedulingParams( thread::ThreadPool::SchedulingStrategy::kFixedBlockSize, absl::nullopt, 1), // block_size scatter_grad_fn); } } class MonolithEmbeddingToLayoutGradOpV2 : public MonolithEmbeddingToLayoutGradOp { public: explicit MonolithEmbeddingToLayoutGradOpV2(OpKernelConstruction *ctx) : MonolithEmbeddingToLayoutGradOp(ctx, 2) {} }; class MonolithEmbeddingToLayoutGradOpV3 : public MonolithEmbeddingToLayoutGradOp { public: explicit MonolithEmbeddingToLayoutGradOpV3(OpKernelConstruction *ctx) : MonolithEmbeddingToLayoutGradOp(ctx, 3) {} }; class MonolithEmbeddingToLayoutGradOpV4 : public MonolithEmbeddingToLayoutGradOp { public: explicit MonolithEmbeddingToLayoutGradOpV4(OpKernelConstruction *ctx) : MonolithEmbeddingToLayoutGradOp(ctx, 4) {} }; class MonolithEmbeddingToLayoutGradOpV5 : public MonolithEmbeddingToLayoutGradOp { public: explicit MonolithEmbeddingToLayoutGradOpV5(OpKernelConstruction *ctx) : MonolithEmbeddingToLayoutGradOp(ctx, 5) {} }; auto forward_shape_inference_fn = [](shape_inference::InferenceContext *ctx) { std::string serialized; TF_RETURN_IF_ERROR(ctx->GetAttr("feature_cfgs", &serialized)); FeatureConfigs feature_cfgs; CHECK(feature_cfgs.ParseFromArray(serialized.data(), serialized.size())); std::vector layout_names; const auto &out_configs = feature_cfgs.out_configs(); for (const auto &pair : out_configs) { layout_names.push_back(pair.first); } std::sort(layout_names.begin(), layout_names.end()); std::vector tensors_shape; for (const auto &layout_name : layout_names) { const OutConfig &out_conf = out_configs.at(layout_name); for (const auto shape : out_conf.shape()) { std::vector dims; for (size_t i = 0; i < shape.dims_size(); ++i) { if (i == 0) { dims.push_back(ctx->UnknownDim()); } else { CHECK_GT(shape.dims(i), 0); dims.push_back(ctx->MakeDim(shape.dims(i))); } } tensors_shape.push_back(ctx->MakeShape(dims)); } } TF_RETURN_IF_ERROR(ctx->set_output("tensors", tensors_shape)); return Status::OK(); }; REGISTER_OP("MonolithEmbeddingToLayout") .Input("embeddings_list: M * float") .Input("fid_offset: uint64") .Input("feature_offset: int32") .Input("nfl_offset: uint32") .Input("batch_size: int32") .Output("tensors: num_out * float") .Attr("M: int") // num of fids_list (shard x subtable) .Attr("num_out: int") .Attr("variant_type: string") .Attr("feature_cfgs: string") .SetDoNotOptimize() .SetShapeFn(forward_shape_inference_fn); auto backward_shape_inference_fn = [](shape_inference::InferenceContext *ctx) { std::vector embeddings_list_shape; TF_RETURN_IF_ERROR(ctx->input("embeddings_list", &embeddings_list_shape)); TF_RETURN_IF_ERROR( ctx->set_output("embeddings_grad_list", embeddings_list_shape)); return Status::OK(); }; REGISTER_OP("MonolithEmbeddingToLayoutGrad") .Input("embeddings_list: M * float") .Input("fid_offset: uint64") .Input("feature_offset: int32") .Input("nfl_offset: uint32") .Input("batch_size: int32") .Input("tensors_grad: num_input * float") .Output("embeddings_grad_list: M * float") .Attr("M: int") // num of fids_list (shard x subtable) .Attr("num_input: int") // num of tensors_grad input .Attr("variant_type: string") .Attr("feature_cfgs: string") .SetDoNotOptimize() .SetShapeFn(backward_shape_inference_fn); REGISTER_KERNEL_BUILDER(Name("MonolithEmbeddingToLayout").Device(DEVICE_CPU), MonolithEmbeddingToLayoutOp); REGISTER_KERNEL_BUILDER( Name("MonolithEmbeddingToLayoutGrad").Device(DEVICE_CPU), MonolithEmbeddingToLayoutGradOp); REGISTER_OP("MonolithEmbeddingToLayoutV2") .Input("embeddings_list: M * float") .Input("fid_list_row_split: M * int64") .Input("fid_offset: uint64") .Input("feature_offset: int32") .Input("nfl_offset: uint32") .Input("batch_size: int32") .Output("tensors: num_out * float") .Attr("M: int") // num of fids_list (shard x subtable) .Attr("num_out: int") .Attr("variant_type: string") .Attr("feature_cfgs: string") .Attr("ps_num: int") .Attr("parallel_flag: int = 0") .SetDoNotOptimize() .SetShapeFn(forward_shape_inference_fn); REGISTER_OP("MonolithEmbeddingToLayoutGradV2") .Input("embeddings_list: M * float") .Input("fid_list_row_split: M * int64") .Input("fid_offset: uint64") .Input("feature_offset: int32") .Input("nfl_offset: uint32") .Input("batch_size: int32") .Input("tensors_grad: num_input * float") .Output("embeddings_grad_list: M * float") .Attr("M: int") // num of fids_list (shard x subtable) .Attr("num_input: int") // num of tensors_grad input .Attr("variant_type: string") .Attr("feature_cfgs: string") .Attr("ps_num: int") .Attr("parallel_flag: int = 0") .SetDoNotOptimize() .SetShapeFn(backward_shape_inference_fn); REGISTER_KERNEL_BUILDER(Name("MonolithEmbeddingToLayoutV2").Device(DEVICE_CPU), MonolithEmbeddingToLayoutOpV2); REGISTER_KERNEL_BUILDER( Name("MonolithEmbeddingToLayoutGradV2").Device(DEVICE_CPU), MonolithEmbeddingToLayoutGradOpV2); REGISTER_OP("MonolithEmbeddingToLayoutV3") .Input("embeddings_list: M * float") .Input("fid_offset: uint64") .Input("feature_offset: int32") .Input("nfl_offset: uint32") .Input("batch_size: int32") .Output("tensors: num_out * float") .Attr("M: int") // num of fids_list (shard x subtable) .Attr("num_out: int") .Attr("variant_type: string") .Attr("feature_cfgs: string") .Attr("ps_num: int") .Attr("parallel_flag: int = 0") .SetDoNotOptimize() .SetShapeFn(forward_shape_inference_fn); REGISTER_OP("MonolithEmbeddingToLayoutGradV3") .Input("embeddings_list: M * float") .Input("fid_offset: uint64") .Input("feature_offset: int32") .Input("nfl_offset: uint32") .Input("batch_size: int32") .Input("tensors_grad: num_input * float") .Output("embeddings_grad_list: M * float") .Attr("M: int") // num of fids_list (shard x subtable) .Attr("num_input: int") // num of tensors_grad input .Attr("variant_type: string") .Attr("feature_cfgs: string") .Attr("ps_num: int") .Attr("parallel_flag: int = 0") .SetDoNotOptimize() .SetShapeFn(backward_shape_inference_fn); REGISTER_KERNEL_BUILDER(Name("MonolithEmbeddingToLayoutV3").Device(DEVICE_CPU), MonolithEmbeddingToLayoutOpV3); REGISTER_KERNEL_BUILDER( Name("MonolithEmbeddingToLayoutGradV3").Device(DEVICE_CPU), MonolithEmbeddingToLayoutGradOpV3); REGISTER_OP("MonolithEmbeddingToLayoutV4") .Input("embeddings_list: M * float") .Input("fid_list_emb_row_lenth: int32") .Input("fid_offset: uint64") .Input("feature_offset: int32") .Input("nfl_offset: uint32") .Input("batch_size: int32") .Output("tensors: num_out * float") .Attr("M: int") // num of fids_list (shard x subtable) .Attr("num_out: int") .Attr("variant_type: string") .Attr("feature_cfgs: string") .Attr("ps_num: int") .Attr("parallel_flag: int = 0") .SetDoNotOptimize() .SetShapeFn(forward_shape_inference_fn); REGISTER_OP("MonolithEmbeddingToLayoutGradV4") .Input("embeddings_list: M * float") .Input("fid_list_emb_row_lenth: int32") .Input("fid_offset: uint64") .Input("feature_offset: int32") .Input("nfl_offset: uint32") .Input("batch_size: int32") .Input("tensors_grad: num_input * float") .Output("embeddings_grad_list: M * float") .Attr("M: int") // num of fids_list (shard x subtable) .Attr("num_input: int") // num of tensors_grad input .Attr("variant_type: string") .Attr("feature_cfgs: string") .Attr("ps_num: int") .Attr("parallel_flag: int = 0") .SetDoNotOptimize() .SetShapeFn(backward_shape_inference_fn); REGISTER_KERNEL_BUILDER(Name("MonolithEmbeddingToLayoutV4").Device(DEVICE_CPU), MonolithEmbeddingToLayoutOpV4); REGISTER_KERNEL_BUILDER( Name("MonolithEmbeddingToLayoutGradV4").Device(DEVICE_CPU), MonolithEmbeddingToLayoutGradOpV4); REGISTER_OP("MonolithEmbeddingToLayoutV5") .Input("embeddings_list: M * float") .Input("fid_offset: uint64") .Input("feature_offset: int32") .Input("nfl_offset: uint32") .Input("batch_size: int32") .Input("nfl_size: int32") .Input("feature_size: int32") .Input("fid_size: int32") .Input("emb_size: int32") .Output("tensors: num_out * float") .Attr("M: int") // num of fids_list (shard x subtable) .Attr("num_out: int") .Attr("variant_type: string") .Attr("feature_cfgs: string") .Attr("ps_num: int") .Attr("parallel_flag: int = 0") .SetDoNotOptimize() .SetShapeFn(forward_shape_inference_fn); REGISTER_OP("MonolithEmbeddingToLayoutGradV5") .Input("embeddings_list: M * float") .Input("fid_offset: uint64") .Input("feature_offset: int32") .Input("nfl_offset: uint32") .Input("batch_size: int32") .Input("tensors_grad: num_input * float") .Output("embeddings_grad_list: M * float") .Attr("M: int") // num of fids_list (shard x subtable) .Attr("num_input: int") // num of tensors_grad input .Attr("variant_type: string") .Attr("feature_cfgs: string") .Attr("ps_num: int") .Attr("parallel_flag: int = 0") .SetDoNotOptimize() .SetShapeFn(backward_shape_inference_fn); REGISTER_KERNEL_BUILDER(Name("MonolithEmbeddingToLayoutV5").Device(DEVICE_CPU), MonolithEmbeddingToLayoutOpV5); REGISTER_KERNEL_BUILDER( Name("MonolithEmbeddingToLayoutGradV5").Device(DEVICE_CPU), MonolithEmbeddingToLayoutGradOpV5); } // namespace fused_layout } // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/runtime/ops/fused_embedding_to_layout.cu.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #if GOOGLE_CUDA #define EIGEN_USE_GPU #include "monolith/native_training/runtime/ops/fused_embedding_to_layout.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/kernels/gpu_device_array.h" #include "tensorflow/core/kernels/gpu_device_array_gpu.h" #include "tensorflow/core/util/gpu_kernel_helper.h" namespace tensorflow { namespace monolith_tf { namespace fused_layout { typedef Eigen::GpuDevice GPUDevice; struct ForwardTaskInfo { int dim_num; PtrWrapper ptr_info; int64 nfl_idx; ::monolith::io::proto::OutType out_type; ::monolith::io::proto::PoolingType pooling_type; int max_sequence_length; int start; int req_i; }; __device__ void *MemCopyGPU(float *dest, const float *src, std::size_t count) { for (int32 idx = 0; idx < count; ++idx) { *(dest + idx) = *(src + idx); } return dest; } __device__ void OptimizedSumpoolingGPU(const float *src, const int dim_num, void *init_ptr, float *dst, void *one_mutex = nullptr, int mean_pool_fid_num = 0) { bool *init = static_cast(init_ptr); if (init && *init) { if (mean_pool_fid_num) { for (size_t i = 0; i < dim_num; ++i) { dst[i] = (src[i] / mean_pool_fid_num); } } else { MemCopyGPU(dst, src, dim_num); } *init = false; } else { if (mean_pool_fid_num) { for (size_t i = 0; i < dim_num; ++i) { dst[i] += (src[i] / mean_pool_fid_num); } } else { for (size_t i = 0; i < dim_num; ++i) { dst[i] += src[i]; } } } } __device__ void OptimizedSumpoolingGPUWithLock(const float *src, const int dim_num, void *init_ptr, float *dst, void *one_mutex = nullptr, int mean_pool_fid_num = 0) { if (mean_pool_fid_num) { for (int32 idx = 0; idx < dim_num; ++idx) { GpuAtomicAdd(dst + idx, (*(src + idx)) / mean_pool_fid_num); } } else { for (int32 idx = 0; idx < dim_num; ++idx) { GpuAtomicAdd(dst + idx, *(src + idx)); } } } __global__ void ForwardBatchKernel( const Gpu2DLaunchConfig config, GpuDeviceArrayStruct embeddings_data_list, const uint64 *fids_offset_vec, int total_fid_num, const int32 *feature_offset_vec, int total_feature_num, const uint32 *nfl_offset_vec, int total_nfl_num, GpuDeviceArrayStruct task_info_list, GpuDeviceArrayStruct each_req_batch_size_list, GpuDeviceArrayStruct each_req_nfl_list, GpuDeviceArrayStruct each_req_feature_list, GpuDeviceArrayStruct each_req_fid_list, GpuDeviceArrayStruct each_req_emb_list) { ForwardTaskInfo *task_info_list_ptr = GetGpuDeviceArrayOnDevice(&task_info_list); int *each_req_batch_size_offset = GetGpuDeviceArrayOnDevice(&each_req_batch_size_list); int *each_req_nfl_offset = GetGpuDeviceArrayOnDevice(&each_req_nfl_list); int *each_req_feature_offset = GetGpuDeviceArrayOnDevice(&each_req_feature_list); int *each_req_fid_offset = GetGpuDeviceArrayOnDevice(&each_req_fid_list); int *each_req_emb_offset = GetGpuDeviceArrayOnDevice(&each_req_emb_list); const PtrWrapper *embeddings_data_list_ptr = GetGpuDeviceArrayOnDevice(&embeddings_data_list); bool is_shared; int nfl_offset; int feature_num; ForwardTaskInfo *task_info = nullptr; int feature_idx; int temp_offset; bool init; GPU_AXIS_KERNEL_LOOP(task_idx, config.virtual_thread_count.y, Y) { task_info = task_info_list_ptr + task_idx; GetFeatureInfo(task_info->nfl_idx, nfl_offset_vec + *(each_req_nfl_offset + task_info->req_i), *(each_req_nfl_offset + task_info->req_i + 1) - *(each_req_nfl_offset + task_info->req_i), *(each_req_feature_offset + task_info->req_i + 1) - *(each_req_feature_offset + task_info->req_i), &is_shared, &nfl_offset, &feature_num); if (!feature_num) return; // nfl exits GPU_AXIS_KERNEL_LOOP(batch_idx, config.virtual_thread_count.x, X) { if (batch_idx >= *(each_req_batch_size_offset + task_info->req_i + 1) - *(each_req_batch_size_offset + task_info->req_i)) return; // out of range feature_idx = nfl_offset; // in single req scope if (!is_shared) { feature_idx += batch_idx; } temp_offset = (batch_idx + *(each_req_batch_size_offset + task_info->req_i)) * task_info->ptr_info.offset; if (task_info->out_type == OutType::ADDN) { if (task_info->pooling_type == PoolingType::FIRSTN) { // not support } else { GatherEmb(feature_idx, task_info->max_sequence_length, task_info->pooling_type, task_info->dim_num, task_info->start, embeddings_data_list_ptr + *(each_req_emb_offset + task_info->req_i), *(each_req_emb_offset + task_info->req_i + 1) - *(each_req_emb_offset + task_info->req_i), fids_offset_vec + *(each_req_fid_offset + task_info->req_i), *(each_req_fid_offset + task_info->req_i + 1) - *(each_req_fid_offset + task_info->req_i), feature_offset_vec + *(each_req_feature_offset + task_info->req_i), *(each_req_feature_offset + task_info->req_i + 1) - *(each_req_feature_offset + task_info->req_i), const_cast(task_info->ptr_info.ptr + temp_offset), OptimizedSumpoolingGPUWithLock, MemCopyGPU, nullptr, nullptr, nullptr, nullptr); } } else { init = true; GatherEmb( feature_idx, task_info->max_sequence_length, task_info->pooling_type, task_info->dim_num, task_info->start, embeddings_data_list_ptr + *(each_req_emb_offset + task_info->req_i), *(each_req_emb_offset + task_info->req_i + 1) - *(each_req_emb_offset + task_info->req_i), fids_offset_vec + *(each_req_fid_offset + task_info->req_i), *(each_req_fid_offset + task_info->req_i + 1) - *(each_req_fid_offset + task_info->req_i), feature_offset_vec + *(each_req_feature_offset + task_info->req_i), *(each_req_feature_offset + task_info->req_i + 1) - *(each_req_feature_offset + task_info->req_i), const_cast(task_info->ptr_info.ptr + temp_offset), OptimizedSumpoolingGPU, MemCopyGPU, nullptr, nullptr, DefaultGetInitFunc, &init); } } } } template struct SetZeroFunctor { void operator()(const GPUDevice &d, typename TTypes::Flat out) { To32Bit(out).device(d) = To32Bit(out).constant(T(0)); } }; class MonolithEmbeddingToLayoutOpV3GPU : public MonolithEmbeddingToLayoutOp { public: explicit MonolithEmbeddingToLayoutOpV3GPU(OpKernelConstruction *ctx, int verison = 3) : MonolithEmbeddingToLayoutOp(ctx, verison) {} virtual void TaskRun(const std::vector> &layouts, const std::vector &embeddings_data, const uint64 *fids_offset_vec, int total_fid_num, const int32 *feature_offset_vec, int total_feature_num, const uint32 *nfl_offset_vec, int total_nfl_num, int batch_size, const std::vector &each_req_batch_size_offset, const std::vector &each_req_nfl_offset, const std::vector &each_req_feature_offset, const std::vector &each_req_fid_offset, int req_num, OpKernelContext *ctx, OpOutputList *layout_tensor_list) { GPUDevice gpu_device = ctx->eigen_device(); SetZeroFunctor zero_functor; for (int32 idx = 0; idx < layout_tensor_list->size(); ++idx) { zero_functor(gpu_device, (*layout_tensor_list)[idx]->flat()); } int each_req_emb_num = embeddings_data.size() / req_num; std::vector task_info_vec; { auto activity = std::make_unique([]() { return "BuildGPUTask"; }); for (int req_i = 0; req_i < req_num; req_i++) { for (int para_i = 0; para_i < layouts.size(); ++para_i) { auto &layout = layouts.at(para_i); // CHECK(end - start == 1); const ::google::protobuf::RepeatedPtrField &layout_slice_configs = layout->GetSliceConfig(); for (uint slice_conf_i = 0; slice_conf_i < layout_slice_configs.size(); ++slice_conf_i) { const SliceConfig &slice_conf = layout_slice_configs[slice_conf_i]; int dim_num = slice_conf.end() - slice_conf.start(); PtrWrapper ptr_info = layout->GetSlice(0, slice_conf); const int64 nfl_idx = slice_conf.feature_idx(); task_info_vec.push_back(ForwardTaskInfo( {dim_num, ptr_info, nfl_idx, layout->out_type(), slice_conf.pooling_type(), slice_conf.max_sequence_length(), slice_conf.start(), req_i})); } } } } GpuDeviceArrayOnHost task_info_list(ctx, task_info_vec.size()); GpuDeviceArrayOnHost embeddings_data_list( ctx, embeddings_data.size()); GpuDeviceArrayOnHost each_req_batch_size_list( ctx, each_req_batch_size_offset.size()); GpuDeviceArrayOnHost each_req_nfl_list(ctx, each_req_nfl_offset.size()); GpuDeviceArrayOnHost each_req_feature_list( ctx, each_req_feature_offset.size()); GpuDeviceArrayOnHost each_req_fid_list(ctx, each_req_fid_offset.size()); GpuDeviceArrayOnHost each_req_emb_list(ctx, req_num + 1); { auto activity = std::make_unique( []() { return "CopyHostValueToDevice"; }); OP_REQUIRES_OK(ctx, task_info_list.Init()); for (int i = 0; i < task_info_vec.size(); ++i) { task_info_list.Set(i, task_info_vec[i]); } OP_REQUIRES_OK(ctx, task_info_list.Finalize()); OP_REQUIRES_OK(ctx, embeddings_data_list.Init()); for (int i = 0; i < embeddings_data.size(); ++i) { embeddings_data_list.Set(i, embeddings_data[i]); } OP_REQUIRES_OK(ctx, embeddings_data_list.Finalize()); OP_REQUIRES_OK(ctx, each_req_batch_size_list.Init()); for (int i = 0; i < each_req_batch_size_offset.size(); ++i) { each_req_batch_size_list.Set(i, each_req_batch_size_offset[i]); } OP_REQUIRES_OK(ctx, each_req_batch_size_list.Finalize()); OP_REQUIRES_OK(ctx, each_req_nfl_list.Init()); for (int i = 0; i < each_req_nfl_offset.size(); ++i) { each_req_nfl_list.Set(i, each_req_nfl_offset[i]); } OP_REQUIRES_OK(ctx, each_req_nfl_list.Finalize()); OP_REQUIRES_OK(ctx, each_req_feature_list.Init()); for (int i = 0; i < each_req_feature_offset.size(); ++i) { each_req_feature_list.Set(i, each_req_feature_offset[i]); } OP_REQUIRES_OK(ctx, each_req_feature_list.Finalize()); OP_REQUIRES_OK(ctx, each_req_fid_list.Init()); for (int i = 0; i < each_req_fid_offset.size(); ++i) { each_req_fid_list.Set(i, each_req_fid_offset[i]); } OP_REQUIRES_OK(ctx, each_req_fid_list.Finalize()); OP_REQUIRES_OK(ctx, each_req_emb_list.Init()); for (int i = 0; i < req_num + 1; ++i) { each_req_emb_list.Set(i, i * each_req_emb_num); } OP_REQUIRES_OK(ctx, each_req_emb_list.Finalize()); } auto config = GetGpu2DLaunchConfig(batch_size, task_info_vec.size(), gpu_device); GpuLaunchKernel(ForwardBatchKernel, config.block_count, config.thread_per_block, 0, gpu_device.stream(), config, embeddings_data_list.data(), fids_offset_vec, total_fid_num, feature_offset_vec, total_feature_num, nfl_offset_vec, total_nfl_num, task_info_list.data(), each_req_batch_size_list.data(), each_req_nfl_list.data(), each_req_feature_list.data(), each_req_fid_list.data(), each_req_emb_list.data()); } }; class MonolithEmbeddingToLayoutOpV4GPU : public MonolithEmbeddingToLayoutOpV3GPU { public: explicit MonolithEmbeddingToLayoutOpV4GPU(OpKernelConstruction *ctx) : MonolithEmbeddingToLayoutOpV3GPU(ctx, 4) {} }; class MonolithEmbeddingToLayoutOpV5GPU : public MonolithEmbeddingToLayoutOpV3GPU { public: explicit MonolithEmbeddingToLayoutOpV5GPU(OpKernelConstruction *ctx) : MonolithEmbeddingToLayoutOpV3GPU(ctx, 5) {} }; __global__ void BackwardBatchKernel( const Gpu2DLaunchConfig config, GpuDeviceArrayStruct task_info_list, const uint64 *fids_offset_vec, int total_fid_num, const int32 *feature_offset_vec, int total_feature_num, const uint32 *nfl_offset_vec, int total_nfl_num, int batch_size, GpuDeviceArrayStruct embeddings_grads_data) { const ForwardTaskInfo *task_info_list_ptr = GetGpuDeviceArrayOnDevice(&task_info_list); PtrWrapper *embeddings_grads_data_list_ptr = GetGpuDeviceArrayOnDevice(&embeddings_grads_data); bool is_shared; int nfl_offset; int feature_num; const ForwardTaskInfo *task_info = nullptr; int feature_idx; int temp_offset; GPU_AXIS_KERNEL_LOOP(task_idx, config.virtual_thread_count.y, Y) { task_info = task_info_list_ptr + task_idx; GetFeatureInfo(task_info->nfl_idx, nfl_offset_vec, total_nfl_num, total_feature_num, &is_shared, &nfl_offset, &feature_num); if (!feature_num) return; // nfl exits GPU_AXIS_KERNEL_LOOP(batch_idx, config.virtual_thread_count.x, X) { feature_idx = nfl_offset; if (!is_shared) { feature_idx += batch_idx; } temp_offset = batch_idx * task_info->ptr_info.offset; ScatterGrad( feature_idx, task_info->max_sequence_length, task_info->pooling_type, task_info->ptr_info.ptr + temp_offset, task_info->dim_num, task_info->start, fids_offset_vec, total_fid_num, feature_offset_vec, total_feature_num, embeddings_grads_data.size, embeddings_grads_data_list_ptr, OptimizedSumpoolingGPUWithLock, nullptr, nullptr, nullptr, nullptr); } } } class MonolithEmbeddingToLayoutGradOpV3GPU : public MonolithEmbeddingToLayoutGradOp { public: explicit MonolithEmbeddingToLayoutGradOpV3GPU(OpKernelConstruction *ctx, int verison = 3) : MonolithEmbeddingToLayoutGradOp(ctx, verison) {} void TaskRun(const std::vector> &layouts, const std::vector> *ufid_grads_info, const uint64 *fids_offset_vec, int total_fid_num, const int32 *feature_offset_vec, int total_feature_num, const uint32 *nfl_offset_vec, int total_nfl_num, int batch_size, OpKernelContext *ctx, OpOutputList *embeddings_grad_list, std::vector *embeddings_grads_data, GroupA *init) { GPUDevice gpu_device = ctx->eigen_device(); SetZeroFunctor zero_functor; for (int32 idx = 0; idx < embeddings_grad_list->size(); ++idx) { zero_functor(gpu_device, (*embeddings_grad_list)[idx]->flat()); } std::vector task_info_vec; for (int64 para_i = 0; para_i < layouts.size(); ++para_i) { auto &layout = layouts.at(para_i); // CHECK(end - start == 1); const ::google::protobuf::RepeatedPtrField &layout_slice_configs = layout->GetSliceConfig(); for (const SliceConfig &slice_conf : layout_slice_configs) { int dim_num = slice_conf.end() - slice_conf.start(); PtrWrapper ptr_info = layout->GetSlice(0, slice_conf); const int64 nfl_idx = slice_conf.feature_idx(); task_info_vec.push_back(ForwardTaskInfo( {dim_num, ptr_info, nfl_idx, layout->out_type(), slice_conf.pooling_type(), slice_conf.max_sequence_length(), slice_conf.start()})); } } GpuDeviceArrayOnHost task_info_list(ctx, task_info_vec.size()); OP_REQUIRES_OK(ctx, task_info_list.Init()); for (int i = 0; i < task_info_vec.size(); ++i) { task_info_list.Set(i, task_info_vec[i]); } OP_REQUIRES_OK(ctx, task_info_list.Finalize()); GpuDeviceArrayOnHost embeddings_grads_data_list( ctx, embeddings_grads_data->size()); OP_REQUIRES_OK(ctx, embeddings_grads_data_list.Init()); for (int i = 0; i < embeddings_grads_data->size(); ++i) { embeddings_grads_data_list.Set(i, (*embeddings_grads_data)[i]); } OP_REQUIRES_OK(ctx, embeddings_grads_data_list.Finalize()); auto config = GetGpu2DLaunchConfig(batch_size, task_info_vec.size(), gpu_device); GpuLaunchKernel( BackwardBatchKernel, config.block_count, config.thread_per_block, 0, gpu_device.stream(), config, task_info_list.data(), fids_offset_vec, total_fid_num, feature_offset_vec, total_feature_num, nfl_offset_vec, total_nfl_num, batch_size, embeddings_grads_data_list.data()); } }; class MonolithEmbeddingToLayoutGradOpV4GPU : public MonolithEmbeddingToLayoutGradOpV3GPU { public: explicit MonolithEmbeddingToLayoutGradOpV4GPU(OpKernelConstruction *ctx) : MonolithEmbeddingToLayoutGradOpV3GPU(ctx, 4) {} }; class MonolithEmbeddingToLayoutGradOpV5GPU : public MonolithEmbeddingToLayoutGradOpV3GPU { public: explicit MonolithEmbeddingToLayoutGradOpV5GPU(OpKernelConstruction *ctx) : MonolithEmbeddingToLayoutGradOpV3GPU(ctx, 5) {} }; REGISTER_KERNEL_BUILDER(Name("MonolithEmbeddingToLayoutV3") .Device(DEVICE_GPU) .HostMemory("batch_size"), MonolithEmbeddingToLayoutOpV3GPU); REGISTER_KERNEL_BUILDER(Name("MonolithEmbeddingToLayoutGradV3") .Device(DEVICE_GPU) .HostMemory("batch_size"), MonolithEmbeddingToLayoutGradOpV3GPU); REGISTER_KERNEL_BUILDER(Name("MonolithEmbeddingToLayoutV4") .Device(DEVICE_GPU) .HostMemory("batch_size") .HostMemory("fid_list_emb_row_lenth"), MonolithEmbeddingToLayoutOpV4GPU); REGISTER_KERNEL_BUILDER(Name("MonolithEmbeddingToLayoutGradV4") .Device(DEVICE_GPU) .HostMemory("batch_size") .HostMemory("fid_list_emb_row_lenth"), MonolithEmbeddingToLayoutGradOpV4GPU); REGISTER_KERNEL_BUILDER(Name("MonolithEmbeddingToLayoutV5") .Device(DEVICE_GPU) .HostMemory("batch_size") .HostMemory("nfl_size") .HostMemory("feature_size") .HostMemory("fid_size") .HostMemory("emb_size"), MonolithEmbeddingToLayoutOpV5GPU); REGISTER_KERNEL_BUILDER(Name("MonolithEmbeddingToLayoutGradV5") .Device(DEVICE_GPU) .HostMemory("batch_size"), MonolithEmbeddingToLayoutGradOpV5GPU); } // namespace fused_layout } // namespace monolith_tf } // namespace tensorflow #endif // GOOGLE_CUDA ================================================ FILE: monolith/native_training/runtime/ops/fused_embedding_to_layout.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include // for tuple #include "absl/strings/str_cat.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/platform/threadpool.h" #include "tensorflow/core/util/work_sharder.h" #include "idl/matrix/proto/example.pb.h" #include "monolith/native_training/data/training_instance/cc/pb_variant.h" #include "monolith/native_training/runtime/hash_table/optimizer/avx_utils.h" #include "tensorflow/core/profiler/lib/traceme.h" namespace tensorflow { namespace monolith_tf { namespace fused_layout { using FidList = ::monolith::io::proto::FidList; using Example = ::monolith::io::proto::Example; using ExampleBatch = ::monolith::io::proto::ExampleBatch; using FeatureConfigs = ::monolith::io::proto::FeatureConfigs; using PoolingType = ::monolith::io::proto::PoolingType; using OutType = ::monolith::io::proto::OutType; using SliceConfig = ::monolith::io::proto::SliceConfig; using LayoutShape = ::monolith::io::proto::TensorShape; using OutConfig = ::monolith::io::proto::OutConfig; using Feature = ::monolith::io::proto::Feature; using FeatureListType = ::monolith::io::proto::FeatureListType; using NamedFeatureList = ::monolith::io::proto::NamedFeatureList; using MiniBatch = std::unordered_map>; using Fid2EmbIdxMap = std::unordered_map>; using Fid2EmbMap = std::unordered_map; EIGEN_ALWAYS_INLINE EIGEN_DEVICE_FUNC void ParseFidOffset( const uint64 &fids_offset, int32 *index1, int32 *index2) { *index1 = fids_offset >> 32; *index2 = fids_offset << 32 >> 32; } EIGEN_ALWAYS_INLINE EIGEN_DEVICE_FUNC void ParseNflOffset( const uint32 nfl_offset_encode, bool *is_shared, int *nfl_offset) { *is_shared = nfl_offset_encode >> 31; *nfl_offset = nfl_offset_encode & 0x7fffffff; } EIGEN_ALWAYS_INLINE EIGEN_DEVICE_FUNC void GetFeatureInfo( const int64 nfl_idx, const uint32 *nfl_offset_vec, const int total_nfl_num, const int total_feature_num, bool *is_shared, int *nfl_offset, int *feature_num) { ParseNflOffset(*(nfl_offset_vec + nfl_idx), is_shared, nfl_offset); if (nfl_idx < total_nfl_num - 1) { bool is_shared_later; int nfl_offset_later; ParseNflOffset(*(nfl_offset_vec + nfl_idx + 1), &is_shared_later, &nfl_offset_later); *feature_num = nfl_offset_later - *nfl_offset; } else { *feature_num = total_feature_num - *nfl_offset; } } struct PtrWrapper { const float *ptr; uint offset; uint count; }; struct GroupA { GroupA(int dim1, int dim2) : b(dim1 * dim2, true), dim_1(dim1), dim_2(dim2) {} char *Get(int dim1, int dim2) { return &(b.at(dim1 * dim_2 + dim2)); } std::vector b; int dim_1; int dim_2; }; class Layout { public: Layout(const std::string &name, const OutConfig &out_conf) : name_(name), out_config_(out_conf) {} virtual ~Layout() {} virtual PtrWrapper GetSlice(int row_id, const SliceConfig &slice_conf) = 0; const SliceConfig *GetKey(const SliceConfig &slice_conf) { return &slice_conf; // return absl::StrCat(name_, "_", slice_conf.feature_name(), "_", // slice_conf.start(), "_", slice_conf.end()); } const ::google::protobuf::RepeatedPtrField &GetSliceConfig() { return out_config_.slice_configs(); } const OutType out_type() { return out_config_.out_type(); } protected: const std::string &name_; const OutConfig &out_config_; }; class NoneLayout : public Layout { public: // op input TODO NoneLayout(const std::string &name, const OutConfig &out_conf, OpInputList &tensor_list, int &start_idx); // op output NoneLayout(const std::string &name, const OutConfig &out_conf, OpOutputList &tensor_list, int &start_idx); virtual ~NoneLayout() {} PtrWrapper GetSlice(int row_id, const SliceConfig &slice_conf) override; private: absl::flat_hash_map> slice_to_tensor_; }; class DefaultLayout : public Layout { public: DefaultLayout(const std::string &name, const OutConfig &out_conf, OpInputList &tensor_list, int &start_idx); DefaultLayout(const std::string &name, const OutConfig &out_conf, OpOutputList &tensor_list, int &start_idx); virtual ~DefaultLayout() {} PtrWrapper GetSlice(int row_id, const SliceConfig &slice_conf) override; private: absl::flat_hash_map> slice_to_tensor_; }; class MonolithEmbeddingToLayoutBase : public OpKernel { public: explicit MonolithEmbeddingToLayoutBase(OpKernelConstruction *ctx, int version); private: std::string variant_type_; FeatureConfigs feature_cfgs_; std::vector layout_names_; int max_slice_num_ = 0; std::vector> table_feature_dim_; int ps_num_ = 0; int parallel_flag_ = 0; int version_ = 1; protected: int GetMaxSliceNum() { return max_slice_num_; } const std::string &GetVariantType() { return variant_type_; } const std::vector &GetLayoutNames() { return layout_names_; } const FeatureConfigs &GetFeatureCfgs() { return feature_cfgs_; } int GetPsNum() { return ps_num_; } int GetParallelFlag() { return parallel_flag_; } int GetVersion() { return version_; } const std::vector> &GetFeatureInTableDim() { return table_feature_dim_; } }; #ifdef EIGEN_USE_GPU #define CUSTOM_CHECK(cond) \ if (!(cond)) { \ printf("ERROR %s %s:%d CHECK Fail\n", __FILE__, __func__, __LINE__); \ return; \ } #else #define CUSTOM_CHECK(cond) CHECK(cond) #endif typedef void (*OptimizedSumpoolingFunc)(const float *src, const int dim_num, void *init, float *dst, void *one_mutex, int mean_pool_fid_num); typedef void *(MemCopyFunc)(float *dest, const float *src, std::size_t count); typedef void *(GetMutexFunc)(void *main_params, int32 index1, int32 index2); typedef void *(GetInitFunc)(void *main_params, int32 index1, int32 index2); EIGEN_ALWAYS_INLINE EIGEN_DEVICE_FUNC void *DefaultGetInitFunc( void *main_params, int32 index1, int32 index2) { return main_params; } EIGEN_ALWAYS_INLINE EIGEN_DEVICE_FUNC void GatherEmb( const int feature_idx, const int max_sequence_length, const PoolingType pooling_type, int dims, int slice_conf_start, const PtrWrapper *embeddings_data, int embeddings_data_size, const uint64 *fids_offset_vec, const int total_fid_num, const int32 *feature_offset_vec, const int total_feature_num, float *out_ptr, OptimizedSumpoolingFunc opt_sumpool_fn, MemCopyFunc mem_copy_fn, GetMutexFunc get_mutex_func, void *get_mutex_func_main_params, GetInitFunc get_init_func, void *get_init_func_main_params) { CUSTOM_CHECK(feature_idx < total_feature_num); int fid_num = (feature_idx < total_feature_num - 1) ? *(feature_offset_vec + feature_idx + 1) - *(feature_offset_vec + feature_idx) : total_fid_num - *(feature_offset_vec + feature_idx); CUSTOM_CHECK(fid_num >= 0); if (fid_num == 0) return; const auto start_fid_offset_idx = *(feature_offset_vec + feature_idx); int seq_idx = 0; for (int fid_idx = 0; fid_idx < fid_num; fid_idx++) { auto fid_offset_idx = start_fid_offset_idx + fid_idx; int32 index1, index2; CUSTOM_CHECK(fid_offset_idx < total_fid_num); const uint64 fids_offset = *(fids_offset_vec + fid_offset_idx); ParseFidOffset(fids_offset, &index1, &index2); CUSTOM_CHECK(index1 < embeddings_data_size); const auto &ptr_info = *(embeddings_data + index1); int tmp_offset = index2 * ptr_info.offset + slice_conf_start; CUSTOM_CHECK(tmp_offset + dims <= ptr_info.count); const float *src = ptr_info.ptr + tmp_offset; void *one_mutex = nullptr; if (get_mutex_func) { one_mutex = get_mutex_func(get_mutex_func_main_params, index1, index2); } void *init = nullptr; if (get_init_func) { init = get_init_func(get_init_func_main_params, index1, index2); } switch (pooling_type) { case PoolingType::SUM: opt_sumpool_fn(src, dims, init, out_ptr, one_mutex, 0); break; case PoolingType::MEAN: opt_sumpool_fn(src, dims, init, out_ptr, one_mutex, fid_num); break; case PoolingType::FIRSTN: if (seq_idx < max_sequence_length) { mem_copy_fn(out_ptr + seq_idx * dims, src, dims); } seq_idx++; break; default: break; } } } class MonolithEmbeddingToLayoutOp : public MonolithEmbeddingToLayoutBase { public: explicit MonolithEmbeddingToLayoutOp(OpKernelConstruction *ctx, int version = 1); void Compute(OpKernelContext *ctx) override; virtual void TaskRun(const std::vector> &layouts, const std::vector &embeddings_data, const uint64 *fids_offset_vec, int total_fid_num, const int32 *feature_offset_vec, int total_feature_num, const uint32 *nfl_offset_vec, int total_nfl_num, int batch_size, const std::vector &each_req_batch_size_offset, const std::vector &each_req_nfl_offset, const std::vector &each_req_feature_offset, const std::vector &each_req_fid_offset, int req_num, OpKernelContext *ctx, OpOutputList *layout_tensor_list); private: int req_sum_ = 0; int process_num_ = 0; }; EIGEN_ALWAYS_INLINE EIGEN_DEVICE_FUNC void ScatterGrad( const int feature_idx, const int max_sequence_length, const PoolingType pooling_type, const float *grad_ptr, int dims, int slice_conf_start, const uint64 *fids_offset_vec, const int total_fid_num, const int32 *feature_offset_vec, const int total_feature_num, const int embeddings_grads_data_num, PtrWrapper *embeddings_grads_data, OptimizedSumpoolingFunc opt_sumpool_fn, GetMutexFunc get_mutex_func, void *get_mutex_func_main_params, GetInitFunc get_init_func, void *get_init_func_main_params) { CUSTOM_CHECK(feature_idx < total_feature_num); int fid_num = (feature_idx < total_feature_num - 1) ? *(feature_offset_vec + feature_idx + 1) - *(feature_offset_vec + feature_idx) : total_fid_num - *(feature_offset_vec + feature_idx); CUSTOM_CHECK(fid_num >= 0); if (fid_num == 0) return; const auto start_fid_offset_idx = *(feature_offset_vec + feature_idx); int seq_idx = 0; for (int fid_idx = 0; fid_idx < fid_num; fid_idx++) { int32 fid_offset_idx = start_fid_offset_idx + fid_idx; int32 index1, index2; CUSTOM_CHECK(fid_offset_idx < total_fid_num); const uint64 fids_offset = *(fids_offset_vec + fid_offset_idx); ParseFidOffset(fids_offset, &index1, &index2); CUSTOM_CHECK(index1 < embeddings_grads_data_num); // embeddings_grads_data: fid grad data const auto &ptr_info = *(embeddings_grads_data + index1); int tmp_offset = index2 * ptr_info.offset + slice_conf_start; CUSTOM_CHECK(tmp_offset + dims <= ptr_info.count); void *one_mutex = nullptr; if (get_mutex_func) { one_mutex = get_mutex_func(get_mutex_func_main_params, index1, index2); } void *init_p = nullptr; if (get_init_func) { init_p = get_init_func(get_init_func_main_params, index1, index2); } float *dst = const_cast(ptr_info.ptr) + tmp_offset; switch (pooling_type) { case PoolingType::SUM: { opt_sumpool_fn(grad_ptr, dims, init_p, dst, one_mutex, 0); break; } case PoolingType::MEAN: { opt_sumpool_fn(grad_ptr, dims, init_p, dst, one_mutex, fid_num); break; } case PoolingType::FIRSTN: { if (seq_idx < max_sequence_length) { opt_sumpool_fn(grad_ptr + seq_idx * dims, dims, init_p, dst, one_mutex, false); } seq_idx++; break; } default: break; } } } class MonolithEmbeddingToLayoutGradOp : public MonolithEmbeddingToLayoutBase { public: explicit MonolithEmbeddingToLayoutGradOp(OpKernelConstruction *ctx, int version = 1); void Compute(OpKernelContext *ctx) override; virtual void TaskRun(const std::vector> &layouts, const std::vector> *ufid_grads_info, const uint64 *fids_offset_vec, int total_fid_num, const int32 *feature_offset_vec, int total_feature_num, const uint32 *nfl_offset_vec, int total_nfl_num, int batch_size, OpKernelContext *ctx, OpOutputList *embeddings_grad_list, std::vector *embeddings_grads_data, GroupA *init); }; } // namespace fused_layout } // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/runtime/ops/fused_reorder_by_indices.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/lib/core/threadpool.h" namespace tensorflow { namespace monolith_tf { template class FusedReorderByIndicesOp : public OpKernel { public: explicit FusedReorderByIndicesOp(OpKernelConstruction* ctx) : OpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("rank0_empty", &rank0_empty_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("num_of_shards", &num_shards_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("M", &num_tables_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("slot_embedding_dims", &slot_embedding_dims_)); } void Compute(OpKernelContext* ctx) override { // auto start = std::chrono::steady_clock::now(); std::vector> ids_sets(num_tables_); std::vector> ids_for_splits(num_tables_ * num_shards_); int total_fids = 0; for (int m = 0; m < num_tables_; ++m) { auto data = ctx->input(m).vec().data(); auto sz = ctx->input(m).NumElements(); total_fids += sz; // Performance critical: reserve enough space so ids_sets won't rehash ids_sets[m].reserve(sz); for (int n = 0; n < num_shards_; n++) // reserve so ids_for_splits will **most likely** not reallocate ids_for_splits[n * num_tables_ + m].reserve((sz + sz / 4) / num_shards_); int dim = slot_embedding_dims_[m]; for (int i = 0; i < sz; ++i) { auto val = data[i]; auto& vec = ids_for_splits[shard_func(val) * num_tables_ + m]; if (ids_sets[m].insert({val, vec.size() * dim}).second) vec.push_back(val); } } Tensor *output, *shard_sizes, *sharded_slot_sizes; OP_REQUIRES_OK(ctx, ctx->allocate_output(1, {num_shards_}, &shard_sizes)); OP_REQUIRES_OK(ctx, ctx->allocate_output(2, {num_shards_ * num_tables_}, &sharded_slot_sizes)); auto shard_sizes_vec = shard_sizes->vec(); auto sharded_slot_sizes_vec = sharded_slot_sizes->vec(); shard_sizes_vec.setZero(); int uniq_id_size = 0; int emb_offset = 0; // compute a column major order emb_offsets for better cache std::vector emb_offsets_cm(num_tables_ * num_shards_); for (int n = 0; n < num_shards_; n++) { for (int m = 0; m < num_tables_; ++m) { auto idx = n * num_tables_ + m; auto sz = ids_for_splits[idx].size(); sharded_slot_sizes_vec(idx) = sz; shard_sizes_vec(n) += sz; uniq_id_size += sz; emb_offsets_cm[m * num_shards_ + n] = emb_offset; emb_offset += sz * slot_embedding_dims_[m]; } } OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {uniq_id_size}, &output)); auto output_ptr = output->vec().data(); for (const auto& vec : ids_for_splits) { std::memcpy(output_ptr, vec.data(), sizeof(T) * vec.size()); output_ptr += vec.size(); } Tensor *emb_offset_sz, *fused_emb_offset; OP_REQUIRES_OK(ctx, ctx->allocate_output(3, {num_tables_}, &emb_offset_sz)); OP_REQUIRES_OK(ctx, ctx->allocate_output(4, {total_fids}, &fused_emb_offset)); auto emb_offset_sz_vec = emb_offset_sz->vec().data(); auto fused_emb_offset_vec = fused_emb_offset->vec().data(); total_fids = 0; for (int m = 0; m < num_tables_; ++m) { auto sz = ctx->input(m).NumElements(); auto data = ctx->input(m).vec().data(); emb_offset_sz_vec[m] = sz; for (int i = 0; i < sz; ++i) { auto val = data[i]; fused_emb_offset_vec[total_fids + i] = ids_sets[m][val] + emb_offsets_cm[shard_func(val) + m * num_shards_]; } total_fids += sz; } // std::cout << "fused reorder took " // << (std::chrono::steady_clock::now() - start).count() * 1e-9 // << std::endl; } // TODO(hanzhizhou): consider precompute this or add specialization for // rank0_empty=false and num_shards=power of 2. Currently 10% of the total // time is spent on this function during the computation of this OP inline int shard_func(int64 val) { return val % (num_shards_ - rank0_empty_) + rank0_empty_; } private: bool rank0_empty_; int num_shards_; int num_tables_; std::vector slot_embedding_dims_; }; REGISTER_OP("FusedReorderByIndices") .Input("input: M * T") .Output("output: T") .Output("shard_sizes: int32") .Output("sharded_slot_sizes: int32") .Output("emb_offset_sz: int32") .Output("fused_emb_offset: int32") .Attr("num_of_shards: int") .Attr("slot_embedding_dims: list(int)") .Attr("rank0_empty: bool") .Attr("M: int") .Attr("T: type") .SetShapeFn([](shape_inference::InferenceContext* c) { shape_inference::ShapeHandle input; std::vector dim_handles; dim_handles.push_back(c->UnknownDim()); // The WithRank call validates that the input shape c->input(0) // has a shape with exactly one dimension. TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &input)); // The first output is the tensor in shape (?,) // It contains deduped reordered ids. c->set_output(0, c->MakeShape(dim_handles)); // The second output is the tensor in shape (num_of_shards,) // It contains shard sizes. int num_of_shards; TF_RETURN_IF_ERROR(c->GetAttr("num_of_shards", &num_of_shards)); c->set_output(1, c->MakeShape({num_of_shards})); // The third output is the tensor in shape (num_of_shards*M,) // where M is the number of type T input. // It contains sharded (merged) slot sizes. int M; TF_RETURN_IF_ERROR(c->GetAttr("M", &M)); c->set_output(2, c->MakeShape({num_of_shards * M})); // The fourth output is an array of offsets to the fifth output c->set_output(3, c->Vector(M)); c->set_output(4, c->Vector(c->UnknownDim())); return Status::OK(); }); #define REGISTER_KERNEL_FUSED_REORDER_BY_INDICES(type) \ REGISTER_KERNEL_BUILDER(Name("FusedReorderByIndices") \ .Device(DEVICE_CPU) \ .TypeConstraint("T"), \ FusedReorderByIndicesOp) REGISTER_KERNEL_FUSED_REORDER_BY_INDICES(int64); #undef REGISTER_KERNEL_FUSED_REORDER_BY_INDICES } // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/runtime/ops/gen_monolith_ops.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import tensorflow as tf from monolith.native_training.runtime.ops.gen_monolith_ops_base import * from monolith import utils tf.load_library( utils.get_libops_path( "monolith/native_training/runtime/ops/libtfkernel_monolith_ops_for_load.so" )) ================================================ FILE: monolith/native_training/runtime/ops/gen_seq_mask.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/shape_inference.h" namespace tensorflow { namespace monolith_tf { // The difference between this reduce sum op and tf.sparse.reduce_sum is that // this supports sparse values which are vectors. template class GenSeqMaskOp : public OpKernel { public: explicit GenSeqMaskOp(OpKernelConstruction* ctx) : OpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("max_seq_length", &max_seq_length_)); } void Compute(OpKernelContext* ctx) override { const Tensor& splits = ctx->input(0); int64 batch_size = splits.dim_size(0) - 1; Tensor* mask = nullptr; OP_REQUIRES_OK( ctx, ctx->allocate_output(0, {batch_size, max_seq_length_}, &mask)); std::memset(mask->data(), 0, mask->AllocatedBytes()); auto splits_flat = splits.flat(); auto mask_mat = mask->matrix(); for (int64 i = 0; i < batch_size; ++i) { T size = splits_flat(i + 1) - splits_flat(i); size = size > max_seq_length_ ? max_seq_length_ : size; for (size_t j = 0; j < size; ++j) mask_mat(i, j) = 1; } } private: int max_seq_length_; }; REGISTER_OP("GenSeqMask") .Input("splits: T") .Output("mask: T") .Attr("max_seq_length: int") .Attr("T: {int32, int64}") .SetShapeFn([](shape_inference::InferenceContext* ctx) { int max_seq_length; TF_RETURN_IF_ERROR(ctx->GetAttr("max_seq_length", &max_seq_length)); if (ctx->FullyDefined(ctx->input(0))) { tensorflow::shape_inference::DimensionHandle batch_size; tensorflow::shape_inference::DimensionHandle input_dim = ctx->Dim(ctx->input(0), 0); TF_RETURN_IF_ERROR( ctx->Subtract(input_dim, ctx->MakeDim(1), &batch_size)); ctx->set_output( 0, ctx->MakeShape({batch_size, ctx->MakeDim(max_seq_length)})); } else { ctx->set_output(0, ctx->MakeShape({ctx->UnknownDim(), ctx->MakeDim(max_seq_length)})); } return Status::OK(); }); REGISTER_KERNEL_BUILDER( Name("GenSeqMask").Device(DEVICE_CPU).TypeConstraint("T"), GenSeqMaskOp); REGISTER_KERNEL_BUILDER( Name("GenSeqMask").Device(DEVICE_CPU).TypeConstraint("T"), GenSeqMaskOp); } // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/runtime/ops/global_norm.cu.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #if GOOGLE_CUDA #define EIGEN_USE_GPU #include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/kernels/gpu_device_array.h" #include "tensorflow/core/kernels/gpu_device_array_gpu.h" #include "tensorflow/core/kernels/gpu_prim.h" #include "tensorflow/core/util/gpu_kernel_helper.h" namespace tensorflow { namespace monolith { namespace { // Seperate for CUDA Kernel Def template __global__ void globalReduceSum( GpuDeviceArrayStruct input_ptrs_da, GpuDeviceArrayStruct offsets_da, float* out, int size) { const float** input_ptrs = GetGpuDeviceArrayOnDevice(&input_ptrs_da); int* offsets = GetGpuDeviceArrayOnDevice(&offsets_da); // if using shared memory // Ref: // https://github.com/tensorflow/tensorflow/blob/v2.4.0/tensorflow/core/kernels/concat_lib_gpu_impl.cu.cc#L73 GPU_DYNAMIC_SHARED_MEM_DECL(sizeof(int), unsigned char, smem); int* smem_offsets = reinterpret_cast(smem); for (int x = threadIdx.x; x < offsets_da.size; x += blockDim.x) { smem_offsets[x] = offsets[x]; } __syncthreads(); offsets = smem_offsets; float thread_sum = 0; int i = 0; GPU_1D_KERNEL_LOOP(idx, size) { // safe offsets read: when idx == size - 1, i+1 == num_inputs while (offsets[i + 1] <= idx) ++i; int j = idx - offsets[i]; float v = ldg(input_ptrs[i] + j); thread_sum += v * v; // l2 } // thread reduce sum to block reduce sum typedef gpuprim::BlockReduce BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; float block_sum = BlockReduce(temp_storage).Sum(thread_sum); __syncthreads(); if (threadIdx.x == 0) // block reduce sum to global reduce sum atomicAdd(out, block_sum); } } // namespace template struct GlobalReduceImpl { static void Compute(OpKernelContext* context, const std::vector& input_ptrs, const std::vector& input_lens, const std::vector& output_ptrs, float global_norm, float clip_norm); }; typedef Eigen::GpuDevice GPUDevice; template struct SetZeroFunctor { void operator()(const GPUDevice& d, typename TTypes::Scalar out) { To32Bit(out).device(d) = To32Bit(out).constant(T(0)); } }; template <> struct GlobalReduceImpl { static void Compute(OpKernelContext* context, const std::vector& input_ptrs, const std::vector& input_lens, TTypes::Scalar output) { GPUDevice gpu_device = context->eigen_device(); int num_inputs = input_ptrs.size(); GpuDeviceArrayOnHost input_ptrs_da(context, num_inputs); OP_REQUIRES_OK(context, input_ptrs_da.Init()); for (int i = 0; i < num_inputs; ++i) { input_ptrs_da.Set(i, input_ptrs[i]); } OP_REQUIRES_OK(context, input_ptrs_da.Finalize()); int offset = 0; GpuDeviceArrayOnHost offsets(context, num_inputs + 1); int smem_usage = sizeof(int) * (num_inputs + 1); OP_REQUIRES_OK(context, offsets.Init()); for (int i = 0; i < num_inputs; ++i) { offsets.Set(i, offset); offset += input_lens[i]; } offsets.Set(num_inputs, offset); // offset val here is total workload OP_REQUIRES_OK(context, offsets.Finalize()); SetZeroFunctor zero_functor; zero_functor(gpu_device, output); const int thread_per_block = 1024; // const int for globalReduceSum const int physical_thread_count = std::min(gpu_device.getNumGpuMultiProcessors() * gpu_device.maxGpuThreadsPerMultiProcessor(), offset); const int block_count = std::min(DivUp(physical_thread_count, thread_per_block), gpu_device.getNumGpuMultiProcessors()); TF_CHECK_OK(GpuLaunchKernel(globalReduceSum, block_count, thread_per_block, smem_usage, gpu_device.stream(), input_ptrs_da.data(), offsets.data(), output.data(), offset)); } }; template class GlobalReduce : public OpKernel { public: explicit GlobalReduce(OpKernelConstruction* context) : OpKernel(context) { OP_REQUIRES_OK(context, context->GetAttr("N", &num_inputs_)); } void Compute(OpKernelContext* context) override { VLOG(1) << "In GlobalReduce Computation"; auto num_inputs = context->num_inputs(); std::vector input_ptrs(num_inputs); std::vector input_lens(num_inputs); for (int i = 0; i < num_inputs; ++i) { input_ptrs[i] = context->input(i).flat().data(); input_lens[i] = context->input(i).NumElements(); } Tensor* out; OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({}), &out)); GlobalReduceImpl::Compute(context, input_ptrs, input_lens, out->scalar().data()); } private: int num_inputs_; }; REGISTER_OP("GlobalL2Reduce") .Input("grad_list: N * float") .Output("global_norm: float") .Attr("N: int") .SetShapeFn(shape_inference::ScalarShape); REGISTER_KERNEL_BUILDER(Name("GlobalL2Reduce").Device(DEVICE_GPU), GlobalReduce); } // namespace monolith } // namespace tensorflow #endif // GOOGLE_CUDA ================================================ FILE: monolith/native_training/runtime/ops/gpu_multi_hash_table.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_NATIVE_TRAINING_RUNTIME_OPS_GPU_MULTI_HASH_TABLE #define MONOLITH_NATIVE_TRAINING_RUNTIME_OPS_GPU_MULTI_HASH_TABLE #ifdef GOOGLE_CUDA #define EIGEN_USE_GPU #include "monolith/native_training/runtime/hash_table/GPUcucohash/cuco_multi_table_ops.cuh.h" #include "monolith/native_training/runtime/ops/multi_hash_table.h" namespace tensorflow { namespace monolith_tf { class GpuMultiHashTable : public MultiHashTable { public: ::monolith::hash_table::CucoMultiHashTableOp op; explicit GpuMultiHashTable( absl::string_view shared_name, std::vector slot_occ = {}, ::monolith::hash_table::GpucucoEmbeddingHashTableConfig config = {}, cudaStream_t stream = 0) : MultiHashTable(shared_name), op(std::move(slot_occ), std::move(config), stream) {} }; } // namespace monolith_tf } // namespace tensorflow #endif #endif // MONOLITH_NATIVE_TRAINING_RUNTIME_OPS_GPU_MULTI_HASH_TABLE ================================================ FILE: monolith/native_training/runtime/ops/hash_filter_intercept_gradient_op.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/runtime/hash_filter/sliding_hash_filter.h" #include "monolith/native_training/runtime/ops/hash_filter_tf_bridge.h" #include "tensorflow/core/framework/op_kernel.h" namespace tensorflow { namespace monolith_tf { class HashFilterInterceptGradientOp : public OpKernel { public: explicit HashFilterInterceptGradientOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} void Compute(OpKernelContext* ctx) override { ctx->set_output(0, ctx->input(2)); } int threshold_; }; class HashFilterInterceptGradientGradientOp : public OpKernel { public: explicit HashFilterInterceptGradientGradientOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} void Compute(OpKernelContext* ctx) override { HashFilterTfBridge* filter = nullptr; OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &filter)); core::ScopedUnref unref(filter); const Tensor& ids = ctx->input(1); auto ids_vec = ids.vec(); const Tensor& grad = ctx->input(2); TensorShape grad_shape( {ids.NumElements(), grad.NumElements() / ids.NumElements()}); auto grad_mat = ctx->input(2).shaped(grad_shape.dim_sizes()); Tensor* filtered_grad; OP_REQUIRES_OK(ctx, ctx->allocate_output(0, grad_shape, &filtered_grad)); auto filtered_grad_mat = filtered_grad->shaped(grad_shape.dim_sizes()); for (int i = 0; i < ids_vec.dimension(0); ++i) { if (filter->ShouldBeFiltered(ids_vec(i))) { for (int j = 0; j < grad_shape.dim_size(1); ++j) { filtered_grad_mat(i, j) = 0; } } else { filtered_grad_mat.chip<0>(i) = grad_mat.chip<0>(i); } } } }; REGISTER_OP("MonolithHashFilterInterceptGradient") .Input("filter_handle: resource") .Input("ids: int64") .Input("embeddings: float") .Output("same_embeddings: float") .SetShapeFn([](shape_inference::InferenceContext* c) { c->set_output(0, c->input(2)); return Status::OK(); }); REGISTER_KERNEL_BUILDER( Name("MonolithHashFilterInterceptGradient").Device(DEVICE_CPU), HashFilterInterceptGradientOp); REGISTER_OP("MonolithHashFilterInterceptGradientGradient") .Input("filter_handle: resource") .Input("ids: int64") .Input("grad: float") .Output("filted_grad: float") .SetShapeFn([](shape_inference::InferenceContext* c) { c->set_output(0, c->input(2)); return Status::OK(); }); REGISTER_KERNEL_BUILDER( Name("MonolithHashFilterInterceptGradientGradient").Device(DEVICE_CPU), HashFilterInterceptGradientGradientOp); } // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/runtime/ops/hash_filter_op.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "monolith/native_training/runtime/hash_filter/dummy_hash_filter.h" #include "monolith/native_training/runtime/hash_filter/probabilistic_filter.h" #include "monolith/native_training/runtime/hash_filter/sliding_hash_filter.h" #include "monolith/native_training/runtime/ops/hash_filter_tf_bridge.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_op_kernel.h" namespace tensorflow { namespace monolith_tf { using ::monolith::hash_filter::DummyHashFilter; using ::monolith::hash_filter::ProbabilisticFilter; using ::monolith::hash_filter::SlidingHashFilter; class DummyFilterOp : public ResourceOpKernel { public: explicit DummyFilterOp(OpKernelConstruction* ctx) : ResourceOpKernel(ctx) {} ~DummyFilterOp() override = default; private: Status CreateResource(HashFilterTfBridge** filter_bridge) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) override { auto filter = std::make_unique(); *filter_bridge = new HashFilterTfBridge(std::move(filter), config_); return Status::OK(); }; monolith::hash_table::SlotOccurrenceThresholdConfig config_; }; class HashFilterOp : public ResourceOpKernel { public: explicit HashFilterOp(OpKernelConstruction* ctx) : ResourceOpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("capacity", &capacity_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("split_num", &split_num_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("config", &config_serialized_)); if (!config_serialized_.empty()) { OP_REQUIRES( ctx, config_.ParseFromString(config_serialized_), errors::InvalidArgument("Unable to parse config. Make sure it " "is serialized version of " "SlotOccurrenceThresholdConfig.")); } } ~HashFilterOp() override {} private: Status CreateResource(HashFilterTfBridge** filter_bridge) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) override { auto filter = std::make_unique(capacity_, split_num_); // TODO(leqi.zou): We know this is NOT thread safe. But let's keep it as it // is because we may remove HashFilter in the future. *filter_bridge = new HashFilterTfBridge(std::move(filter), config_); return Status::OK(); }; int64 capacity_; int split_num_; std::string config_serialized_; monolith::hash_table::SlotOccurrenceThresholdConfig config_; }; class ProbabilisticFilterOp : public ResourceOpKernel { public: explicit ProbabilisticFilterOp(OpKernelConstruction* ctx) : ResourceOpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("equal_probability", &equal_probability_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("config", &config_serialized_)); if (!config_serialized_.empty()) { OP_REQUIRES( ctx, config_.ParseFromString(config_serialized_), errors::InvalidArgument("Unable to parse config. Make sure it " "is serialized version of " "SlotOccurrenceThresholdConfig.")); } } ~ProbabilisticFilterOp() override = default; private: Status CreateResource(HashFilterTfBridge** filter_bridge) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) override { auto filter = std::make_unique(equal_probability_); *filter_bridge = new HashFilterTfBridge(std::move(filter), config_, true); return Status::OK(); }; bool equal_probability_; std::string config_serialized_; monolith::hash_table::SlotOccurrenceThresholdConfig config_; }; REGISTER_OP("MonolithHashFilter") .Output("handle: resource") .Attr("capacity: int = 300000000") .Attr("split_num: int = 7") // Config contains a string of pb message SlotOccurrenceThresholdConfig. .Attr("config: string = ''") .Attr("container: string = ''") .Attr("shared_name: string = ''") .SetIsStateful() .SetShapeFn(shape_inference::ScalarShape); REGISTER_KERNEL_BUILDER(Name("MonolithHashFilter").Device(DEVICE_CPU), HashFilterOp); REGISTER_OP("MonolithProbabilisticFilter") .Output("handle: resource") .Attr("equal_probability: bool = false") // Config contains a string of pb message SlotOccurrenceThresholdConfig. .Attr("config: string = ''") .Attr("container: string = ''") .Attr("shared_name: string = ''") .SetIsStateful() .SetShapeFn(shape_inference::ScalarShape); REGISTER_OP("MonolithDummyHashFilter") .Output("handle: resource") .Attr("container: string = ''") .Attr("shared_name: string = ''") .SetIsStateful() .SetShapeFn(shape_inference::ScalarShape); REGISTER_KERNEL_BUILDER(Name("MonolithDummyHashFilter").Device(DEVICE_CPU), DummyFilterOp); REGISTER_KERNEL_BUILDER(Name("MonolithProbabilisticFilter").Device(DEVICE_CPU), ProbabilisticFilterOp); } // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/runtime/ops/hash_filter_restore_op.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_op_kernel.h" #include "tensorflow/core/lib/io/record_writer.h" #include "tensorflow/core/platform/path.h" #include "tensorflow/core/platform/threadpool.h" #include "monolith/native_training/runtime/hash_filter/dummy_hash_filter.h" #include "monolith/native_training/runtime/hash_filter/sliding_hash_filter.h" #include "monolith/native_training/runtime/ops/file_utils.h" #include "monolith/native_training/runtime/ops/hash_filter_tf_bridge.h" #include "tensorflow/core/lib/io/record_reader.h" namespace tensorflow { namespace monolith_tf { using ::monolith::hash_table::SlidingHashFilterMetaDump; using ::monolith::hash_table::HashFilterSplitMetaDump; using ::monolith::hash_table::HashFilterSplitDataDump; class HashFilterRestoreOp : public AsyncOpKernel { public: explicit HashFilterRestoreOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {} void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { HashFilterTfBridge* hash_filter = nullptr; OP_REQUIRES_OK_ASYNC( ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &hash_filter), done); core::ScopedUnref unref(hash_filter); const Tensor& basename_tensor = ctx->input(1); const std::string basename = basename_tensor.scalar()(); std::vector files; OP_REQUIRES_OK_ASYNC( ctx, ctx->env()->GetMatchingPaths(absl::StrCat(basename, "-*"), &files), done); FileSpec file_spec; OP_REQUIRES_OK_ASYNC(ctx, ValidateShardedFiles(basename, files, &file_spec), done); OP_REQUIRES_ASYNC(ctx, file_spec.nshards() > 0, errors::NotFound("Unable to find the dump files for: ", name(), " in ", basename), done); ctx->set_output(0, ctx->input(0)); int nsplits = file_spec.nshards(); auto pack = new HashFilterAsyncPack(ctx, hash_filter, basename, std::move(done), nsplits); for (int i = 0; i < nsplits; ++i) { ctx->device()->tensorflow_cpu_worker_threads()->workers->Schedule( [this, pack, i, nsplits] { WorkerThread(i, nsplits, pack); }); } } private: void WorkerThread(int split_idx, int nsplits, HashFilterAsyncPack* p) { p->status[split_idx] = RestoreOneSplit(split_idx, nsplits, p); if (p->finish_num.fetch_add(1) == p->thread_num - 1) { Cleanup(p); } } Status RestoreOneSplit(int split_idx, int nsplits, HashFilterAsyncPack* p) { std::string filename = GetShardedFileName(p->basename, split_idx, nsplits); std::unique_ptr f; TF_RETURN_IF_ERROR(p->ctx->env()->NewRandomAccessFile(filename, &f)); io::RecordReaderOptions opts; opts.buffer_size = 10 * 1024 * 1024; io::SequentialRecordReader reader(f.get(), opts); Status restore_status; auto get_meta_fn = [&reader, &restore_status](HashFilterSplitMetaDump* dump) { Status s = GetMetaRecord(&reader, dump); if (TF_PREDICT_FALSE(!s.ok())) { if (!errors::IsOutOfRange(s)) { restore_status = s; } return false; } return true; }; auto get_data_fn = [&reader, &restore_status](HashFilterSplitDataDump* dump) { Status s = GetDataRecord(&reader, dump); if (TF_PREDICT_FALSE(!s.ok())) { if (!errors::IsOutOfRange(s)) { restore_status = s; } return false; } return true; }; TF_RETURN_IF_ERROR( p->hash_filter->Restore(split_idx, get_meta_fn, get_data_fn)); TF_RETURN_IF_ERROR(restore_status); return Status::OK(); } static Status GetMetaRecord(io::SequentialRecordReader* reader, HashFilterSplitMetaDump* dump) { tstring s; TF_RETURN_IF_ERROR(reader->ReadRecord(&s)); if (!dump->ParseFromArray(s.data(), s.size())) { return errors::FailedPrecondition( "Unable to parse data. Data might be corrupted"); } return Status::OK(); } static Status GetDataRecord(io::SequentialRecordReader* reader, HashFilterSplitDataDump* dump) { tstring s; TF_RETURN_IF_ERROR(reader->ReadRecord(&s)); if (!dump->ParseFromArray(s.data(), s.size())) { return errors::FailedPrecondition( "Unable to parse data. Data might be corrupted"); } return Status::OK(); } // Clean up when all shards are done. void Cleanup(HashFilterAsyncPack* p) { auto done = [p]() { // We want to delete p first and then call done. auto done = std::move(p->done); delete p; done(); }; for (int i = 0; i < p->thread_num; ++i) { OP_REQUIRES_OK_ASYNC(p->ctx, p->status[i], done); } done(); } }; REGISTER_OP("MonolithHashFilterRestore") .Input("handle: resource") .Input("basename: string") .Output("output_handle: resource") .SetShapeFn(shape_inference::ScalarShape); REGISTER_KERNEL_BUILDER(Name("MonolithHashFilterRestore").Device(DEVICE_CPU), HashFilterRestoreOp); } // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/runtime/ops/hash_filter_save_op.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "monolith/native_training/runtime/hash_filter/dummy_hash_filter.h" #include "monolith/native_training/runtime/hash_filter/sliding_hash_filter.h" #include "monolith/native_training/runtime/ops/file_utils.h" #include "monolith/native_training/runtime/ops/hash_filter_tf_bridge.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_op_kernel.h" #include "tensorflow/core/lib/io/record_writer.h" #include "tensorflow/core/platform/path.h" #include "tensorflow/core/platform/random.h" #include "tensorflow/core/platform/threadpool.h" namespace tensorflow { namespace monolith_tf { using ::monolith::hash_table::SlidingHashFilterMetaDump; using ::monolith::hash_table::HashFilterSplitMetaDump; using ::monolith::hash_table::HashFilterSplitDataDump; class HashFilterSaveOp : public AsyncOpKernel { public: explicit HashFilterSaveOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {} void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { HashFilterTfBridge* hash_filter = nullptr; OP_REQUIRES_OK_ASYNC( ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &hash_filter), done); core::ScopedUnref unref(hash_filter); const Tensor& basename_tensor = ctx->input(1); const std::string basename = basename_tensor.scalar()(); const std::string dirname = std::string(io::Dirname(basename)); OP_REQUIRES_OK_ASYNC(ctx, ctx->env()->RecursivelyCreateDir(dirname), done); ctx->set_output(0, ctx->input(0)); int nsplits = hash_filter->GetSplitNum(); if (nsplits == 0) { done(); return; } auto pack = new HashFilterAsyncPack(ctx, hash_filter, basename, std::move(done), nsplits); for (int i = 0; i < nsplits; ++i) { ctx->device()->tensorflow_cpu_worker_threads()->workers->Schedule( [this, i, nsplits, pack] { WorkerThread(i, nsplits, pack); }); } } private: void WorkerThread(int split_idx, int nsplits, HashFilterAsyncPack* p) { p->status[split_idx] = SaveOneSplit(split_idx, nsplits, p); if (p->finish_num.fetch_add(1) == p->thread_num - 1) { Cleanup(p); } } Status SaveOneSplit(int split_idx, int nsplits, HashFilterAsyncPack* p) { std::string filename = GetShardedFileName(p->basename, split_idx, nsplits); std::string tmp_filename = absl::StrCat(filename, "-tmp-", random::New64()); std::unique_ptr f; TF_RETURN_IF_ERROR(p->ctx->env()->NewWritableFile(tmp_filename, &f)); io::RecordWriter writer(f.get()); Status write_status; // In theory, we should stop writing once write failed. // But this requires a lot of refactoring and currently we only do 2 writes. // So we keep it as it is here. auto write_data_fn = [this, &writer, &write_status](HashFilterSplitDataDump dump) { Status s = writer.WriteRecord(dump.SerializeAsString()); if (TF_PREDICT_FALSE(!s.ok())) { write_status.Update(s); } }; auto write_meta_fn = [this, &writer, &write_status](HashFilterSplitMetaDump dump) { Status s = writer.WriteRecord(dump.SerializeAsString()); if (TF_PREDICT_FALSE(!s.ok())) { write_status.Update(s); } }; TF_RETURN_IF_ERROR( p->hash_filter->Save(split_idx, write_meta_fn, write_data_fn)); TF_RETURN_IF_ERROR(write_status); TF_RETURN_IF_ERROR(writer.Close()); TF_RETURN_IF_ERROR(f->Close()); TF_RETURN_IF_ERROR(p->ctx->env()->RenameFile(tmp_filename, filename)); return Status::OK(); } // Clean up when all shards are done. void Cleanup(HashFilterAsyncPack* p) { auto done = [p]() { // We want to delete p first and then call done. auto done = std::move(p->done); delete p; done(); }; for (int i = 0; i < p->thread_num; ++i) { OP_REQUIRES_OK_ASYNC(p->ctx, p->status[i], done); } done(); } }; REGISTER_OP("MonolithHashFilterSave") .Input("handle: resource") .Input("basename: string") .Output("output_handle: resource") .SetShapeFn(shape_inference::ScalarShape); REGISTER_KERNEL_BUILDER(Name("MonolithHashFilterSave").Device(DEVICE_CPU), HashFilterSaveOp); } // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/runtime/ops/hash_filter_tf_bridge.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/runtime/ops/hash_filter_tf_bridge.h" #include "monolith/native_training/data/training_instance/cc/reader_util.h" namespace tensorflow { namespace monolith_tf { using ::monolith::hash_filter::Filter; using ::monolith::hash_table::HashFilterSplitDataDump; using ::monolith::hash_table::HashFilterSplitMetaDump; using ::monolith::hash_table::SlotOccurrenceThresholdConfig; HashFilterTfBridge::HashFilterTfBridge( std::unique_ptr filter, const SlotOccurrenceThresholdConfig& config, bool is_probabilistic) : filter_(std::move(filter)), is_probabilistic_(is_probabilistic) { slot_to_occurrence_threshold_.resize(get_max_slot_number(), config.default_occurrence_threshold()); for (const auto& slot_occurrence_threshold : config.slot_occurrence_thresholds()) { slot_to_occurrence_threshold_[slot_occurrence_threshold.slot()] = slot_occurrence_threshold.occurrence_threshold(); } } int HashFilterTfBridge::GetSlotOccurrenceThreshold(int64_t fid) const { return slot_to_occurrence_threshold_[slot_id_v2(fid)]; } Status HashFilterTfBridge::Save( int split_idx, std::function write_meta_fn, std::function write_data_fn) const { try { filter_->Save(split_idx, std::move(write_meta_fn), std::move(write_data_fn)); return Status::OK(); } catch (const std::exception& e) { return errors::ResourceExhausted(e.what()); } } Status HashFilterTfBridge::Restore( int split_idx, std::function get_meta_fn, std::function get_data_fn) const { try { filter_->Restore(split_idx, std::move(get_meta_fn), std::move(get_data_fn)); return Status::OK(); } catch (const std::exception& e) { return errors::ResourceExhausted(e.what()); } } } // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/runtime/ops/hash_filter_tf_bridge.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_NATIVE_TRAINING_RUNTIME_OPS_HASH_FILTER_TF_BRIDGE_H_ #define MONOLITH_NATIVE_TRAINING_RUNTIME_OPS_HASH_FILTER_TF_BRIDGE_H_ #include #include #include "absl/strings/str_format.h" #include "monolith/native_training/runtime/hash_filter/filter.h" #include "monolith/native_training/runtime/ops/file_utils.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/framework/resource_op_kernel.h" namespace tensorflow { namespace monolith_tf { class HashFilterTfBridge : public ResourceBase { public: explicit HashFilterTfBridge( std::unique_ptr filter, const monolith::hash_table::SlotOccurrenceThresholdConfig& config, bool is_probabilistic = false); bool ShouldBeFiltered( int64_t id, int64_t count, monolith::hash_table::EmbeddingHashTableInterface* table) { return filter_->ShouldBeFiltered(id, count, GetSlotOccurrenceThreshold(id), table); } bool ShouldBeFiltered( int64_t id, monolith::hash_table::EmbeddingHashTableInterface* table = nullptr) { return ShouldBeFiltered(id, 1, table); } int GetSplitNum() { return filter_->split_num(); } std::string DebugString() const override { return absl::StrFormat("Filter with capacity: %d", filter_->capacity()); } // For the functor injected, it is ok to throw exceptions. Status Save( int split_idx, std::function write_meta_fn, std::function write_data_fn) const; Status Restore( int split_idx, std::function get_meta_fn, std::function get_data_fn) const; const std::vector& GetOccuranceThresholdArray() const { return slot_to_occurrence_threshold_; } bool IsProbabilistic() const { return is_probabilistic_; } private: int GetSlotOccurrenceThreshold(int64_t fid) const; std::unique_ptr filter_; std::vector slot_to_occurrence_threshold_; const bool is_probabilistic_; }; // Carries the data through async process. // It will ref and unref |p_hash_filter| struct HashFilterAsyncPack { HashFilterAsyncPack(OpKernelContext* p_ctx, HashFilterTfBridge* p_hash_filter, std::string p_basename, std::function p_done, int p_thread_num) : ctx(p_ctx), hash_filter(p_hash_filter), basename(p_basename), done(std::move(p_done)), thread_num(p_thread_num), finish_num(0), status(thread_num) { hash_filter->Ref(); } ~HashFilterAsyncPack() { hash_filter->Unref(); } OpKernelContext* ctx; HashFilterTfBridge* hash_filter; std::string basename; std::function done; const int thread_num; std::atomic_int finish_num; std::vector status; }; } // namespace monolith_tf } // namespace tensorflow #endif // MONOLITH_NATIVE_TRAINING_RUNTIME_OPS_HASH_FILTER_TF_BRIDGE_H_ ================================================ FILE: monolith/native_training/runtime/ops/hash_table/misc_ops.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/data/training_instance/cc/reader_util.h" #include "monolith/native_training/runtime/hash_table/embedding_hash_table_interface.h" #include "monolith/native_training/runtime/ops/embedding_hash_table_tf_bridge.h" #include "tensorflow/core/framework/op_kernel.h" namespace tensorflow { namespace monolith_tf { namespace { class HashTableSizeOp : public OpKernel { public: explicit HashTableSizeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} void Compute(OpKernelContext* ctx) override { EmbeddingHashTableTfBridge* table = nullptr; OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &table)); Tensor* size_tensor; OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {}, &size_tensor)); size_tensor->scalar()() = table->Size(); } }; REGISTER_OP("MonolithHashTableSize") .Input("table_handle: resource") .Output("size: int64") .SetShapeFn(shape_inference::ScalarShape); REGISTER_KERNEL_BUILDER(Name("MonolithHashTableSize").Device(DEVICE_CPU), HashTableSizeOp); class SaveAsTensorOp : public OpKernel { public: explicit SaveAsTensorOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} void Compute(OpKernelContext* c) override { EmbeddingHashTableTfBridge* table = nullptr; OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &table)); auto shard = GetShard(c); auto iter = GetIter(c); std::vector dumps; dumps.reserve(shard.limit); auto write_fn = [&dumps](EmbeddingHashTableTfBridge::EntryDump dump) { dumps.push_back(std::move(dump)); return true; }; OP_REQUIRES_OK(c, table->Save(c, shard, write_fn, &iter)); Tensor* new_offset; OP_REQUIRES_OK(c, c->allocate_output(0, {}, &new_offset)); new_offset->scalar()() = iter.offset; Tensor* out_dump; OP_REQUIRES_OK(c, c->allocate_output(1, {(int64)dumps.size()}, &out_dump)); auto out_dump_vec = out_dump->vec(); for (int i = 0; i < dumps.size(); ++i) { out_dump_vec(i) = dumps[i].SerializeAsString(); } } private: EmbeddingHashTableTfBridge::DumpShard GetShard(OpKernelContext* c) { EmbeddingHashTableTfBridge::DumpShard shard; shard.idx = c->input(1).scalar()(); shard.total = c->input(2).scalar()(); shard.limit = c->input(3).scalar()(); return shard; } EmbeddingHashTableTfBridge::DumpIterator GetIter(OpKernelContext* c) { EmbeddingHashTableTfBridge::DumpIterator iter; iter.offset = c->input(4).scalar()(); return iter; } private: int num_shard_; int shard_id_; }; REGISTER_OP("MonolithHashTableSaveAsTensor") .Input("table_handle: resource") .Input("shard_idx: int32") .Input("num_shards: int32") .Input("limit: int64") .Input("offset: int64") .Output("new_offset: int64") .Output("entry: string") .SetShapeFn([](shape_inference::InferenceContext* ctx) { ctx->set_output(0, ctx->Scalar()); ctx->set_output(1, ctx->Vector({ctx->UnknownDim()})); return Status::OK(); }); REGISTER_KERNEL_BUILDER( Name("MonolithHashTableSaveAsTensor").Device(DEVICE_CPU), SaveAsTensorOp); // In the future, probably we want to convert EntryDump to a set of tensors // before we extract useful information. class ExtractSlotFromEntryOp : public OpKernel { public: explicit ExtractSlotFromEntryOp(OpKernelConstruction* c) : OpKernel(c) { c->GetAttr("fid_v2", &fid_v2_); } void Compute(OpKernelContext* c) { const Tensor& dump = c->input(0); auto dump_vec = dump.vec(); const int len = dump.NumElements(); Tensor* slot; OP_REQUIRES_OK(c, c->allocate_output(0, {len}, &slot)); auto slot_vec = slot->vec(); for (int i = 0; i < len; ++i) { const tstring& s = dump_vec(i); monolith::hash_table::EntryDump d; if (!d.ParseFromArray(s.data(), s.size())) { LOG_EVERY_N_SEC(WARNING, 10) << "Fail to parse entry dump."; } if (fid_v2_) { slot_vec(i) = slot_id_v2(d.id()); } else { slot_vec(i) = slot_id_v1(d.id()); } } } private: bool fid_v2_; }; REGISTER_OP("MonolithExtractSlotFromEntry") .Input("entry: string") .Output("slot: int32") .Attr("fid_v2: bool") .SetShapeFn([](shape_inference::InferenceContext* ctx) { ctx->set_output(0, ctx->input(0)); return Status::OK(); }); REGISTER_KERNEL_BUILDER(Name("MonolithExtractSlotFromEntry").Device(DEVICE_CPU), ExtractSlotFromEntryOp); } // namespace } // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/runtime/ops/hash_table_lookup_op.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "absl/strings/str_format.h" #include "monolith/native_training/runtime/hash_table/utils.h" #include "monolith/native_training/runtime/ops/embedding_hash_table_tf_bridge.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/platform/default/integral_types.h" #include "tensorflow/core/util/work_sharder.h" namespace tensorflow { namespace monolith_tf { using CPUDevice = Eigen::ThreadPoolDevice; class HashTableLookupOp : public OpKernel { public: explicit HashTableLookupOp(OpKernelConstruction* ctx) : OpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("dim_size", &dim_size_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("use_multi_threads", &use_multi_threads_)); } void Compute(OpKernelContext* ctx) override { EmbeddingHashTableTfBridge* hash_table = nullptr; OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &hash_table)); core::ScopedUnref unref(hash_table); const Tensor& ids = ctx->input(1); const int64 len_ids = ids.NumElements(); OP_REQUIRES( ctx, dim_size_ == hash_table->dim_size(), errors::InvalidArgument(absl::StrFormat( "dim_size should match hash table size. %d vs %d. Node name: %s", dim_size_, hash_table->dim_size(), def().name()))); Tensor* embeddings; OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {len_ids, dim_size_}, &embeddings)); auto embeddings_mat = embeddings->matrix(); auto ids_flat = ids.flat(); if (use_multi_threads_) { auto lookup = [&](const int64 begin, const int64 end) { int64_t hit_fid_count = 0; hash_table->BatchLookup( ctx, (end - begin), const_cast(ids_flat.data() + begin), embeddings_mat.data() + begin * dim_size_, &hit_fid_count); }; // TODO(zhangbiao.david, youlong.cheng): tweak this number for // optimization. const int64 kCostPerUnit = 8 * dim_size_; const DeviceBase::CpuWorkerThreads& worker_threads = *ctx->device()->tensorflow_cpu_worker_threads(); Shard(worker_threads.num_threads, worker_threads.workers, len_ids, kCostPerUnit, lookup); } else { int64_t hit_fid_count = 0; hash_table->BatchLookup(ctx, len_ids, const_cast(ids_flat.data()), embeddings_mat.data(), &hit_fid_count); } } int64 dim_size_; bool use_multi_threads_; }; class HashTableLookupEntryOp : public OpKernel { public: explicit HashTableLookupEntryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} void Compute(OpKernelContext* ctx) override { EmbeddingHashTableTfBridge* hash_table = nullptr; OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &hash_table)); core::ScopedUnref unref(hash_table); const Tensor& ids = ctx->input(1); const int64 len_ids = ids.NumElements(); auto ids_flat = ids.flat(); std::vector entries(len_ids); if (entries.size() > 0) { hash_table->BatchLookupEntry( ctx, len_ids, const_cast(ids_flat.data()), &entries[0]); } Tensor* entry_strs; OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {len_ids}, &entry_strs)); auto entry_str_vec = entry_strs->vec(); for (int i = 0; i < entries.size(); ++i) { entry_str_vec(i) = entries[i].SerializeAsString(); } } }; class HashTableLookupGradientOp : public OpKernel { public: explicit HashTableLookupGradientOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} void Compute(OpKernelContext* ctx) override { const Tensor& id_indices = ctx->input(0); const Tensor& id_values = ctx->input(1); const Tensor& input_grads = ctx->input(2); OP_REQUIRES( ctx, id_indices.dim_size(0) == id_values.dim_size(0), errors::InvalidArgument( "id_indices's first dim and id_values dim should be same. Got ", id_indices.dim_size(0), "v.s. ", id_values.dim_size(0))); const int64 batch_size = id_indices.dim_size(0); const int64 embedding_dim = input_grads.dim_size(1); Tensor* output_ids; ctx->allocate_output(0, {batch_size}, &output_ids); Tensor* output_grads; ctx->allocate_output(1, {batch_size, embedding_dim}, &output_grads); auto id_indices_mat = id_indices.matrix(); auto id_values_vec = id_values.vec(); auto input_grads_mat = input_grads.matrix(); auto output_ids_vec = output_ids->vec(); auto output_grads_mat = output_grads->matrix(); for (int64 i = 0; i < batch_size; ++i) { const int64 batch = id_indices_mat(i, 0); const int64 id = id_values_vec(i); output_ids_vec(i) = id; for (int64 j = 0; j < embedding_dim; ++j) { output_grads_mat(i, j) = input_grads_mat(batch, j); } } } }; template class HashTableFusedLookupOp : public OpKernel { public: explicit HashTableFusedLookupOp(OpKernelConstruction* ctx) : OpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("N", &num_tables_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("num_of_shards", &num_shards_)); } void ComputeH(OpKernelContext* ctx); void Compute(OpKernelContext* ctx) override { ComputeH(ctx); } private: int num_shards_; int num_tables_; }; template <> void HashTableFusedLookupOp::ComputeH(OpKernelContext* ctx) { auto ids_flat = ctx->input(num_tables_ + 0).flat().data(); auto slot_size_vec = ctx->input(num_tables_ + 1).vec().data(); auto slot_size_cnt = num_tables_ * num_shards_; Tensor *embeddings_ts, *emb_splits_ts, *key_offsets_ts, *emb_offsets_ts; OP_REQUIRES_OK(ctx, ctx->allocate_output(1, {num_shards_}, &emb_splits_ts)); OP_REQUIRES_OK(ctx, ctx->allocate_output(2, {slot_size_cnt + 1}, &key_offsets_ts)); OP_REQUIRES_OK(ctx, ctx->allocate_output(3, {slot_size_cnt + 1}, &emb_offsets_ts)); ctx->set_output(4, ctx->input(num_tables_ + 0)); auto key_offsets = key_offsets_ts->vec().data(); auto emb_offsets = emb_offsets_ts->vec().data(); auto emb_splits = emb_splits_ts->vec().data(); std::vector hash_tables(num_tables_, nullptr); std::vector hash_table_dims(num_tables_, 0); for (int table_id = 0; table_id < num_tables_; table_id++) { EmbeddingHashTableTfBridge* hash_table = nullptr; OP_REQUIRES_OK( ctx, LookupResource(ctx, HandleFromInput(ctx, table_id), &hash_table)); core::ScopedUnref unref(hash_table); hash_tables[table_id] = hash_table; hash_table_dims[table_id] = hash_table->dim_size(); } int total_keys, total_embs; std::tie(total_keys, total_embs) = monolith::hash_table::ComputeFusedOffsets( slot_size_vec, hash_table_dims.data(), num_tables_, num_shards_, key_offsets, emb_offsets, nullptr, emb_splits); OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {total_embs}, &embeddings_ts)); auto embeddings = embeddings_ts->vec().data(); auto lookup = [&](const int begin, const int end) { for (int shard_id = begin; shard_id < end; shard_id++) { for (int table_id = 0; table_id < num_tables_; table_id++) { int curr_idx = shard_id * num_tables_ + table_id; int64_t hit_fid_count = 0; hash_tables[table_id]->BatchLookup( ctx, slot_size_vec[curr_idx], const_cast(ids_flat) + key_offsets[curr_idx], embeddings + emb_offsets[curr_idx], &hit_fid_count); } } }; // TODO(zouxuan): tweak this number for optimization. const int64 kCostPerUnit = 1000000; const DeviceBase::CpuWorkerThreads& worker_threads = *ctx->device()->tensorflow_cpu_worker_threads(); Shard(worker_threads.num_threads, worker_threads.workers, num_shards_, kCostPerUnit, lookup); } REGISTER_OP("MonolithHashTableLookup") .Input("table_handle: resource") .Input("ids: int64") .Output("embeddings: float32") .Attr("dim_size: int") .Attr("use_multi_threads: bool = false") .SetShapeFn([](shape_inference::InferenceContext* c) { int64 dim_size; TF_RETURN_IF_ERROR(c->GetAttr("dim_size", &dim_size)); shape_inference::DimensionHandle len_ids = c->Dim(c->input(1), 0); c->set_output(0, c->Matrix(len_ids, dim_size)); return Status::OK(); }); REGISTER_KERNEL_BUILDER(Name("MonolithHashTableLookup").Device(DEVICE_CPU), HashTableLookupOp); REGISTER_OP("MonolithHashTableLookupEntry") .Input("table_handle: resource") .Input("ids: int64") .Output("entry_str: string") .SetShapeFn([](shape_inference::InferenceContext* c) { shape_inference::DimensionHandle len_ids = c->Dim(c->input(1), 0); c->set_output(0, c->Vector(len_ids)); return Status::OK(); }); REGISTER_KERNEL_BUILDER(Name("MonolithHashTableLookupEntry").Device(DEVICE_CPU), HashTableLookupEntryOp); REGISTER_OP("MonolithHashTableLookupGradient") .Input("id_indices: int64") .Input("id_values: int64") .Input("input_grads : float") .Output("ids: int64") .Output("output_grads: float") .SetShapeFn([](shape_inference::InferenceContext* c) { shape_inference::DimensionHandle batch_size = c->Dim(c->input(0), 0); shape_inference::DimensionHandle embedding_size = c->Dim(c->input(2), 1); c->set_output(0, c->MakeShape({batch_size})); c->set_output(1, c->MakeShape({batch_size, embedding_size})); return Status::OK(); }); REGISTER_KERNEL_BUILDER( Name("MonolithHashTableLookupGradient").Device(DEVICE_CPU), HashTableLookupGradientOp); REGISTER_OP("MonolithHashTableFusedLookup") .Input("table_handles: N * resource") .Input("ids: int64") .Input("fused_slot_size: int32") .Input("req_time: int64") .Output("embeddings: float32") .Output("embedding_splits: int32") .Output("id_offsets: int32") .Output("embedding_offsets: int32") .Output("indices: int64") .Attr("N: int") .Attr("num_of_shards: int") .SetShapeFn([](shape_inference::InferenceContext* c) { int num_tables, num_shards; TF_RETURN_IF_ERROR(c->GetAttr("num_of_shards", &num_shards)); TF_RETURN_IF_ERROR(c->GetAttr("N", &num_tables)); c->set_output(0, c->Vector(c->UnknownDim())); c->set_output(1, c->Vector(num_shards)); c->set_output(2, c->Vector(num_tables * num_shards + 1)); c->set_output(3, c->Vector(num_tables * num_shards + 1)); c->set_output(4, c->input(num_tables)); auto shape = c->input(num_tables + 1); TF_RETURN_IF_ERROR(c->WithRank(shape, 1, &shape)); auto dim = c->Dim(shape, 0); TF_RETURN_IF_ERROR(c->WithValue(dim, num_tables * num_shards, &dim)); return Status::OK(); }); REGISTER_KERNEL_BUILDER(Name("MonolithHashTableFusedLookup").Device(DEVICE_CPU), HashTableFusedLookupOp); } // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/runtime/ops/hash_table_op.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "absl/strings/str_format.h" #include "absl/synchronization/mutex.h" #include "monolith/native_training/runtime/ops/embedding_hash_table_tf_bridge.h" #include "monolith/native_training/runtime/ops/parameter_sync_tf_bridge.h" #include "monolith/native_training/runtime/parameter_sync/parameter_sync_client.h" #include "tensorflow/core/framework/lookup_interface.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/platform/default/integral_types.h" #include "tensorflow/core/platform/mutex.h" namespace tensorflow { namespace monolith_tf { using ::monolith::hash_table::EmbeddingHashTableConfig; using ::monolith::hash_table::GpuExtraArgs; // using ::monolith::hopscotch::HopscotchHashSet; using ::monolith::parameter_sync::ParameterSyncClient; using CPUDevice = Eigen::ThreadPoolDevice; template class HashTableOp : public OpKernel { public: explicit HashTableOp(OpKernelConstruction* ctx) : OpKernel(ctx) { std::string config_serialized; OP_REQUIRES_OK(ctx, ctx->GetAttr("config", &config_serialized)); OP_REQUIRES(ctx, config_.ParseFromString(config_serialized), errors::InvalidArgument("Unable to parse config. Make sure it " "is serialized version of " "EmbeddingHashTableConfig.")); } ~HashTableOp() override { if (hash_table_ != nullptr) { if (cinfo_.resource_is_private_to_kernel()) { cinfo_.resource_manager() ->Delete(cinfo_.container(), cinfo_.name()) .IgnoreError(); } // here we use different way than ResourceKernelOp. Otherwise, // we got crash and I believe it is our compiler's problem. hash_table_->Unref(); } if (hash_filter_ != nullptr) { hash_filter_->Unref(); } } void ComputeH(OpKernelContext* ctx); void Compute(OpKernelContext* ctx) override { absl::MutexLock l(&mu_); if (hash_filter_ == nullptr) { OP_REQUIRES_OK( ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &hash_filter_)); } if (hash_table_ == nullptr) { ComputeH(ctx); } OP_REQUIRES_OK(ctx, MakeResourceHandleToOutput( ctx, 0, cinfo_.container(), cinfo_.name(), TypeIndex::Make())); } private: EmbeddingHashTableConfig config_; absl::Mutex mu_; EmbeddingHashTableTfBridge* hash_table_ ABSL_GUARDED_BY(mu_) = nullptr; HashFilterTfBridge* hash_filter_ ABSL_GUARDED_BY(mu_) = nullptr; ContainerInfo cinfo_ ABSL_GUARDED_BY(mu_); }; template <> void HashTableOp::ComputeH(OpKernelContext* ctx) { ResourceMgr* rmgr = ctx->resource_manager(); OP_REQUIRES_OK(ctx, cinfo_.Init(rmgr, def())); core::RefCountPtr sync_client = nullptr; OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 1), &sync_client)); auto sync_client_ptr = sync_client.get(); auto creator = [&, this](EmbeddingHashTableTfBridge** out_hash_table) { TF_RETURN_IF_ERROR(EmbeddingHashTableTfBridge::New( config_, hash_filter_, out_hash_table, cinfo_.name())); if (sync_client_ptr->IsDummySyncClient()) { LOG(INFO) << absl::StrFormat( "Hash table %s will not be attached to the sync client", cinfo_.name()); } else { // TODO(zhangbiao.david) Make hopscotch hash set configurable auto* touched_key_set = sync_client_ptr->GetTouchedKeySet(); (*out_hash_table)->SetHopscotchHashSet(touched_key_set); sync_client_ptr->AddHashTableResource(cinfo_.name(), *out_hash_table); LOG(INFO) << absl::StrFormat( "Hash table %s will be attached to the sync client", cinfo_.name()); } return Status::OK(); }; OP_REQUIRES_OK(ctx, rmgr->LookupOrCreate( cinfo_.container(), cinfo_.name(), &hash_table_, creator)); } REGISTER_OP("MonolithHashTable") .Input("filter_handle: resource") .Input("sync_client_handle: resource") .Output("handle: resource") .Attr("config: string") .Attr("container: string = ''") .Attr("shared_name: string = ''") .SetIsStateful() .SetShapeFn(shape_inference::ScalarShape); REGISTER_KERNEL_BUILDER(Name("MonolithHashTable").Device(DEVICE_CPU), HashTableOp); } // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/runtime/ops/hash_table_restore_op.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include #include "absl/strings/str_cat.h" #include "monolith/native_training/runtime/ops/embedding_hash_table_tf_bridge.h" #include "monolith/native_training/runtime/ops/file_utils.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/io/record_reader.h" #include "tensorflow/core/platform/threadpool.h" namespace tensorflow { namespace monolith_tf { namespace { // Carries the data through async process. // It will ref and unref |p_hash_table|. struct AsyncPack { AsyncPack(OpKernelContext* p_ctx, EmbeddingHashTableTfBridge* p_hash_table, std::string p_basename, std::function p_done, int p_thread_num) : ctx(p_ctx), basename(std::move(p_basename)), record_count(0), hash_table(p_hash_table), done(std::move(p_done)), thread_num(p_thread_num), finish_num(0), status(thread_num) { hash_table->Ref(); } ~AsyncPack() { hash_table->Unref(); } OpKernelContext* ctx; std::string basename; std::atomic_long record_count; EmbeddingHashTableTfBridge* hash_table; std::function done; const int thread_num; std::atomic_int finish_num; std::vector status; }; } // namespace class HashTableRestoreOp : public AsyncOpKernel { public: explicit HashTableRestoreOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {} void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { EmbeddingHashTableTfBridge* hash_table = nullptr; OP_REQUIRES_OK_ASYNC( ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &hash_table), done); core::ScopedUnref unref(hash_table); const Tensor& basename_tensor = ctx->input(1); const std::string basename = basename_tensor.scalar()(); std::vector files; OP_REQUIRES_OK_ASYNC( ctx, ctx->env()->GetMatchingPaths(absl::StrCat(basename, "-*"), &files), done); FileSpec file_spec; OP_REQUIRES_OK_ASYNC(ctx, ValidateShardedFiles(basename, files, &file_spec), done); OP_REQUIRES_ASYNC(ctx, file_spec.nshards() > 0, errors::NotFound("Unable to find the dump files for: ", name(), " in ", basename), done); ctx->set_output(0, ctx->input(0)); hash_table->Clear(); int nshards = files.size(); auto pack = new AsyncPack(ctx, hash_table, basename, std::move(done), nshards); for (int i = 0; i < nshards; ++i) { ctx->device()->tensorflow_cpu_worker_threads()->workers->Schedule( [this, pack, i, nshards] { WorkerThread({i, nshards}, pack); }); } } private: void WorkerThread(EmbeddingHashTableTfBridge::DumpShard shard, AsyncPack* p) { p->status[shard.idx] = RestoreOneShard(shard, p); if (p->finish_num.fetch_add(1) == p->thread_num - 1) { auto summary = p->hash_table->Summary(); auto basename = tensorflow::io::Basename(p->basename); LOG(INFO) << absl::StrFormat( "Hash table: %s, summary: %s, restore read %ld records, skip %ld " "zero embeddings", basename, summary, p->record_count, p->record_count - p->hash_table->Size()); Cleanup(p); } } Status RestoreOneShard(EmbeddingHashTableTfBridge::DumpShard shard, AsyncPack* p) { std::string filename = GetShardedFileName(p->basename, shard.idx, shard.total); std::unique_ptr f; TF_RETURN_IF_ERROR(p->ctx->env()->NewRandomAccessFile(filename, &f)); io::RecordReaderOptions opts; opts.buffer_size = 10 * 1024 * 1024; io::SequentialRecordReader reader(f.get(), opts); Status restore_status; auto get_fn = [&reader, &restore_status, &p]( EmbeddingHashTableTfBridge::EntryDump* dump, int64_t* max_update_ts) { Status s = GetRecord(&reader, dump); if (TF_PREDICT_FALSE(!s.ok())) { if (!errors::IsOutOfRange(s)) { restore_status = s; } return false; } p->record_count.fetch_add(1); if (!dump->has_last_update_ts_sec()) { dump->set_last_update_ts_sec(0); } *max_update_ts = std::max(dump->last_update_ts_sec(), *max_update_ts); return true; }; TF_RETURN_IF_ERROR(p->hash_table->Restore(p->ctx, shard, get_fn)); TF_RETURN_IF_ERROR(restore_status); return Status::OK(); } static Status GetRecord(io::SequentialRecordReader* reader, EmbeddingHashTableTfBridge::EntryDump* dump) { tstring s; TF_RETURN_IF_ERROR(reader->ReadRecord(&s)); if (!dump->ParseFromArray(s.data(), s.size())) { return errors::FailedPrecondition( "Unable to parse data. Data might be corrupted"); } return Status::OK(); } // Clean up when all shards are done. void Cleanup(AsyncPack* p) { auto done = [p]() { // We want to delete p first and then call done. auto done = std::move(p->done); delete p; done(); }; for (int i = 0; i < p->thread_num; ++i) { OP_REQUIRES_OK_ASYNC(p->ctx, p->status[i], done); } done(); } }; REGISTER_OP("MonolithHashTableRestore") .Input("handle: resource") .Input("basename: string") .Output("output_handle: resource") .SetShapeFn(shape_inference::ScalarShape); REGISTER_KERNEL_BUILDER(Name("MonolithHashTableRestore").Device(DEVICE_CPU), HashTableRestoreOp); } // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/runtime/ops/hash_table_save_op.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include "absl/random/random.h" #include "absl/strings/str_cat.h" #include "absl/time/clock.h" #include "absl/time/time.h" #include "monolith/native_training/data/training_instance/cc/reader_util.h" #include "monolith/native_training/runtime/ops/embedding_hash_table_tf_bridge.h" #include "monolith/native_training/runtime/ops/file_utils.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/lib/io/record_writer.h" #include "tensorflow/core/platform/path.h" #include "tensorflow/core/platform/random.h" #include "tensorflow/core/platform/threadpool.h" namespace tensorflow { namespace monolith_tf { namespace { // Carries the data through async process. // It will ref and unref |p_hash_table|. struct AsyncPack { AsyncPack(OpKernelContext* p_ctx, EmbeddingHashTableTfBridge* p_hash_table, std::string p_basename, std::unique_ptr p_lock_ctx, std::function p_done, int p_thread_num) : ctx(p_ctx), basename(p_basename), hash_table(p_hash_table), lock_ctx(std::move(p_lock_ctx)), done(std::move(p_done)), thread_num(p_thread_num), finish_num(0), status(thread_num) { hash_table->Ref(); } ~AsyncPack() { hash_table->Unref(); } OpKernelContext* ctx; std::string basename; EmbeddingHashTableTfBridge* hash_table; std::unique_ptr lock_ctx; std::function done; const int thread_num; std::atomic_int finish_num; std::vector status; }; const int kAutoTune = -1; } // namespace class HashTableSaveOp : public AsyncOpKernel { public: explicit HashTableSaveOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("nshards", &nshards_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("random_sleep_ms", &random_sleep_ms_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("slot_expire_time_config", &slot_expire_time_config_serialized_)); if (!slot_expire_time_config_serialized_.empty()) { OP_REQUIRES( ctx, slot_expire_time_config_.ParseFromString( slot_expire_time_config_serialized_), errors::InvalidArgument("Unable to parse config. Make sure it " "is serialized version of " "SlotExpireTimeConfig.")); } slot_to_expire_time_.resize(get_max_slot_number(), slot_expire_time_config_.default_expire_time()); for (const auto& slot_expire_time : slot_expire_time_config_.slot_expire_times()) { slot_to_expire_time_[slot_expire_time.slot()] = slot_expire_time.expire_time(); } } void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { EmbeddingHashTableTfBridge* hash_table = nullptr; OP_REQUIRES_OK_ASYNC( ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &hash_table), done); core::ScopedUnref unref(hash_table); const Tensor& basename_tensor = ctx->input(1); const std::string basename = basename_tensor.scalar()(); const std::string dirname = std::string(io::Dirname(basename)); OP_REQUIRES_OK_ASYNC(ctx, ctx->env()->RecursivelyCreateDir(dirname), done); ctx->set_output(0, ctx->input(0)); int real_nshards = PickNshards(hash_table); std::unique_ptr lock_ctx; OP_REQUIRES_OK_ASYNC(ctx, hash_table->LockAll(&lock_ctx), done); auto pack = new AsyncPack(ctx, hash_table, basename, std::move(lock_ctx), std::move(done), real_nshards); for (int i = 0; i < real_nshards; ++i) { // !important: When using GPU, tensorflow_cpu_worker_threads' are bound to // device 0 regardless of the correct device id of the current process. // That means one has to use the CUDA_VISIBLE_DEVICES // environment vairable to make sure that only one GPU is visible each // process. Otherwise something like this will happen: some device memory // is allocated on device 3, but accessing that memory in this thread as // device 0 will cause illegal memory errors ctx->device()->tensorflow_cpu_worker_threads()->workers->Schedule( [this, pack, i, real_nshards] { WorkerThread({i, real_nshards}, pack); }); } } int PickNshards(EmbeddingHashTableTfBridge* table) { if (nshards_ >= 0) return nshards_; const int64 size = table->Size(); const int64 kBaseline = 1000000ll; return std::min(4LL, std::max(1LL, size / kBaseline)); } private: void WorkerThread(EmbeddingHashTableTfBridge::DumpShard shard, AsyncPack* p) { absl::BitGen bitgen; absl::SleepFor( absl::Milliseconds(absl::Uniform(bitgen, 0, random_sleep_ms_))); p->status[shard.idx] = SaveOneShard(shard, p); if (p->finish_num.fetch_add(1) == p->thread_num - 1) { Cleanup(p); } } Status SaveOneShard(EmbeddingHashTableTfBridge::DumpShard shard, AsyncPack* p) { std::string filename = GetShardedFileName(p->basename, shard.idx, shard.total); std::string tmp_filename = absl::StrCat(filename, "-tmp-", random::New64()); std::unique_ptr f; TF_RETURN_IF_ERROR(p->ctx->env()->NewWritableFile(tmp_filename, &f)); io::RecordWriter writer(f.get()); Status write_status; int64_t max_update_ts_sec = p->hash_table->max_update_ts_sec(); auto write_fn = [this, &max_update_ts_sec, &writer, &write_status]( EmbeddingHashTableTfBridge::EntryDump dump) { int64_t slot_id = slot_id_v2(dump.id()); // Elements of slot_to_expire_time_ are in days. // last_update_ts_sec is seconds since the Epoch. if (max_update_ts_sec - dump.last_update_ts_sec() >= slot_to_expire_time_[slot_id] * 24 * 3600) { return true; } Status s = writer.WriteRecord(dump.SerializeAsString()); if (TF_PREDICT_FALSE(!s.ok())) { // OK to throw here since it will be catched. write_status = s; return false; } return true; }; EmbeddingHashTableTfBridge::DumpIterator iter; TF_RETURN_IF_ERROR(p->hash_table->Save(p->ctx, shard, write_fn, &iter)); TF_RETURN_IF_ERROR(writer.Close()); TF_RETURN_IF_ERROR(f->Close()); TF_RETURN_IF_ERROR(p->ctx->env()->RenameFile(tmp_filename, filename)); return Status::OK(); } // Clean up when all shards are done. void Cleanup(AsyncPack* p) { auto done = [p]() { // We want to delete p first and then call done. auto done = std::move(p->done); delete p; done(); }; for (int i = 0; i < p->thread_num; ++i) { OP_REQUIRES_OK_ASYNC(p->ctx, p->status[i], done); } done(); } int nshards_; int64 random_sleep_ms_; std::string slot_expire_time_config_serialized_; monolith::hash_table::SlotExpireTimeConfig slot_expire_time_config_; std::vector slot_to_expire_time_; }; REGISTER_OP("MonolithHashTableSave") .Input("handle: resource") .Input("basename: string") .Output("output_handle: resource") .Attr("nshards: int=-1") .Attr("random_sleep_ms: int=0") .Attr("slot_expire_time_config: string = ''") .SetShapeFn(shape_inference::ScalarShape); REGISTER_KERNEL_BUILDER(Name("MonolithHashTableSave").Device(DEVICE_CPU), HashTableSaveOp); } // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/runtime/ops/hash_table_update_op.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/runtime/concurrency/queue.h" #include "monolith/native_training/runtime/ops/embedding_hash_table_tf_bridge.h" #include "monolith/native_training/runtime/ops/hash_filter_tf_bridge.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/util/work_sharder.h" namespace tensorflow { namespace monolith_tf { using monolith::concurrency::Queue; using CPUDevice = Eigen::ThreadPoolDevice; class HashTableAssignOp : public OpKernel { public: explicit HashTableAssignOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} void Compute(OpKernelContext* ctx) override { EmbeddingHashTableTfBridge* hash_table = nullptr; OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &hash_table)); core::ScopedUnref unref(hash_table); const Tensor& id_values = ctx->input(1); const Tensor& id_updates = ctx->input(2); const Tensor& update_time_tensor = ctx->input(3); int64_t update_time = update_time_tensor.scalar()(); auto id_values_vec = id_values.vec(); const int num_updates = id_values_vec.dimension(0); OP_REQUIRES_OK( ctx, hash_table->Assign( ctx, num_updates, static_cast(id_values.data()), static_cast(id_updates.data()), update_time)); ctx->set_output(0, ctx->input(0)); } }; class HashTableAssignAddOp : public OpKernel { public: explicit HashTableAssignAddOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} void Compute(OpKernelContext* ctx) override { EmbeddingHashTableTfBridge* hash_table = nullptr; OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &hash_table)); core::ScopedUnref unref(hash_table); const Tensor& id_values = ctx->input(1); const Tensor& id_updates = ctx->input(2); const Tensor& update_time_tensor = ctx->input(3); int64_t update_time = update_time_tensor.scalar()(); auto id_values_vec = id_values.vec(); const int64 num_updates = id_values_vec.dimension(0); for (int64 i = 0; i < num_updates; ++i) { OP_REQUIRES_OK( ctx, hash_table->AssignAdd(ctx, id_values_vec(i), id_updates.SubSlice(i), update_time)); } ctx->set_output(0, ctx->input(0)); } }; class HashTableOptimizeOp : public OpKernel { public: explicit HashTableOptimizeOp(OpKernelConstruction* ctx) : OpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("use_multi_threads", &use_multi_threads_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("enable_dedup", &enable_dedup_)); int queue_size = 0; OP_REQUIRES_OK(ctx, ctx->GetAttr("queue_size", &queue_size)); CHECK_GE(queue_size, 0); queue_ = queue_size > 0 ? std::make_unique< Queue>>>( queue_size) : nullptr; } void Compute(OpKernelContext* ctx) override { EmbeddingHashTableTfBridge* hash_table = nullptr; OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &hash_table)); core::ScopedUnref unref(hash_table); const Tensor& id_values = ctx->input(1); const Tensor& id_updates = ctx->input(2); const Tensor& learning_rate_tensor = ctx->input(3); const Tensor& update_time_tensor = ctx->input(4); const Tensor& global_step = ctx->input(5); int64_t update_time = update_time_tensor.scalar()(); size_t num_updates = id_values.NumElements(); auto ids_flat = id_values.flat(); auto* ids = const_cast(ids_flat.data()); absl::Span learning_rate_values = absl::MakeSpan(static_cast(learning_rate_tensor.data()), learning_rate_tensor.NumElements()); int64_t global_step_value = global_step.scalar()(); if (use_multi_threads_) { auto dim_size = hash_table->dim_size(); auto update = [&](const int64 begin, const int64 end) { OP_REQUIRES_OK( ctx, hash_table->BatchOptimize( ctx, (end - begin), (ids + begin), static_cast(id_updates.data()) + begin * dim_size, learning_rate_values, update_time, enable_dedup_, global_step_value)); }; // TODO(zhangbiao.david, youlong.cheng): tweak this number for // optimization. const int64 kCostPerUnit = 20 * dim_size; const DeviceBase::CpuWorkerThreads& worker_threads = *ctx->device()->tensorflow_cpu_worker_threads(); Shard(worker_threads.num_threads, worker_threads.workers, num_updates, kCostPerUnit, update); } else { std::chrono::milliseconds timeout(1); // Optimize using this thread if operation timing out if (queue_ && queue_->try_push({id_values, id_updates, learning_rate_values}, timeout)) { auto thread_pool = ctx->device()->tensorflow_cpu_worker_threads()->workers; thread_pool->Schedule( [this, ctx, update_time, hash_table, global_step_value]() { auto ids_and_grads = queue_->pop(); const auto& id_values = std::get<0>(ids_and_grads); const auto& tensor = std::get<1>(ids_and_grads); auto& learning_rate_values = std::get<2>(ids_and_grads); size_t num_updates = id_values.NumElements(); auto ids_flat = id_values.flat(); hash_table->BatchOptimize( ctx, num_updates, const_cast(ids_flat.data()), static_cast(tensor.data()), learning_rate_values, update_time, enable_dedup_, global_step_value); }); } else { OP_REQUIRES_OK(ctx, hash_table->BatchOptimize( ctx, num_updates, ids, static_cast(id_updates.data()), learning_rate_values, update_time, enable_dedup_, global_step_value)); } } ctx->set_output(0, ctx->input(0)); } private: bool use_multi_threads_; bool enable_dedup_; mutable std::unique_ptr< Queue>>> queue_; }; template class HashTableFusedOptimizeOp : public OpKernel { public: explicit HashTableFusedOptimizeOp(OpKernelConstruction* ctx) : OpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("N", &num_tables_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("num_of_shards", &num_shards_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("enable_grad_accumulation", &enable_grad_accumulation_)); } void ComputeH(OpKernelContext* ctx); void Compute(OpKernelContext* ctx) override { ComputeH(ctx); for (int table_id = 0; table_id < num_tables_; table_id++) { ctx->set_output(table_id, ctx->input(table_id)); } } private: bool enable_grad_accumulation_; int num_tables_; int num_shards_; }; template <> void HashTableFusedOptimizeOp::ComputeH(OpKernelContext* ctx) { auto ids = ctx->input(num_tables_).vec().data(); auto num_ids = ctx->input(num_tables_).NumElements(); auto indices = ctx->input(num_tables_ + 1).vec().data(); auto slot_size_vec = ctx->input(num_tables_ + 2).vec().data(); auto id_grads = ctx->input(num_tables_ + 3).vec().data(); auto num_grads = ctx->input(num_tables_ + 3).NumElements(); auto key_offsets = ctx->input(num_tables_ + 4).vec().data(); auto emb_offsets = ctx->input(num_tables_ + 5).vec().data(); auto learning_rates = ctx->input(num_tables_ + 6).vec().data(); auto update_time = ctx->input(num_tables_ + 7).scalar()(); auto global_step = ctx->input(num_tables_ + 8).scalar()(); std::vector hash_tables(num_tables_, nullptr); for (int table_id = 0; table_id < num_tables_; table_id++) { EmbeddingHashTableTfBridge* hash_table = nullptr; OP_REQUIRES_OK( ctx, LookupResource(ctx, HandleFromInput(ctx, table_id), &hash_table)); core::ScopedUnref unref(hash_table); hash_tables[table_id] = hash_table; } auto optimize = [&](const int begin, const int end) { for (int shard_id = begin; shard_id < end; shard_id++) { int learning_rate_offset = 0; for (int table_id = 0; table_id < num_tables_; table_id++) { int curr_idx = shard_id * num_tables_ + table_id; auto table = hash_tables[table_id]; auto learning_rate = absl::MakeConstSpan( learning_rates + learning_rate_offset, table->slice_size()); learning_rate_offset += table->slice_size(); table->BatchOptimize( ctx, slot_size_vec[curr_idx], ids + key_offsets[curr_idx], id_grads + emb_offsets[curr_idx], learning_rate, update_time, enable_grad_accumulation_, global_step); } } }; // TODO(zouxuan): tweak this number for optimization. const int64 kCostPerUnit = 10000000; const DeviceBase::CpuWorkerThreads& worker_threads = *ctx->device()->tensorflow_cpu_worker_threads(); Shard(worker_threads.num_threads, worker_threads.workers, num_shards_, kCostPerUnit, optimize); } REGISTER_OP("MonolithHashTableAssign") .Input("table_handle: resource") .Input("id_values: int64") .Input("id_updates: float") .Input("req_time: int64") .Output("table_handle_output: resource") .SetShapeFn(shape_inference::ScalarShape); REGISTER_KERNEL_BUILDER(Name("MonolithHashTableAssign").Device(DEVICE_CPU), HashTableAssignOp); REGISTER_OP("MonolithHashTableAssignAdd") .Input("table_handle: resource") .Input("id_values: int64") .Input("id_updates: float") .Input("req_time: int64") .Output("table_handle_output: resource") .SetShapeFn(shape_inference::ScalarShape); REGISTER_KERNEL_BUILDER(Name("MonolithHashTableAssignAdd").Device(DEVICE_CPU), HashTableAssignAddOp); REGISTER_OP("MonolithHashTableOptimize") .Input("table_handle: resource") .Input("id_values: int64") .Input("id_updates: float") .Input("learning_rate_tensor: float") .Input("req_time: int64") .Input("global_step: int64") .Output("table_handle_output: resource") .Attr("use_multi_threads: bool = false") .Attr("queue_size: int = 0") .Attr("enable_dedup: bool = false") .SetShapeFn(shape_inference::ScalarShape); REGISTER_KERNEL_BUILDER(Name("MonolithHashTableOptimize").Device(DEVICE_CPU), HashTableOptimizeOp); REGISTER_OP("MonolithHashTableFusedOptimize") .Input("table_handles: N * resource") .Input("ids: int64") .Input("indices: int64") .Input("fused_slot_size: int32") .Input("id_grads: float") .Input("id_offsets: int32") .Input("grad_offsets: int32") .Input("learning_rate_tensors: float") .Input("req_time: int64") .Input("global_step: int64") .Output("table_handles_output: N * resource") .Attr("N: int") .Attr("num_of_shards: int") .Attr("enable_grad_accumulation: bool = false") .SetShapeFn([](shape_inference::InferenceContext* c) { int num_tables, num_shards; TF_RETURN_IF_ERROR(c->GetAttr("N", &num_tables)); TF_RETURN_IF_ERROR(c->GetAttr("num_of_shards", &num_shards)); for (int i = 0; i < num_tables; ++i) { c->set_output(i, c->Scalar()); } auto shape = c->input(num_tables + 2); TF_RETURN_IF_ERROR(c->WithRank(shape, 1, &shape)); auto dim = c->Dim(shape, 0); TF_RETURN_IF_ERROR(c->WithValue(dim, num_tables * num_shards, &dim)); return Status::OK(); }); REGISTER_KERNEL_BUILDER( Name("MonolithHashTableFusedOptimize").Device(DEVICE_CPU), HashTableFusedOptimizeOp); } // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/runtime/ops/inbatch_auc_loss.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include #include #include #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/framework/tensor.h" namespace tensorflow { namespace monolith_tf { class InbatchAucLossOp : public OpKernel { public: explicit InbatchAucLossOp(OpKernelConstruction *ctx) : OpKernel(ctx) { float neg_weight; OP_REQUIRES_OK(ctx, ctx->GetAttr("neg_weight", &neg_weight)); CHECK_GT(neg_weight, 0); CHECK_LE(neg_weight, 1.0); } void Compute(OpKernelContext *ctx) override { const Tensor *label_tensor; OP_REQUIRES_OK(ctx, ctx->input("label", &label_tensor)); const Tensor *logit_tensor; OP_REQUIRES_OK(ctx, ctx->input("logit", &logit_tensor)); OP_REQUIRES(ctx, label_tensor->NumElements() == logit_tensor->NumElements(), errors::InvalidArgument("the label and logit not match")); std::vector positive, negative; auto label_flat = label_tensor->flat(); for (size_t i = 0; i < label_flat.size(); ++i) { if (label_flat(i) > 0) { positive.push_back(i); } else if (label_flat(i) > -10000) { negative.push_back(i); } } float loss = 0; auto logit_flat = logit_tensor->flat(); for (const size_t &i : positive) { float pos_logit = logit_flat(i); for (const size_t &j : negative) { float diff = pos_logit - logit_flat(j); if (diff > -87 && diff < 88) { loss += diff - log(1.0 + exp(diff)); } else if (diff <= -87) { loss += diff; } } } Tensor *loss_tensor = nullptr; OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {}, &loss_tensor)); loss_tensor->scalar()() = loss; } }; class InbatchAucLossGradOp : public OpKernel { public: explicit InbatchAucLossGradOp(OpKernelConstruction *ctx) : OpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("neg_weight", &neg_weight_)); } void Compute(OpKernelContext *ctx) override { const Tensor *label_tensor; OP_REQUIRES_OK(ctx, ctx->input("label", &label_tensor)); const Tensor *logit_tensor; OP_REQUIRES_OK(ctx, ctx->input("logit", &logit_tensor)); OP_REQUIRES(ctx, label_tensor->NumElements() == logit_tensor->NumElements(), errors::InvalidArgument("the label and logit not match")); const Tensor *grad_tensor; OP_REQUIRES_OK(ctx, ctx->input("grad", &grad_tensor)); float grad = grad_tensor->scalar()(); Tensor *logit_grad_tensor = nullptr; OP_REQUIRES_OK(ctx, ctx->allocate_output(0, logit_tensor->shape(), &logit_grad_tensor)); auto logit_grad_float = logit_grad_tensor->flat(); logit_grad_float.setZero(); std::vector positive, negative; auto label_flat = label_tensor->flat(); for (size_t i = 0; i < label_flat.size(); ++i) { if (label_flat(i) > 0) { positive.push_back(i); } else if (label_flat(i) > -10000) { negative.push_back(i); } } auto logit_flat = logit_tensor->flat(); for (const size_t &i : positive) { float pos_logit = logit_flat(i); for (const size_t &j : negative) { float diff = pos_logit - logit_flat(j); float grad_ij; if (diff > -87 && diff < 88) { grad_ij = 1.0 - 1.0 / (1.0 + exp(-diff)); } else if (diff <= -87) { grad_ij = 1; } else { grad_ij = 0; } logit_grad_float(i) += grad_ij; logit_grad_float(j) -= neg_weight_ * grad_ij; } } if (grad != 1) { for (size_t i = 0; i < logit_grad_float.size(); ++i) { logit_grad_float(i) *= grad; } } } private: float neg_weight_; }; namespace { REGISTER_KERNEL_BUILDER(Name("InbatchAucLoss").Device(DEVICE_CPU), InbatchAucLossOp) REGISTER_KERNEL_BUILDER(Name("InbatchAucLossGrad").Device(DEVICE_CPU), InbatchAucLossGradOp) REGISTER_OP("InbatchAucLoss") .Input("label: float") .Input("logit: float") .Attr("neg_weight: float") .Output("loss: float") .SetDoNotOptimize() .SetShapeFn([](shape_inference::InferenceContext *ctx) { ctx->set_output(0, ctx->Scalar()); return Status::OK(); }); REGISTER_OP("InbatchAucLossGrad") .Input("label: float") .Input("logit: float") .Input("grad: float") .Attr("neg_weight: float") .Output("logit_grad: float") .SetDoNotOptimize() .SetShapeFn([](shape_inference::InferenceContext *ctx) { ctx->set_output(0, ctx->input(1)); return Status::OK(); }); } // namespace } // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/runtime/ops/logging_ops.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include "absl/container/flat_hash_map.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "absl/time/clock.h" #include "monolith/native_training/runtime/common/metrics.h" #include "monolith/native_training/runtime/ops/logging_ops.pb.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_op_kernel.h" #include "tensorflow/core/framework/shape_inference.h" namespace tensorflow { namespace monolith_tf { namespace { class TensorTimestampOp : public OpKernel { public: explicit TensorTimestampOp(OpKernelConstruction* ctx) : OpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &types_)); } void Compute(OpKernelContext* ctx) override { for (int i = 0; i < static_cast(types_.size()); ++i) { ctx->set_output(i, ctx->input(i)); } Tensor* ts; OP_REQUIRES_OK(ctx, ctx->allocate_output(types_.size(), {}, &ts)); auto ts_scalar = ts->scalar(); ts_scalar() = absl::ToUnixMicros(absl::Now()); } private: std::vector types_; }; REGISTER_OP("MonolithTensorsTimestamp") .Attr("T: list(type)") .Input("tensors_in: T") .Output("tensors_out: T") .Output("timestamp: int64") .SetShapeFn([](shape_inference::InferenceContext* ctx) { std::vector types; TF_RETURN_IF_ERROR(ctx->GetAttr("T", &types)); for (int i = 0; i < static_cast(types.size()); ++i) { ctx->set_output(i, ctx->input(i)); } ctx->set_output(types.size(), ctx->Scalar()); return Status::OK(); }); REGISTER_KERNEL_BUILDER(Name("MonolithTensorsTimestamp").Device(DEVICE_CPU), TensorTimestampOp); // Deprecated. class MetricOp : public OpKernel { public: explicit MetricOp(OpKernelConstruction* ctx) : OpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &types_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("key", &key_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("tags", &tags_)); } void Compute(OpKernelContext* ctx) override { for (int i = 0; i < static_cast(types_.size()); ++i) { ctx->set_output(i, ctx->input(i)); } const Tensor& value_tensor = ctx->input(types_.size()); const float value = value_tensor.scalar()(); monolith::GetMetrics()->emit_timer(key_, value, tags_); } private: std::vector types_; std::string key_; std::string tags_; }; REGISTER_OP("MonolithMetric") .Attr("T: list(type)") .Attr("key: string") .Attr("tags: string") .Input("tensors_in: T") .Input("value: float") .Output("tensors_out: T") .SetShapeFn([](shape_inference::InferenceContext* ctx) { std::vector types; TF_RETURN_IF_ERROR(ctx->GetAttr("T", &types)); for (int i = 0; i < static_cast(types.size()); ++i) { ctx->set_output(i, ctx->input(i)); } return Status::OK(); }); REGISTER_KERNEL_BUILDER(Name("MonolithMetric").Device(DEVICE_CPU), MetricOp); class MetricV2Op : public OpKernel { public: explicit MetricV2Op(OpKernelConstruction* ctx) : OpKernel(ctx), stat_last_1_min_(60), stat_last_5_min_(300), stat_last_15_min_(900) { OP_REQUIRES_OK(ctx, ctx->GetAttr("key", &key_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("tags", &tags_)); } void Compute(OpKernelContext* ctx) override { const Tensor& value_tensor = ctx->input(0); const float value = value_tensor.scalar()(); monolith::GetMetrics()->emit_timer(key_, value, tags_); tensorflow::mutex_lock l(mu_); auto now = absl::Now(); stat_last_1_min_.PushOne(value, now); stat_last_5_min_.PushOne(value, now); stat_last_15_min_.PushOne(value, now); LOG_EVERY_N_SEC(INFO, 600) << absl::StrFormat( "%s last_1_min: %s", key_, stat_last_1_min_.DebugString()); LOG_EVERY_N_SEC(INFO, 600) << absl::StrFormat( "%s last_5_min: %s", key_, stat_last_5_min_.DebugString()); LOG_EVERY_N_SEC(INFO, 600) << absl::StrFormat( "%s last_15_min: %s", key_, stat_last_15_min_.DebugString()); } private: class MovingStat { public: explicit MovingStat(int64_t time_window_in_sec) : time_window_in_sec_(time_window_in_sec), min_(std::numeric_limits::max()), max_(std::numeric_limits::min()), sum_(0.f) {} void PushOne(float value, absl::Time t) { min_ = std::min(min_, value); max_ = std::max(max_, value); sum_ += value; buffer_.emplace_back(value, t); while (!buffer_.empty()) { auto start = buffer_.front().second; int64_t delta = absl::ToInt64Seconds(t - start); if (delta > time_window_in_sec_) { sum_ -= buffer_.front().first; buffer_.pop_front(); } else { break; } } } std::string DebugString() const { float avg = 0, p99 = 0; if (!buffer_.empty()) { avg = sum_ / buffer_.size(); std::vector values; values.reserve(buffer_.size()); for (const auto& p : buffer_) { values.push_back(p.first); } std::sort(values.begin(), values.end()); int64_t p99_index = values.size() * 0.99f; p99 = values[p99_index]; } return absl::StrFormat("min: %f, max: %f, avg: %f, p99: %f", min_, max_, avg, p99); } private: int64_t time_window_in_sec_; float min_; float max_; float sum_; std::deque> buffer_; }; private: std::string key_; std::string tags_; tensorflow::mutex mu_; MovingStat stat_last_1_min_ TF_GUARDED_BY(mu_); MovingStat stat_last_5_min_ TF_GUARDED_BY(mu_); MovingStat stat_last_15_min_ TF_GUARDED_BY(mu_); }; REGISTER_OP("MonolithMetricV2") .Attr("key: string") .Attr("tags: string") .SetIsStateful() .Input("value: float") .SetShapeFn(shape_inference::NoOutputs); REGISTER_KERNEL_BUILDER(Name("MonolithMetricV2").Device(DEVICE_CPU), MetricV2Op); struct MachineInfo : ResourceBase { int64 mem_limit = 0; std::string DebugString() const { return absl::StrFormat("mem_limit: %lld", mem_limit); } }; class MachineInfoOp : public ResourceOpKernel { public: explicit MachineInfoOp(OpKernelConstruction* c) : ResourceOpKernel(c) { OP_REQUIRES_OK(c, c->GetAttr("mem_limit", &mem_limit_)); } private: Status CreateResource(MachineInfo** info_out) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) override { auto* info = new MachineInfo(); info->mem_limit = mem_limit_; *info_out = info; return Status::OK(); } int64 mem_limit_; // Unit bytes }; REGISTER_OP("MonolithMachineInfo") .Output("handle: resource") .Attr("mem_limit: int") .Attr("container: string = ''") .Attr("shared_name: string = ''") .SetIsStateful() .SetShapeFn(shape_inference::ScalarShape); REGISTER_KERNEL_BUILDER(Name("MonolithMachineInfo").Device(DEVICE_CPU), MachineInfoOp); class MonolithCheckMachineHealthOp : public OpKernel { public: explicit MonolithCheckMachineHealthOp(OpKernelConstruction* c) : OpKernel(c) {} void Compute(OpKernelContext* c) override { int64 current_mem = GetCurrentUsage(); OP_REQUIRES(c, current_mem > 0, errors::Internal("Unable to get the current process usage.")); MachineInfo* info = nullptr; OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &info)); core::ScopedUnref unref(info); Tensor* result_tensor; OP_REQUIRES_OK(c, c->allocate_output(0, {}, &result_tensor)); auto result_scalar = result_tensor->scalar(); MachineHealthResult result; if (current_mem >= info->mem_limit) { result.set_status(MachineHealthResult::OUT_OF_MEMORY); result.set_message( absl::StrFormat("Memory limit exceeded. Current: %lld, Limit: %lld", current_mem, info->mem_limit)); } result_scalar() = result.SerializeAsString(); } int64_t GetCurrentUsage() { FILE* file = fopen("/proc/self/status", "r"); int64_t result = 0; char line[128]; while (fgets(line, 128, file) != NULL) { if (std::strncmp(line, "VmRSS:", 6) == 0) { // The line is like `VmRSS: 708 kB` result = std::strtol(line + 6, nullptr, 10); break; } } fclose(file); return result * 1024; } }; REGISTER_OP("MonolithCheckMachineHealth") .Input("machine_info_handle: resource") .Output("serialized_result: string") .SetShapeFn([](shape_inference::InferenceContext* ctx) { ctx->set_output(0, ctx->Scalar()); return Status::OK(); }); REGISTER_KERNEL_BUILDER(Name("MonolithCheckMachineHealth").Device(DEVICE_CPU), MonolithCheckMachineHealthOp); } // namespace } // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/runtime/ops/logging_ops.proto ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. syntax="proto2"; package tensorflow.monolith_tf; // To make it faster, we need to guarantee that // it is empty message when it is OK. message MachineHealthResult { enum MachineHealthStatus { OK = 0; OUT_OF_MEMORY = 1; } optional MachineHealthStatus status = 1; optional string message = 2; } ================================================ FILE: monolith/native_training/runtime/ops/map_id_to_embedding.cu.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #if GOOGLE_CUDA #define EIGEN_USE_GPU #include "monolith/native_training/runtime/ops/alloc_utils.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/kernels/gpu_device_array.h" #include "tensorflow/core/kernels/gpu_device_array_gpu.h" #include "tensorflow/core/util/gpu_kernel_helper.h" namespace tensorflow { namespace monolith_tf { typedef Eigen::GpuDevice GPUDevice; template __global__ void FusedGatherKernel( const T* __restrict__ fused_embeddings, GpuDeviceArrayStruct input_ptr_data, GpuDeviceArrayStruct output_ptr_data, GpuDeviceArrayStruct embedding_dims_data, GpuDeviceArrayStruct offsets_data, int32 size) { const int32** input_ptrs = GetGpuDeviceArrayOnDevice(&input_ptr_data); T** output_ptrs = GetGpuDeviceArrayOnDevice(&output_ptr_data); int32* offsets = GetGpuDeviceArrayOnDevice(&offsets_data); int32* embedding_dims = GetGpuDeviceArrayOnDevice(&embedding_dims_data); // if using shared memory // Ref: // https://github.com/tensorflow/tensorflow/blob/v2.4.0/tensorflow/core/kernels/split_lib_gpu.cu.cc#L124 GPU_DYNAMIC_SHARED_MEM_DECL(sizeof(int32), unsigned char, smem); int32* smem_offsets = reinterpret_cast(smem); int32* smem_embedding_dims = smem_offsets + offsets_data.size; for (int x = threadIdx.x; x < offsets_data.size; x += blockDim.x) { smem_offsets[x] = offsets[x]; } for (int x = threadIdx.x; x < embedding_dims_data.size; x += blockDim.x) { smem_embedding_dims[x] = embedding_dims[x]; } __syncthreads(); offsets = smem_offsets; embedding_dims = smem_embedding_dims; int i = 0; for (int32 idx : GpuGridRangeX(size)) { // safe offsets read: when idx == size - 1, i+1 == num_inputs // since num_inputs := number of merged slot < 100, // linear search would be sufficient here while (offsets[i + 1] <= idx) ++i; int32 local_idx = idx - offsets[i]; int32 dim = embedding_dims[i]; int j = local_idx / dim; int k = local_idx % dim; int32 emb_offset = input_ptrs[i][j]; output_ptrs[i][local_idx] = ldg(fused_embeddings + emb_offset + k); } } template __global__ void FusedGatherGradKernel( T* output_ptr, GpuDeviceArrayStruct input_ptr_data, GpuDeviceArrayStruct offset_ptr_data, GpuDeviceArrayStruct embedding_dims_data, GpuDeviceArrayStruct offsets_data, int32 size, const float* _scale) { float scale = *_scale; const T** input_ptrs = GetGpuDeviceArrayOnDevice(&input_ptr_data); const int32** offset_ptrs = GetGpuDeviceArrayOnDevice(&offset_ptr_data); int32* offsets = GetGpuDeviceArrayOnDevice(&offsets_data); int32* embedding_dims = GetGpuDeviceArrayOnDevice(&embedding_dims_data); // if using shared memory // Ref: // https://github.com/tensorflow/tensorflow/blob/v2.4.0/tensorflow/core/kernels/split_lib_gpu.cu.cc#L124 GPU_DYNAMIC_SHARED_MEM_DECL(sizeof(int32), unsigned char, smem); int32* smem_offsets = reinterpret_cast(smem); int32* smem_embedding_dims = smem_offsets + offsets_data.size; for (int x = threadIdx.x; x < offsets_data.size; x += blockDim.x) { smem_offsets[x] = offsets[x]; } for (int x = threadIdx.x; x < embedding_dims_data.size; x += blockDim.x) { smem_embedding_dims[x] = embedding_dims[x]; } __syncthreads(); offsets = smem_offsets; embedding_dims = smem_embedding_dims; int i = 0; for (int32 idx : GpuGridRangeX(size)) { // safe offsets read: when idx == size - 1, i+1 == num_inputs // since num_inputs := number of merged slot < 100, // linear search would be sufficient here while (offsets[i + 1] <= idx) ++i; int32 local_idx = idx - offsets[i]; int32 dim = embedding_dims[i]; int j = local_idx / dim; int k = local_idx % dim; const int32 emb_offset = offset_ptrs[i][j]; GpuAtomicAdd(output_ptr + emb_offset + k, input_ptrs[i][local_idx] * scale); } } template struct SetZeroFunctor { void operator()(const GPUDevice& d, typename TTypes::Flat out) { To32Bit(out).device(d) = To32Bit(out).constant(T(0)); } }; template class FusedGatherEmbeddingsByInputOpGPU : public OpKernel { public: explicit FusedGatherEmbeddingsByInputOpGPU(OpKernelConstruction* ctx) : OpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("M", &num_of_inputs_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("embedding_dims", &embedding_dims_)); } void Compute(OpKernelContext* ctx) override { auto fused_embeddings_flat = ctx->input(0).flat(); OpInputList inputs; OP_REQUIRES_OK(ctx, ctx->input_list("embedding_offsets", &inputs)); GpuDeviceArrayOnHost input_ptrs(ctx, num_of_inputs_); GpuDeviceArrayOnHost embedding_dims(ctx, num_of_inputs_); GpuDeviceArrayOnHost offsets(ctx, num_of_inputs_ + 1); OP_REQUIRES_OK(ctx, input_ptrs.Init()); OP_REQUIRES_OK(ctx, embedding_dims.Init()); OP_REQUIRES_OK(ctx, offsets.Init()); int smem_usage = sizeof(int32) * (num_of_inputs_ + 1 + num_of_inputs_); // smem: offsets + embedding_dims FusedAlignedOutputAllocator fao_alloc( ctx); for (int i = 0; i < num_of_inputs_; ++i) { auto dim = embedding_dims_[i]; auto s = inputs[i].NumElements(); // == input[i].shape().dim_size(0) embedding_dims.Set(i, dim); offsets.Set(i, fao_alloc.get_unaligned_total()); input_ptrs.Set(i, inputs[i].flat().data()); fao_alloc.add_slice(s * dim); } // offset val here is total workload offsets.Set(num_of_inputs_, fao_alloc.get_unaligned_total()); OP_REQUIRES_OK(ctx, offsets.Finalize()); OP_REQUIRES_OK(ctx, input_ptrs.Finalize()); OP_REQUIRES_OK(ctx, embedding_dims.Finalize()); GpuDeviceArrayOnHost output_ptrs(ctx, num_of_inputs_); OP_REQUIRES_OK(ctx, output_ptrs.Init()); fao_alloc.allocate(ctx->expected_output_dtype(0)); for (int i = 0; i < num_of_inputs_; ++i) { auto dim = embedding_dims_[i]; auto s = inputs[i].NumElements(); // == input[i].shape().dim_size(0) Tensor out = fao_alloc.get_slice({s, dim}); output_ptrs.Set(i, out.flat().data()); ctx->set_output(i, std::move(out)); } OP_REQUIRES_OK(ctx, output_ptrs.Finalize()); GPUDevice gpu_device = ctx->eigen_device(); // We use a 2D LaunchConfig here to make thread (x, y) of every // input tensor y better benefit from the ldg local cache read // for multiple x of x + n * grid_stride. // >>> auto config = GetGpu2DLaunchConfig(max_input_size, num_of_inputs_, // gpu_device); // // However, across inputs the distribution of elements thus thread workload // can be imbalanced in this implementation. // /// One alternative implmentation for this Op is based on ComputeAsync + // Multiple Kernel Calls, for example similar to // https://github.com/tensorflow/tensorflow/blob/v2.4.0/tensorflow/core/kernels/dynamic_partition_op_gpu.cu.cc#L454-L469 // // The chosen implementation is to distribute the output workload balanced // on threads, // while searching the idx input bucket to which the output val belongs to. auto config = GetGpuLaunchConfig(fao_alloc.get_unaligned_total(), gpu_device); auto grid_offset = 24; char* ptr = std::getenv("MONOLITH_GT_OVERSUB_SM"); if (ptr) grid_offset = std::atoi(ptr); grid_offset += 2; GpuLaunchKernel( FusedGatherKernel, config.block_count - grid_offset, config.thread_per_block, /*shared_memory_size_bytes=*/smem_usage, gpu_device.stream(), fused_embeddings_flat.data(), input_ptrs.data(), output_ptrs.data(), embedding_dims.data(), offsets.data(), fao_alloc.get_unaligned_total()); } private: int num_of_inputs_; std::vector embedding_dims_; }; template class FusedGatherEmbeddingsByInputGradientOpGPU : public OpKernel { public: explicit FusedGatherEmbeddingsByInputGradientOpGPU(OpKernelConstruction* ctx) : OpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("M", &num_of_inputs_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("embedding_dims", &embedding_dims_)); } void Compute(OpKernelContext* ctx) override { GPUDevice gpu_device = ctx->eigen_device(); OpInputList inputs; OP_REQUIRES_OK(ctx, ctx->input_list("grads", &inputs)); OpInputList embedding_offsets; OP_REQUIRES_OK(ctx, ctx->input_list("embedding_offsets", &embedding_offsets)); GpuDeviceArrayOnHost input_ptrs(ctx, num_of_inputs_); GpuDeviceArrayOnHost emb_offset_ptrs(ctx, num_of_inputs_); OP_REQUIRES_OK(ctx, input_ptrs.Init()); OP_REQUIRES_OK(ctx, emb_offset_ptrs.Init()); GpuDeviceArrayOnHost embedding_dims(ctx, num_of_inputs_); OP_REQUIRES_OK(ctx, embedding_dims.Init()); int32 offset = 0; GpuDeviceArrayOnHost offsets(ctx, num_of_inputs_ + 1); // input_offsets OP_REQUIRES_OK(ctx, offsets.Init()); int smem_usage = sizeof(int32) * (num_of_inputs_ + 1 + num_of_inputs_); // smem: offsets + embedding_dims for (int i = 0; i < num_of_inputs_; ++i) { input_ptrs.Set(i, inputs[i].flat().data()); emb_offset_ptrs.Set(i, embedding_offsets[i].flat().data()); auto s = embedding_offsets[i].NumElements(); auto dim = embedding_dims_[i]; embedding_dims.Set(i, dim); offsets.Set(i, offset); offset += s * dim; } offsets.Set(num_of_inputs_, offset); // offset val here is total workload OP_REQUIRES_OK(ctx, offsets.Finalize()); OP_REQUIRES_OK(ctx, input_ptrs.Finalize()); OP_REQUIRES_OK(ctx, emb_offset_ptrs.Finalize()); OP_REQUIRES_OK(ctx, embedding_dims.Finalize()); int32 fused_embeddings_size = ctx->input(0).scalar().data()[0]; Tensor* output_tensor; OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({fused_embeddings_size}), &output_tensor)); SetZeroFunctor zero_functor; auto output = output_tensor->flat(); zero_functor(gpu_device, output); auto config = GetGpuLaunchConfig(offset, gpu_device); GpuLaunchKernel( FusedGatherGradKernel, config.block_count, config.thread_per_block, /*shared_memory_size_bytes=*/smem_usage, gpu_device.stream(), output.data(), input_ptrs.data(), emb_offset_ptrs.data(), embedding_dims.data(), offsets.data(), offset, ctx->input(2 * num_of_inputs_ + 1).flat().data()); } private: int num_of_inputs_; std::vector embedding_dims_; }; REGISTER_KERNEL_BUILDER( Name("MonolithFusedGatherEmbeddingsByInput").Device(DEVICE_GPU), FusedGatherEmbeddingsByInputOpGPU); REGISTER_KERNEL_BUILDER(Name("MonolithFusedGatherEmbeddingsByInputGradient") .Device(DEVICE_GPU) .HostMemory("fused_embeddings_size"), FusedGatherEmbeddingsByInputGradientOpGPU); } // namespace monolith_tf } // namespace tensorflow #endif // GOOGLE_CUDA ================================================ FILE: monolith/native_training/runtime/ops/map_id_to_embedding_op.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "absl/container/flat_hash_map.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/platform/threadpool.h" #include "tensorflow/core/util/work_sharder.h" #include "monolith/native_training/runtime/hash_table/optimizer/avx_utils.h" namespace tensorflow { namespace monolith_tf { namespace { // The input embeddings are a list of 2D tensors. // This represents the embedding: embeddings[index].chip<0>(pos) struct EmbeddingLocation { int64 index; int64 pos; }; } // namespace // Maps input ids into embeddings. class MapIdToEmbeddingOp : public OpKernel { public: explicit MapIdToEmbeddingOp(OpKernelConstruction* ctx) : OpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("num_splits", &num_splits_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("use_multi_threads", &use_multi_threads_)); } void Compute(OpKernelContext* ctx) override { absl::flat_hash_map id_to_loc; int64 total_split_ids = 0; for (int i = 0; i < num_splits_; ++i) { total_split_ids += ctx->input(i).flat().dimension(0); } id_to_loc.reserve(total_split_ids); for (int i = 0; i < num_splits_; ++i) { auto ids = ctx->input(i).flat(); for (int64 j = 0; j < ids.dimension(0); ++j) { id_to_loc.insert({ids(j), {i, j}}); } } std::vector::Matrix> embeddings; embeddings.reserve(num_splits_); for (int i = 0; i < num_splits_; ++i) { embeddings.emplace_back(ctx->input(num_splits_ + i).matrix()); } int64 embedding_dim = embeddings[0].dimension(1); const Tensor& input = ctx->input(2 * num_splits_); auto input_flat = input.flat(); Tensor* output; TensorShape output_shape = input.shape(); output_shape.AddDim(embedding_dim); OP_REQUIRES_OK(ctx, ctx->allocate_output(0, output_shape, &output)); auto output_mat = output->shaped({input.NumElements(), embedding_dim}); auto map_fn = [&](const int64 begin, const int64 end) { for (int64 k = begin; k < end; ++k) { auto iter = id_to_loc.find(input_flat(k)); if (iter == id_to_loc.end()) { return ctx->SetStatus( errors::InvalidArgument("Unable to map id ", input_flat(k))); } const EmbeddingLocation& loc = iter->second; output_mat.chip<0>(k) = embeddings[loc.index].chip<0>(loc.pos); } }; int64 total = input_flat.dimension(0); if (use_multi_threads_) { auto worker_threads = ctx->device()->tensorflow_cpu_worker_threads(); auto workers = worker_threads->workers; int num_shards = std::min(5LL, std::max(1LL, total / 10000)); int64 block_size = total / num_shards; workers->ParallelFor( total, thread::ThreadPool::SchedulingParams( thread::ThreadPool::SchedulingStrategy::kFixedBlockSize, absl::nullopt, block_size), map_fn); } else { map_fn(0, total); } } private: int num_splits_; bool use_multi_threads_; }; class MapIdToEmbeddingGradientOp : public OpKernel { public: explicit MapIdToEmbeddingGradientOp(OpKernelConstruction* ctx) : OpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("num_splits", &num_splits_)); } void Compute(OpKernelContext* ctx) override { const Tensor& input = ctx->input(num_splits_); const Tensor& grads = ctx->input(num_splits_ + 1); const int64 embedding_size = grads.dim_size(grads.dims() - 1); absl::flat_hash_map id_to_loc; std::vector::Matrix> embedding_grads_mats; embedding_grads_mats.reserve(num_splits_); for (int i = 0; i < num_splits_; ++i) { auto ids = ctx->input(i).flat(); int64 len_ids = ids.dimension(0); for (int64 j = 0; j < ids.dimension(0); ++j) { id_to_loc.insert({ids(j), {i, j}}); } Tensor* output; OP_REQUIRES_OK( ctx, ctx->allocate_output(i, {len_ids, embedding_size}, &output)); std::memset(output->data(), 0, output->AllocatedBytes()); embedding_grads_mats.emplace_back(output->matrix()); } auto input_flat = input.flat(); auto grads_mat = grads.shaped({input.NumElements(), embedding_size}); for (int64 k = 0; k < input_flat.dimension(0); ++k) { const EmbeddingLocation& loc = id_to_loc.find(input_flat(k))->second; embedding_grads_mats[loc.index].chip<0>(loc.pos) += grads_mat.chip<0>(k); } } private: int num_splits_; }; // Maps input ids into embeddings with only 1 tensor. class GatherEmbeddingsByInputOp : public OpKernel { public: explicit GatherEmbeddingsByInputOp(OpKernelConstruction* ctx) : OpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("use_multi_threads", &use_multi_threads_)); } void Compute(OpKernelContext* ctx) override { absl::flat_hash_map id_to_loc; auto ids = ctx->input(0).flat(); for (int i = 0; i < ids.dimension(0); ++i) { id_to_loc.insert({ids(i), i}); } TTypes::Matrix embeddings = ctx->input(1).matrix(); OP_REQUIRES(ctx, embeddings.dimension(0) == ids.dimension(0), errors::InvalidArgument("See unmatched embedding dim ", embeddings.dimension(0), " and id dim ", ids.dimension(0))); int64 embedding_dim = embeddings.dimension(1); const Tensor& input = ctx->input(2); auto input_flat = input.flat(); Tensor *output, *output_index_mapping; TensorShape output_shape = input.shape(); output_shape.AddDim(embedding_dim); OP_REQUIRES_OK(ctx, ctx->allocate_output(0, output_shape, &output)); OP_REQUIRES_OK( ctx, ctx->allocate_output(1, input.shape(), &output_index_mapping)); auto output_mat = output->shaped({input.NumElements(), embedding_dim}); auto output_index_mapping_flat = output_index_mapping->flat(); auto fill_fn = [&](const int64 begin, const int64 end) { for (int64 k = begin; k < end; ++k) { auto iter = id_to_loc.find(input_flat(k)); if (iter == id_to_loc.end()) { return ctx->SetStatus( errors::InvalidArgument("Unable to map id ", input_flat(k))); } const int64& pos = iter->second; output_mat.chip<0>(k) = embeddings.chip<0>(pos); output_index_mapping_flat(k) = pos; } }; if (use_multi_threads_) { // TODO(zouxuan): tune this for performance. const int64 kCostPerUnit = 4 * embedding_dim; const DeviceBase::CpuWorkerThreads& worker_threads = *ctx->device()->tensorflow_cpu_worker_threads(); Shard(worker_threads.num_threads, worker_threads.workers, input_flat.dimension(0), kCostPerUnit, fill_fn); } else { fill_fn(0, input_flat.dimension(0)); } } private: bool use_multi_threads_; }; class GatherEmbeddingsByInputGradientOp : public OpKernel { public: explicit GatherEmbeddingsByInputGradientOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} void Compute(OpKernelContext* ctx) override { const Tensor& ids = ctx->input(0); const Tensor& grads = ctx->input(1); auto index_mapping_flat = ctx->input(2).flat(); auto ids_flat = ids.flat(); // Reshape it to len(input):embedding_size shape. const int64 embedding_size = grads.dim_size(grads.dims() - 1); const int64 input_size = index_mapping_flat.dimension(0); auto grads_mat = grads.shaped({input_size, embedding_size}); const int64 len_ids = ids_flat.dimension(0); Tensor* output; OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {len_ids, embedding_size}, &output)); std::memset(output->data(), 0, output->AllocatedBytes()); TTypes::Matrix embedding_grads_mats = output->matrix(); for (int64 k = 0; k < input_size; ++k) { const int64 loc = index_mapping_flat(k); embedding_grads_mats.chip<0>(loc) += grads_mat.chip<0>(k); } } }; class FusedGatherEmbeddingsByInputOp : public OpKernel { public: explicit FusedGatherEmbeddingsByInputOp(OpKernelConstruction* ctx) : OpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("M", &num_of_inputs_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("embedding_dims", &embedding_dims_)); } void Compute(OpKernelContext* ctx) override { auto fused_embeddings_flat = ctx->input(0).flat(); OpInputList inputs; OP_REQUIRES_OK(ctx, ctx->input_list("embedding_offsets", &inputs)); OpOutputList outputs; OP_REQUIRES_OK(ctx, ctx->output_list("output", &outputs)); std::vector output_ptrs(outputs.size()); DCHECK_EQ(num_of_inputs_, outputs.size()); for (int i = 0; i < num_of_inputs_; ++i) { TensorShape output_shape = inputs[i].shape(); output_shape.AddDim(embedding_dims_[i]); Tensor* out; OP_REQUIRES_OK(ctx, outputs.allocate(i, output_shape, &out)); output_ptrs[i] = out->flat().data(); } auto fill_fn = [&](const int64 begin, const int64 end) { for (int i = begin; i < end; ++i) { auto embedding_offset_vec = inputs[i].vec(); int embedding_dim = embedding_dims_[i]; for (int j = 0; j < embedding_offset_vec.size(); ++j) { auto offset = embedding_offset_vec(j); std::memcpy(output_ptrs[i] + j * embedding_dim, fused_embeddings_flat.data() + offset, embedding_dim * sizeof(float)); } } }; auto worker_threads = ctx->device()->tensorflow_cpu_worker_threads(); worker_threads->workers->ParallelFor( num_of_inputs_, thread::ThreadPool::SchedulingParams( thread::ThreadPool::SchedulingStrategy::kFixedBlockSize, absl::nullopt, 1), // block_size fill_fn); } private: int num_of_inputs_; std::vector embedding_dims_; }; class FusedGatherEmbeddingsByInputGradientOp : public OpKernel { public: explicit FusedGatherEmbeddingsByInputGradientOp(OpKernelConstruction* ctx) : OpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("embedding_dims", &embedding_dims_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("M", &num_of_inputs_)); } void Compute(OpKernelContext* ctx) override { int32 fused_embeddings_size = ctx->input(0).flat()(0); Tensor* output; OP_REQUIRES_OK(ctx, ctx->allocate_output( 0, TensorShape({fused_embeddings_size}), &output)); auto output_flat = output->flat(); std::memset(output->data(), 0, output->AllocatedBytes()); // By design, different inputs from num_of_inputs_ are sharded into // different positions in the flattened gradients, and thus simply do a // parallel fill function. auto fill_fn = [&](const int64 begin, const int64 end) { for (int i = begin; i < end; ++i) { auto input_flat = ctx->input(1 + i).flat(); auto embedding_offset_vec = ctx->input(num_of_inputs_ + 1 + i).vec(); int embedding_dim = embedding_dims_[i]; for (int j = 0; j < embedding_offset_vec.dimension(0); ++j) { int32 offset = embedding_offset_vec(j); const float* input_a = const_cast(input_flat.data()) + j * embedding_dim; float* output_b = static_cast(output_flat.data()) + offset; // Use AVX acceleration for reducesum. ::monolith::hash_table::ReduceSum(input_a, output_b, output_b, embedding_dim); } } }; auto worker_threads = ctx->device()->tensorflow_cpu_worker_threads(); worker_threads->workers->ParallelFor( num_of_inputs_, thread::ThreadPool::SchedulingParams( thread::ThreadPool::SchedulingStrategy::kFixedBlockSize, absl::nullopt, 1), // block_size fill_fn); } private: std::vector embedding_dims_; int num_of_inputs_; }; REGISTER_OP("MonolithMapIdToEmbedding") .Input("ids: num_splits * int64") .Input("embeddings: num_splits * float") .Input("input: int64") .Output("output: float") .Attr("num_splits: int") .Attr("use_multi_threads: bool = true") .SetShapeFn([](shape_inference::InferenceContext* c) { int num_splits; TF_RETURN_IF_ERROR(c->GetAttr("num_splits", &num_splits)); shape_inference::ShapeHandle embedding_shape = c->MakeShape({c->Dim(c->input(num_splits), -1)}); shape_inference::ShapeHandle input_shape = c->input(2 * num_splits); shape_inference::ShapeHandle output_shape; TF_RETURN_IF_ERROR( c->Concatenate(input_shape, embedding_shape, &output_shape)); c->set_output(0, output_shape); return Status::OK(); }); REGISTER_OP("MonolithMapIdToEmbeddingGradient") .Input("ids: num_splits * int64") .Input("input: int64") .Input("grads: float") .Output("embedding_grads: num_splits * float") .Attr("num_splits: int") .SetShapeFn([](shape_inference::InferenceContext* ctx) { int num_splits; TF_RETURN_IF_ERROR(ctx->GetAttr("num_splits", &num_splits)); shape_inference::DimensionHandle embedding_size = ctx->Dim(ctx->input(num_splits + 1), -1); for (int i = 0; i < num_splits; ++i) { shape_inference::DimensionHandle len_ids = ctx->Dim(ctx->input(i), 0); ctx->set_output(i, ctx->MakeShape({len_ids, embedding_size})); } return Status::OK(); }); REGISTER_OP("MonolithGatherEmbeddingsByInput") .Input("ids: int64") .Input("embeddings: float") .Input("input: int64") .Output("output: float") .Output("output_index_mapping: int64") .SetDoNotOptimize() // Crash with grappler. .Attr("use_multi_threads: bool = false") .SetShapeFn([](shape_inference::InferenceContext* c) { shape_inference::ShapeHandle embedding_shape = c->MakeShape({c->Dim(c->input(1), -1)}); shape_inference::ShapeHandle input_shape = c->input(2); shape_inference::ShapeHandle output_shape; TF_RETURN_IF_ERROR( c->Concatenate(input_shape, embedding_shape, &output_shape)); c->set_output(0, output_shape); c->set_output(1, input_shape); return Status::OK(); }); REGISTER_OP("MonolithGatherEmbeddingsByInputGradient") .Input("ids: int64") .Input("grads: float") .Input("index_mapping: int64") .Output("embedding_grads: float") .SetDoNotOptimize() // Crash with grappler. .SetShapeFn([](shape_inference::InferenceContext* ctx) { shape_inference::DimensionHandle embedding_size = ctx->Dim(ctx->input(1), -1); shape_inference::DimensionHandle len_ids = ctx->Dim(ctx->input(0), 0); ctx->set_output(0, ctx->MakeShape({len_ids, embedding_size})); return Status::OK(); }); REGISTER_OP("MonolithFusedGatherEmbeddingsByInput") .Input("fused_embeddings: float") .Input("embedding_offsets: M * int32") .Output("output: M * float") .Attr("embedding_dims: list(int)") .Attr("M: int") .SetDoNotOptimize() // Crash with grappler. .SetShapeFn([](shape_inference::InferenceContext* c) { int M; std::vector embedding_dims; TF_RETURN_IF_ERROR(c->GetAttr("embedding_dims", &embedding_dims)); TF_RETURN_IF_ERROR(c->GetAttr("M", &M)); for (int i = 0; i < M; ++i) { shape_inference::DimensionHandle dim = c->Dim(c->input(1 + i), 0); c->set_output(i, c->MakeShape({dim, embedding_dims[i]})); } return Status::OK(); }); REGISTER_OP("MonolithFusedGatherEmbeddingsByInputGradient") .Input("fused_embeddings_size: int32") .Input("grads: M * float") .Input("embedding_offsets: M * int32") .Input("scale: float") .Output("output: float") .Attr("embedding_dims: list(int)") .Attr("M: int") .SetDoNotOptimize() // Crash with grappler. .SetShapeFn([](shape_inference::InferenceContext* ctx) { shape_inference::DimensionHandle output_dim; TF_RETURN_IF_ERROR(ctx->MakeDimForScalarInput(0, &output_dim)); ctx->set_output(0, ctx->Vector(output_dim)); return Status::OK(); }); REGISTER_KERNEL_BUILDER(Name("MonolithMapIdToEmbedding").Device(DEVICE_CPU), MapIdToEmbeddingOp); REGISTER_KERNEL_BUILDER( Name("MonolithMapIdToEmbeddingGradient").Device(DEVICE_CPU), MapIdToEmbeddingGradientOp); REGISTER_KERNEL_BUILDER( Name("MonolithGatherEmbeddingsByInput").Device(DEVICE_CPU), GatherEmbeddingsByInputOp); REGISTER_KERNEL_BUILDER( Name("MonolithGatherEmbeddingsByInputGradient").Device(DEVICE_CPU), GatherEmbeddingsByInputGradientOp); REGISTER_KERNEL_BUILDER( Name("MonolithFusedGatherEmbeddingsByInput").Device(DEVICE_CPU), FusedGatherEmbeddingsByInputOp); REGISTER_KERNEL_BUILDER( Name("MonolithFusedGatherEmbeddingsByInputGradient").Device(DEVICE_CPU), FusedGatherEmbeddingsByInputGradientOp); } // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/runtime/ops/multi_hash_table.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_NATIVE_TRAINING_RUNTIME_OPS_MULTI_HASH_TABLE_H_ #define MONOLITH_NATIVE_TRAINING_RUNTIME_OPS_MULTI_HASH_TABLE_H_ #include #include "absl/container/flat_hash_map.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "monolith/native_training/runtime/ops/embedding_hash_table_tf_bridge.h" #include "tensorflow/core/framework/resource_mgr.h" namespace tensorflow { namespace monolith_tf { class MultiHashTable : public ResourceBase { public: explicit MultiHashTable(absl::string_view shared_name) : shared_name_(std::string(shared_name)) {} void add_table(absl::string_view name, core::RefCountPtr table) { names_.push_back(std::string(name)); tables_.push_back(std::move(table)); } EmbeddingHashTableTfBridge* table(int i) const { return tables_[i].get(); } const std::vector& names() const { return names_; } const std::string& name(int i) { return names_[i]; } int size() const { return names_.size(); } const std::string& shared_name() const { return shared_name_; } std::string DebugString() const override { std::string ret; for (int i = 0; i < size(); ++i) { ret += absl::StrCat("name: ", names_[i], ":", tables_[i]->DebugString(), ";"); } return ret; } int64 MemoryUsed() const override { int64 ret = 0; for (const auto& table : tables_) { ret += table->MemoryUsed(); } return ret; } private: std::string shared_name_; std::vector> tables_; std::vector names_; }; } // namespace monolith_tf } // namespace tensorflow #endif // MONOLITH_NATIVE_TRAINING_RUNTIME_OPS_MULTI_HASH_TABLE_H_ ================================================ FILE: monolith/native_training/runtime/ops/multi_hash_table_lookup_op.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "absl/strings/str_format.h" #include "monolith/native_training/runtime/common/metrics.h" #include "monolith/native_training/runtime/hash_table/utils.h" #include "monolith/native_training/runtime/ops/embedding_hash_table_tf_bridge.h" #include "monolith/native_training/runtime/ops/multi_hash_table.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/platform/default/integral_types.h" #include "tensorflow/core/util/work_sharder.h" namespace tensorflow { namespace monolith_tf { using CPUDevice = Eigen::ThreadPoolDevice; class MultiHashTableLookupOp : public OpKernel { public: explicit MultiHashTableLookupOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} void Compute(OpKernelContext* ctx) override { core::RefCountPtr mtable; OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &mtable)); auto id_vec = ctx->input(1).flat(); auto id_split_vec = ctx->input(2).flat(); // OP_REQUIRES(ctx, id_split_vec.size() == mtable->size() + 1, // errors::InvalidArgument("id_split must be ", mtable->size() + 1, // ". Current: ", id_split_vec.size())); int req_size = (id_split_vec.size() - 1) / mtable->size(); OP_REQUIRES(ctx, id_split_vec.size() == mtable->size() * req_size + 1, errors::InvalidArgument("table size: ", mtable->size(), ". Error id_split size: ", id_split_vec.size())); std::vector each_emb_offset(req_size); int emb_size = 0; for (int req_i = 0; req_i < req_size; ++req_i) { each_emb_offset[req_i] = emb_size; for (int i = 0; i < mtable->size(); ++i) { int id_split_idx = req_i * mtable->size() + i; const int num_ids = id_split_vec(id_split_idx + 1) - id_split_vec(id_split_idx); emb_size += num_ids * mtable->table(i)->dim_size(); } } Tensor* emb_tensor; OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {emb_size}, &emb_tensor)); int64_t* id_ptr = const_cast(id_vec.data()); // int emb_offset = 0; float* emb_data = reinterpret_cast(emb_tensor->data()); int64_t total_hit_fid_count = 0, total_num_ids = 0; for (int i = 0; i < mtable->size(); ++i) { EmbeddingHashTableTfBridge* table = mtable->table(i); for (int req_i = 0; req_i < req_size; ++req_i) { int id_split_idx = req_i * mtable->size() + i; const int num_ids = id_split_vec(id_split_idx + 1) - id_split_vec(id_split_idx); total_num_ids += num_ids; int64_t hit_fid_count = 0; OP_REQUIRES_OK(ctx, table->BatchLookup(ctx, num_ids, id_ptr + id_split_vec(i), emb_data + each_emb_offset[req_i], &hit_fid_count)); total_hit_fid_count += hit_fid_count; each_emb_offset[req_i] += num_ids * table->dim_size(); } } if (mtable->size() && mtable->table(0)->IsServingEntryType() && total_num_ids) { const std::string tagkv = "name=all"; float hit_rate = total_hit_fid_count / static_cast(total_num_ids); monolith::GetMetrics()->emit_timer("lookup_fid_hit_rate", hit_rate, tagkv); } } }; class MultiHashTableLookupEntryOp : public OpKernel { public: explicit MultiHashTableLookupEntryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} void Compute(OpKernelContext* ctx) override { core::RefCountPtr mtable; OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &mtable)); auto id_vec = ctx->input(1).flat(); auto id_split_vec = ctx->input(2).flat(); OP_REQUIRES(ctx, id_split_vec.size() == mtable->size() + 1, errors::InvalidArgument("id_split must be ", mtable->size() + 1, ". Current: ", id_split_vec.size())); Tensor* entry_tensor; OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {id_vec.size()}, &entry_tensor)); auto entry = entry_tensor->vec(); int64_t* id_ptr = const_cast(id_vec.data()); for (int i = 0; i < mtable->size(); ++i) { EmbeddingHashTableTfBridge* table = mtable->table(i); const int num_ids = id_split_vec(i + 1) - id_split_vec(i); std::vector entries(num_ids); OP_REQUIRES_OK( ctx, table->BatchLookupEntry(ctx, num_ids, id_ptr + id_split_vec(i), entries.data())); for (int j = 0; j < num_ids; ++j) { entry(j + id_split_vec(i)) = entries[j].SerializeAsString(); } } } private: int num_tables_; bool enable_inter_table_parallelism_; int64 cost_per_table_; }; template class MultiHashTableFusedLookupOp : public OpKernel { public: explicit MultiHashTableFusedLookupOp(OpKernelConstruction* ctx) : OpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("num_of_shards", &num_shards_)); } void ComputeH(OpKernelContext* ctx); void Compute(OpKernelContext* ctx) override { ComputeH(ctx); } private: int num_shards_; }; template <> void MultiHashTableFusedLookupOp::ComputeH(OpKernelContext* ctx) { auto ids_flat = ctx->input(1).flat().data(); auto slot_size_vec = ctx->input(2).vec().data(); auto slot_size_cnt = ctx->input(2).NumElements(); Tensor *embeddings_ts, *emb_splits_ts, *key_offsets_ts, *emb_offsets_ts; OP_REQUIRES_OK(ctx, ctx->allocate_output(1, {num_shards_}, &emb_splits_ts)); OP_REQUIRES_OK(ctx, ctx->allocate_output(2, {slot_size_cnt + 1}, &key_offsets_ts)); OP_REQUIRES_OK(ctx, ctx->allocate_output(3, {slot_size_cnt + 1}, &emb_offsets_ts)); ctx->set_output(4, ctx->input(1)); auto key_offsets = key_offsets_ts->vec().data(); auto emb_offsets = emb_offsets_ts->vec().data(); auto emb_splits = emb_splits_ts->vec().data(); core::RefCountPtr mtable; OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &mtable)); int num_tables_ = mtable->size(); std::vector hash_table_dims(num_tables_, 0); for (int table_id = 0; table_id < num_tables_; table_id++) { hash_table_dims[table_id] = mtable->table(table_id)->dim_size(); } int total_keys, total_embs; std::tie(total_keys, total_embs) = monolith::hash_table::ComputeFusedOffsets( slot_size_vec, hash_table_dims.data(), num_tables_, num_shards_, key_offsets, emb_offsets, nullptr, emb_splits); OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {total_embs}, &embeddings_ts)); auto embeddings = embeddings_ts->vec().data(); auto lookup = [&](const int begin, const int end) { for (int shard_id = begin; shard_id < end; shard_id++) { for (int table_id = 0; table_id < num_tables_; table_id++) { int curr_idx = shard_id * num_tables_ + table_id; int64_t hit_fid_count = 0; mtable->table(table_id)->BatchLookup( ctx, slot_size_vec[curr_idx], const_cast(ids_flat) + key_offsets[curr_idx], embeddings + emb_offsets[curr_idx], &hit_fid_count); } } }; // TODO(zouxuan): tweak this number for optimization. const int64 kCostPerUnit = 1000000; const DeviceBase::CpuWorkerThreads& worker_threads = *ctx->device()->tensorflow_cpu_worker_threads(); Shard(worker_threads.num_threads, worker_threads.workers, num_shards_, kCostPerUnit, lookup); } REGISTER_OP("MonolithMultiHashTableLookup") .Input("mtable: resource") .Input("id: int64") .Input("id_split: int64") .Output("embedding: float") .SetShapeFn([](shape_inference::InferenceContext* c) { c->set_output(0, c->Vector(c->UnknownDim())); return Status::OK(); }); REGISTER_KERNEL_BUILDER(Name("MonolithMultiHashTableLookup").Device(DEVICE_CPU), MultiHashTableLookupOp); REGISTER_OP("MonolithMultiHashTableLookupEntry") .Input("mtable: resource") .Input("id: int64") .Input("id_split: int64") .Output("serialized_entries: string") .SetShapeFn([](shape_inference::InferenceContext* c) { c->set_output(0, c->input(0)); return Status::OK(); }); REGISTER_KERNEL_BUILDER( Name("MonolithMultiHashTableLookupEntry").Device(DEVICE_CPU), MultiHashTableLookupEntryOp); REGISTER_OP("MonolithMultiHashTableFusedLookup") .Input("mtable: resource") .Input("ids: int64") .Input("fused_slot_size: int32") .Input("req_time: int64") .Output("embeddings: float32") .Output("embedding_splits: int32") .Output("id_offsets: int32") .Output("embedding_offsets: int32") .Output("indices: int64") .Attr("num_of_shards: int") .SetShapeFn([](shape_inference::InferenceContext* c) { int num_shards; TF_RETURN_IF_ERROR(c->GetAttr("num_of_shards", &num_shards)); c->set_output(0, c->Vector(c->UnknownDim())); c->set_output(1, c->Vector(num_shards)); auto shape = c->input(2); TF_RETURN_IF_ERROR(c->WithRank(shape, 1, &shape)); auto dim = c->Dim(shape, 0); shape_inference::DimensionHandle out; TF_RETURN_IF_ERROR(c->Add(dim, 1, &out)); c->set_output(2, c->Vector(out)); c->set_output(3, c->Vector(out)); c->set_output(4, c->input(1)); return Status::OK(); }); REGISTER_KERNEL_BUILDER( Name("MonolithMultiHashTableFusedLookup").Device(DEVICE_CPU), MultiHashTableFusedLookupOp); } // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/runtime/ops/multi_hash_table_op.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include "absl/strings/str_format.h" #include "absl/synchronization/mutex.h" #include "monolith/native_training/runtime/hash_table/embedding_hash_table.pb.h" #include "monolith/native_training/runtime/ops/embedding_hash_table_tf_bridge.h" #include "monolith/native_training/runtime/ops/multi_hash_table.h" #include "monolith/native_training/runtime/ops/parameter_sync_tf_bridge.h" #include "monolith/native_training/runtime/parameter_sync/parameter_sync_client.h" #include "tensorflow/core/framework/lookup_interface.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/platform/default/integral_types.h" #include "tensorflow/core/platform/mutex.h" namespace tensorflow { namespace monolith_tf { using ::monolith::hash_table::MultiEmbeddingHashTableConfig; using ::monolith::parameter_sync::ParameterSyncClient; template class CreateMultiHashTableOp : public ResourceOpKernel { public: explicit CreateMultiHashTableOp(OpKernelConstruction* ctx) : ResourceOpKernel(ctx) {} void Compute(OpKernelContext* ctx) override { absl::MutexLock l(&mu_); if (ResourceOpKernel::resource_ == nullptr) { const tstring& serialized_config = ctx->input(0).scalar()(); OP_REQUIRES(ctx, config_.ParseFromArray(serialized_config.data(), serialized_config.size()), errors::InvalidArgument("Unable to parse config.")); n_ = config_.names_size(); OP_REQUIRES( ctx, config_.configs_size() == n_, errors::InvalidArgument( "`table_configs` size must equal to `N`, got filter_handles (", config_.names_size(), ") vs N (", n_, ")")); const auto& filter_handle = ctx->input(1).scalar()(); OP_REQUIRES_OK(ctx, LookupResource(ctx, filter_handle, &hash_filter_)); const auto& sync_client_handle = ctx->input(2).scalar()(); OP_REQUIRES_OK(ctx, LookupResource(ctx, sync_client_handle, &sync_client_)); SetupStream(ctx); } ResourceOpKernel::Compute(ctx); } void SetupStream(OpKernelContext* ctx); Status CreateResource(T** resource) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) override; private: absl::Mutex mu_; int n_ ABSL_GUARDED_BY(mu_); MultiEmbeddingHashTableConfig config_ ABSL_GUARDED_BY(mu_); core::RefCountPtr hash_filter_ ABSL_GUARDED_BY(mu_); core::RefCountPtr sync_client_ ABSL_GUARDED_BY(mu_); }; template <> void CreateMultiHashTableOp::SetupStream(OpKernelContext* ctx) { } template <> Status CreateMultiHashTableOp::CreateResource( MultiHashTable** resource) { auto* mtable = new MultiHashTable(cinfo_.name()); for (int i = 0; i < n_; i++) { EmbeddingHashTableTfBridge* hash_table; TF_RETURN_IF_ERROR(EmbeddingHashTableTfBridge::New( config_.configs(i), hash_filter_.get(), &hash_table, config_.names(i))); hash_table->SetHopscotchHashSet(sync_client_->GetTouchedKeySet()); mtable->add_table( config_.names(i), core::RefCountPtr(hash_table)); } sync_client_->SetMultiHashTableResource(mtable); *resource = mtable; return Status::OK(); } REGISTER_OP("CreateMonolithMultiHashTable") .Input("config: string") .Input("filter_handle: resource") .Input("sync_client_handle: resource") .Output("multi_hash_table_handle: resource") .Attr("container: string = ''") .Attr("shared_name: string = ''") .SetIsStateful() .SetShapeFn(shape_inference::ScalarShape); REGISTER_KERNEL_BUILDER(Name("CreateMonolithMultiHashTable").Device(DEVICE_CPU), CreateMultiHashTableOp); template class ReadMultiHashTableOp : public OpKernel { public: explicit ReadMultiHashTableOp(OpKernelConstruction* c) : OpKernel(c) {} void Compute(OpKernelContext* ctx) override { absl::MutexLock l(&mu_); if (cinfo_.name().empty()) { OP_REQUIRES_OK(ctx, cinfo_.Init(ctx->resource_manager(), def())); } OP_REQUIRES_OK( ctx, MakeResourceHandleToOutput(ctx, 0, cinfo_.container(), cinfo_.name(), TypeIndex::Make())); } private: absl::Mutex mu_; ContainerInfo cinfo_ TF_GUARDED_BY(mu_); }; REGISTER_OP("ReadMonolithMultiHashTable") .Output("multi_hash_table_handle: resource") .Attr("container: string = ''") .Attr("shared_name: string = ''") .SetIsStateful() .SetShapeFn(shape_inference::ScalarShape); REGISTER_KERNEL_BUILDER(Name("ReadMonolithMultiHashTable").Device(DEVICE_CPU), ReadMultiHashTableOp); template class IsHashTableInitializedOp : public OpKernel { public: explicit IsHashTableInitializedOp(OpKernelConstruction* c) : OpKernel(c) {} void Compute(OpKernelContext* ctx) override { core::RefCountPtr mtable; Status s = LookupResource(ctx, HandleFromInput(ctx, 0), &mtable); Tensor* output_tensor; OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {}, &output_tensor)); output_tensor->scalar()() = s.ok(); } }; REGISTER_OP("IsHashTableInitialized") .Input("handle: resource") .Output("is_initialized: bool") .SetIsStateful() .SetShapeFn(shape_inference::ScalarShape); REGISTER_KERNEL_BUILDER(Name("IsHashTableInitialized").Device(DEVICE_CPU), IsHashTableInitializedOp); } // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/runtime/ops/multi_hash_table_save_restore_ops.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/random/random.h" #include "absl/strings/str_cat.h" #include "absl/time/clock.h" #include "absl/time/time.h" #include "monolith/native_training/data/training_instance/cc/reader_util.h" #include "monolith/native_training/runtime/hash_table/embedding_hash_table.pb.h" #include "monolith/native_training/runtime/ops/embedding_hash_table_tf_bridge.h" #include "monolith/native_training/runtime/ops/file_utils.h" #include "monolith/native_training/runtime/ops/multi_hash_table.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/lib/io/record_reader.h" #include "tensorflow/core/lib/io/record_writer.h" #include "tensorflow/core/platform/path.h" #include "tensorflow/core/platform/random.h" #include "tensorflow/core/platform/threadpool.h" #include "third_party/nlohmann/json.hpp" namespace tensorflow { namespace monolith_tf { namespace { using tensorflow::strings::HumanReadableNumBytes; // Carries the data through async process. template struct AsyncPack { AsyncPack(OpKernelContext* p_ctx, core::RefCountPtr p_mtable, std::string p_basename, std::vector> p_lock_ctxs, std::function p_done, int p_thread_num) : ctx(p_ctx), basename(std::move(p_basename)), record_count(0), mtable(std::move(p_mtable)), lock_ctxs(std::move(p_lock_ctxs)), done(std::move(p_done)), thread_num(p_thread_num), finish_num(0), status(p_thread_num) {} ~AsyncPack() { for (const auto& s : status) { OP_REQUIRES_OK_ASYNC(ctx, s, done); } done(); } OpKernelContext* ctx; std::string basename; mutable std::atomic_long record_count; core::RefCountPtr mtable; std::vector> lock_ctxs; std::function done; const int thread_num; mutable std::atomic_int finish_num; mutable std::vector status; }; template struct EntryDumpIter { explicit EntryDumpIter(io::SequentialRecordReader* reader_, int64_t limit_) : reader(reader_), limit(limit_), offset(0) {} bool GetNext(const AsyncPack* p, EmbeddingHashTableTfBridge::EntryDump* dump, Status* status) { *status = Status::OK(); if (offset >= limit) return false; tstring s; *status = reader->ReadRecord(&s); if (!status->ok() || !dump->ParseFromArray(s.data(), s.size())) { *status = errors::DataLoss("Parse entry failed!"); return false; } offset++; p->record_count.fetch_add(1); return true; } io::SequentialRecordReader* reader; int64_t limit; int64_t offset; }; const char* const kShardedMetadataFileFormat = "%s.meta-%05d-of-%05d"; std::string GetShardedMetadataFileName(absl::string_view basename, int shard, int nshards) { return absl::StrFormat(kShardedMetadataFileFormat, basename, shard, nshards); } } // namespace template class MultiHashTableSaveOp : public AsyncOpKernel { public: explicit MultiHashTableSaveOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("nshards", &nshards_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("slot_expire_time_config", &slot_expire_time_config_serialized_)); if (!slot_expire_time_config_serialized_.empty()) { OP_REQUIRES( ctx, slot_expire_time_config_.ParseFromString( slot_expire_time_config_serialized_), errors::InvalidArgument("Unable to parse config. Make sure it " "is serialized version of " "SlotExpireTimeConfig.")); } slot_to_expire_time_.resize(get_max_slot_number(), slot_expire_time_config_.default_expire_time()); for (const auto& slot_expire_time : slot_expire_time_config_.slot_expire_times()) { slot_to_expire_time_[slot_expire_time.slot()] = slot_expire_time.expire_time(); } } void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { core::RefCountPtr mtable; OP_REQUIRES_OK_ASYNC( ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &mtable), done); const Tensor& basename_tensor = ctx->input(1); const std::string basename = basename_tensor.scalar()(); const std::string dirname = std::string(io::Dirname(basename)); OP_REQUIRES_OK_ASYNC(ctx, ctx->env()->RecursivelyCreateDir(dirname), done); int real_nshards = PickNshards(*mtable); std::vector> lock_ctxs; for (int i = 0; i < mtable->size(); ++i) { std::unique_ptr lock_ctx; OP_REQUIRES_OK_ASYNC(ctx, mtable->table(i)->LockAll(&lock_ctx), done); lock_ctxs.push_back(std::move(lock_ctx)); } auto pack = std::make_shared>( ctx, std::move(mtable), basename, std::move(lock_ctxs), std::move(done), real_nshards); for (int i = 0; i < real_nshards; ++i) { ctx->device()->tensorflow_cpu_worker_threads()->workers->Schedule( [this, pack, i, real_nshards] { WorkerThread({i, real_nshards}, pack); }); } ctx->set_output(0, ctx->input(0)); } private: void WorkerThread(EmbeddingHashTableTfBridge::DumpShard shard, std::shared_ptr> p) { p->status[shard.idx] = SaveOneShard(shard, p.get()); } Status SaveOneShard(EmbeddingHashTableTfBridge::DumpShard shard, const AsyncPack* p) { const std::string filename = GetShardedFileName(p->basename, shard.idx, shard.total); const std::string meta_filename = GetShardedMetadataFileName(p->basename, shard.idx, shard.total); const std::string tmp_filename = absl::StrCat(filename, "-tmp-", random::New64()); const std::string tmp_meta_filename = absl::StrCat(meta_filename, "-tmp-", random::New64()); std::unique_ptr fp; TF_RETURN_IF_ERROR(p->ctx->env()->NewWritableFile(tmp_filename, &fp)); std::unique_ptr fp_meta; TF_RETURN_IF_ERROR( p->ctx->env()->NewWritableFile(tmp_meta_filename, &fp_meta)); io::RecordWriterOptions options; options.compression_type = io::RecordWriterOptions::SNAPPY_COMPRESSION; io::RecordWriterOptions options_meta; io::RecordWriter writer(fp.get(), options); io::RecordWriter meta_writer(fp_meta.get(), options_meta); Status write_status; for (int table_idx = 0; table_idx < p->mtable->size(); table_idx++) { int64_t num_entries = 0; const EmbeddingHashTableTfBridge* table = p->mtable->table(table_idx); const std::string& table_name = p->mtable->name(table_idx); int64_t max_update_ts_sec = table->max_update_ts_sec(); auto write_fn = [&](EmbeddingHashTableTfBridge::EntryDump dump) { int64_t slot_id = slot_id_v2(dump.id()); // Elements of slot_to_expire_time_ are in days. // last_update_ts_sec is seconds since the Epoch. if (max_update_ts_sec - dump.last_update_ts_sec() >= slot_to_expire_time_[slot_id] * 24 * 3600) { return true; } Status s = writer.WriteRecord(dump.SerializeAsString()); if (TF_PREDICT_FALSE(!s.ok())) { // OK to throw here since it will be catched. write_status = s; return false; } num_entries++; return true; }; EmbeddingHashTableTfBridge::DumpIterator iter; TF_RETURN_IF_ERROR(table->Save(p->ctx, shard, write_fn, &iter)); monolith::hash_table::MultiHashTableMetadata meta; meta.set_table_name(table_name); meta.set_num_entries(num_entries); TF_RETURN_IF_ERROR(meta_writer.WriteRecord(meta.SerializeAsString())); } TF_RETURN_IF_ERROR(writer.Close()); TF_RETURN_IF_ERROR(meta_writer.Close()); TF_RETURN_IF_ERROR(fp->Close()); TF_RETURN_IF_ERROR(fp_meta->Close()); TF_RETURN_IF_ERROR(p->ctx->env()->RenameFile(tmp_filename, filename)); TF_RETURN_IF_ERROR( p->ctx->env()->RenameFile(tmp_meta_filename, meta_filename)); return Status::OK(); } int PickNshards(const TableType& mtable) { if (nshards_ >= 0) return nshards_; int64 total_size = 0; const int64 kBaseline = 1000000ll; for (size_t i = 0; i < mtable.size(); i++) { total_size += mtable.table(i)->Size(); } return std::min(4LL, std::max(1LL, total_size / kBaseline)); } int nshards_; std::string slot_expire_time_config_serialized_; monolith::hash_table::SlotExpireTimeConfig slot_expire_time_config_; std::vector slot_to_expire_time_; }; template class MultiHashTableRestoreOp : public AsyncOpKernel { public: explicit MultiHashTableRestoreOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {} void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { core::RefCountPtr mtable; OP_REQUIRES_OK_ASYNC( ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &mtable), done); const Tensor& basename_tensor = ctx->input(1); const std::string basename = basename_tensor.scalar()(); std::vector files; OP_REQUIRES_OK_ASYNC( ctx, ctx->env()->GetMatchingPaths(absl::StrCat(basename, "-*"), &files), done); FileSpec file_spec; OP_REQUIRES_OK_ASYNC(ctx, ValidateShardedFiles(basename, files, &file_spec), done); OP_REQUIRES_ASYNC(ctx, file_spec.nshards() > 0, errors::NotFound("Unable to find the dump files for: ", name(), " in ", basename), done); int nshards = file_spec.nshards(); auto pack = std::make_shared>( ctx, std::move(mtable), basename, std::vector>(), std::move(done), nshards); for (int i = 0; i < nshards; ++i) { ctx->device()->tensorflow_cpu_worker_threads()->workers->Schedule( [this, pack, i, nshards] { WorkerThread({i, nshards}, pack); }); } ctx->set_output(0, ctx->input(0)); } private: void WorkerThread(EmbeddingHashTableTfBridge::DumpShard shard, std::shared_ptr> p) { p->status[shard.idx] = RestoreOneShard(shard, p.get()); if (p->finish_num.fetch_add(1) == p->thread_num - 1) { int64_t total_byte_size = 0, total_uncompressed_byte_size = 0, total_size = 0; for (int i = 0; i < p->mtable->size(); ++i) { auto t = p->mtable->table(i); auto name = p->mtable->name(i); auto summary = t->Summary(); LOG(INFO) << absl::StrFormat("Hash table: %s, summary: %s", name, summary); LogSummary(summary, &total_byte_size, &total_uncompressed_byte_size); total_size += t->Size(); } LOG(INFO) << absl::StrFormat( "Restore read %ld records, skip %ld zero embeddings", p->record_count, p->record_count - total_size); LOG(INFO) << absl::StrFormat( "total memory: %s, total memory if not compressed: %s", HumanReadableNumBytes(total_byte_size), HumanReadableNumBytes(total_uncompressed_byte_size)); } } Status RestoreOneShard(EmbeddingHashTableTfBridge::DumpShard shard, const AsyncPack* p) { std::string filename = GetShardedFileName(p->basename, shard.idx, shard.total); std::string meta_filename = GetShardedMetadataFileName(p->basename, shard.idx, shard.total); std::unique_ptr fp; std::unique_ptr fp_meta; TF_RETURN_IF_ERROR(p->ctx->env()->NewRandomAccessFile(filename, &fp)); TF_RETURN_IF_ERROR( p->ctx->env()->NewRandomAccessFile(meta_filename, &fp_meta)); io::RecordReaderOptions options; options.compression_type = io::RecordReaderOptions::SNAPPY_COMPRESSION; options.buffer_size = 10 * 1024 * 1024; io::SequentialRecordReader reader(fp.get(), options); io::RecordReaderOptions options_meta; io::SequentialRecordReader meta_reader(fp_meta.get(), options_meta); absl::flat_hash_set tables_in_shard; absl::flat_hash_map name_to_idx; for (int i = 0; i < p->mtable->size(); ++i) { name_to_idx.insert({p->mtable->name(i), i}); } bool eof = false; Status restore_status; while (!eof) { tstring meta_pb; Status meta_status = meta_reader.ReadRecord(&meta_pb); if (!meta_status.ok()) { if (errors::IsOutOfRange(meta_status)) { eof = true; break; } else { return errors::DataLoss("Read table metadata failed!"); } } monolith::hash_table::MultiHashTableMetadata meta; if (!meta.ParseFromArray(meta_pb.data(), meta_pb.size())) { return errors::DataLoss("Parse table metadata failed!"); } auto name_iter = name_to_idx.find(meta.table_name()); if (name_iter == name_to_idx.end()) { if (shard.idx == 0) { LOG(INFO) << "Table " << meta.table_name() << " in checkpoint. skipped."; } tstring dummy_str; for (int64_t i = 0; i < meta.num_entries(); i++) { TF_RETURN_IF_ERROR(reader.ReadRecord(&dummy_str)); } continue; } tables_in_shard.insert(meta.table_name()); EmbeddingHashTableTfBridge* table = p->mtable->table(name_iter->second); EntryDumpIter entry_iter(&reader, meta.num_entries()); auto get_fn = [&](EmbeddingHashTableTfBridge::EntryDump* dump, int64_t* max_update_ts) { if (!entry_iter.GetNext(p, dump, &restore_status)) return false; if (!dump->has_last_update_ts_sec()) { dump->set_last_update_ts_sec(0); } *max_update_ts = std::max(dump->last_update_ts_sec(), *max_update_ts); return true; }; TF_RETURN_IF_ERROR(table->Restore(p->ctx, shard, get_fn)); TF_RETURN_IF_ERROR(restore_status); } if (shard.idx == 0) { for (const std::string& table_name : p->mtable->names()) { if (!tables_in_shard.contains(table_name)) { LOG(WARNING) << "Table " << table_name << " not found checkpoint."; } } } if (!eof) return errors::DataLoss("Couldn't read all of checkpoint shard ", shard.idx); return Status::OK(); } private: template ::value> inline typename std::enable_if::type LogSummary( const std::string& summary, int64_t* total_byte_size, int64_t* total_uncompressed_byte_size) { nlohmann::json json = nlohmann::json::parse(summary); CHECK(json.contains("memory")); CHECK(json.contains("memory_if_not_compressed")); *total_byte_size += int64_t(json["memory"]); *total_uncompressed_byte_size += int64_t(json["memory_if_not_compressed"]); } template ::value> inline typename std::enable_if::type LogSummary( const std::string& summary, int64_t* total_byte_size, int64_t* total_uncompressed_byte_size) {} }; class MultiHashTableFeatureStatOp : public OpKernel { public: explicit MultiHashTableFeatureStatOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} void Compute(OpKernelContext* ctx) override { const Tensor& basename_tensor = ctx->input(0); const std::string basename = basename_tensor.scalar()(); std::vector files; OP_REQUIRES_OK(ctx, ctx->env()->GetMatchingPaths( absl::StrCat(basename, "-*"), &files)); OP_REQUIRES_OK(ctx, ValidateShardedFiles(basename, files)); OP_REQUIRES(ctx, !files.empty(), errors::NotFound("Unable to find the dump files for: ", name(), " in ", basename)); absl::flat_hash_map feature_count; int nshards = files.size(); for (int idx = 0; idx < nshards; ++idx) { std::string filename = GetShardedMetadataFileName(basename, idx, nshards); std::unique_ptr fp; OP_REQUIRES_OK(ctx, ctx->env()->NewRandomAccessFile(filename, &fp)); io::RecordReaderOptions options; io::SequentialRecordReader reader(fp.get(), options); bool eof = false; while (!eof) { tstring meta_pb; Status s = reader.ReadRecord(&meta_pb); if (!s.ok()) { if (errors::IsOutOfRange(s)) { eof = true; break; } else { OP_REQUIRES(ctx, s.ok(), errors::DataLoss("Read table metadata failed!")); } } monolith::hash_table::MultiHashTableMetadata meta; OP_REQUIRES(ctx, meta.ParseFromArray(meta_pb.data(), meta_pb.size()), errors::DataLoss("Parse table metadata failed!")); if (!feature_count.contains(meta.table_name())) { feature_count[meta.table_name()] = 0; } feature_count[meta.table_name()] += meta.num_entries(); } OP_REQUIRES(ctx, eof, errors::DataLoss( "Couldn't read all of checkpoint shard ", idx)); } int num_tables = feature_count.size(); Tensor* features; OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({ num_tables, }), &features)); auto features_vec = features->vec(); Tensor* counts; OP_REQUIRES_OK(ctx, ctx->allocate_output(1, TensorShape({ num_tables, }), &counts)); auto counts_vec = counts->vec(); int feature_iter = 0; for (const auto& it : feature_count) { features_vec(feature_iter) = it.first; counts_vec(feature_iter) = it.second; feature_iter++; } } }; REGISTER_OP("MonolithMultiHashTableSave") .Input("mtable: resource") .Input("basename: string") .Output("output_mtable: resource") .Attr("nshards: int=-1") .Attr("slot_expire_time_config: string = ''") .SetShapeFn(shape_inference::ScalarShape); REGISTER_KERNEL_BUILDER(Name("MonolithMultiHashTableSave").Device(DEVICE_CPU), MultiHashTableSaveOp); REGISTER_OP("MonolithMultiHashTableRestore") .Input("mtable: resource") .Input("basename: string") .Output("output_mtable: resource") .SetShapeFn(shape_inference::ScalarShape); REGISTER_KERNEL_BUILDER( Name("MonolithMultiHashTableRestore").Device(DEVICE_CPU), MultiHashTableRestoreOp); REGISTER_OP("MonolithMultiHashTableFeatureStat") .Input("basename: string") .Output("features: string") .Output("counts: uint64") .SetShapeFn([](shape_inference::InferenceContext* ctx) { ctx->set_output(0, ctx->Vector(ctx->UnknownDim())); ctx->set_output(1, ctx->Vector(ctx->UnknownDim())); return Status::OK(); }); REGISTER_KERNEL_BUILDER( Name("MonolithMultiHashTableFeatureStat").Device(DEVICE_CPU), MultiHashTableFeatureStatOp); } // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/runtime/ops/multi_hash_table_update_op.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "absl/types/span.h" #include "monolith/native_training/runtime/concurrency/queue.h" #include "monolith/native_training/runtime/ops/embedding_hash_table_tf_bridge.h" #include "monolith/native_training/runtime/ops/hash_filter_tf_bridge.h" #include "monolith/native_training/runtime/ops/multi_hash_table.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/util/work_sharder.h" namespace tensorflow { namespace monolith_tf { using CPUDevice = Eigen::ThreadPoolDevice; namespace { using monolith::concurrency::Queue; Status MismatchLength(absl::string_view tensor_name, int tensor_size, int expected_size) { return errors::InvalidArgument("The length of tensor `", tensor_name, "` doesn't equal to table num. ", tensor_size, "v.s.", expected_size); } Status LengthTooShort(absl::string_view tensor_name, int tensor_size) { return errors::InvalidArgument("The length of tensor `", tensor_name, "` is too short. Currently value", tensor_size); } class MultiHashTableOptimizeOp : public OpKernel { public: explicit MultiHashTableOptimizeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} void Compute(OpKernelContext* c) override { core::RefCountPtr mtable; OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &mtable)); auto id_vec = c->input(1).flat(); auto id_split = c->input(2).flat(); OP_REQUIRES(c, id_split.size() - 1 == mtable->size(), MismatchLength("id", id_split.size() - 1, mtable->size())); auto value_vec = c->input(3).flat(); auto learning_rate_vec = c->input(4).flat(); int64 update_time = c->input(5).scalar()(); int64 global_step = c->input(6).scalar()(); int n = mtable->size(); int value_offset = 0; int learning_rate_offset = 0; for (int i = 0; i < n; ++i) { EmbeddingHashTableTfBridge* table = mtable->table(i); const int num_ids = id_split(i + 1) - id_split(i); const int value_size = (id_split(i + 1) - id_split(i)) * table->dim_size(); OP_REQUIRES(c, value_offset + value_size <= value_vec.size(), LengthTooShort("value", value_vec.size())); auto learning_rate = absl::MakeConstSpan( learning_rate_vec.data() + learning_rate_offset, table->slice_size()); learning_rate_offset += table->slice_size(); OP_REQUIRES(c, learning_rate_offset <= learning_rate_vec.size(), LengthTooShort("learning_rate", learning_rate_vec.size())); OP_REQUIRES_OK( c, table->BatchOptimize( c, num_ids, reinterpret_cast(id_vec.data() + id_split(i)), value_vec.data() + value_offset, learning_rate, update_time, false, global_step)); value_offset += value_size; } c->set_output(0, c->input(0)); } }; REGISTER_OP("MonolithMultiHashTableOptimize") .Input("mtable: resource") .Input("id: int64") .Input("id_split: int64") .Input("value: float") .Input("learning_rate: float") .Input("update_time: int64") .Input("global_step: int64") .Output("updated_table: resource") .SetShapeFn(shape_inference::ScalarShape); REGISTER_KERNEL_BUILDER( Name("MonolithMultiHashTableOptimize").Device(DEVICE_CPU), MultiHashTableOptimizeOp); class MultiHashTableAssignOp : public OpKernel { public: explicit MultiHashTableAssignOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} void Compute(OpKernelContext* c) override { core::RefCountPtr mtable; OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &mtable)); auto id_vec = c->input(1).flat(); auto id_split = c->input(2).flat(); OP_REQUIRES(c, id_split.size() - 1 == mtable->size(), MismatchLength("id", id_split.size() - 1, mtable->size())); auto value_vec = c->input(3).flat(); int64 update_time = c->input(4).scalar()(); int n = mtable->size(); int value_offset = 0; for (int i = 0; i < n; ++i) { EmbeddingHashTableTfBridge* table = mtable->table(i); const int num_ids = id_split(i + 1) - id_split(i); const int value_size = num_ids * table->dim_size(); OP_REQUIRES(c, value_offset + value_size <= value_vec.size(), LengthTooShort("value", value_vec.size())); OP_REQUIRES_OK( c, table->Assign( c, num_ids, reinterpret_cast(id_vec.data() + id_split(i)), value_vec.data() + value_offset, update_time)); value_offset += value_size; } c->set_output(0, c->input(0)); } }; REGISTER_OP("MonolithMultiHashTableAssign") .Input("mtable: resource") .Input("id: int64") .Input("id_split: int64") .Input("value: float") .Input("update_time: int64") .Output("updated_table: resource") .SetShapeFn(shape_inference::ScalarShape); REGISTER_KERNEL_BUILDER(Name("MonolithMultiHashTableAssign").Device(DEVICE_CPU), MultiHashTableAssignOp); class MultiHashTableAssignAddOp : public OpKernel { public: explicit MultiHashTableAssignAddOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} void Compute(OpKernelContext* c) override { core::RefCountPtr mtable; OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &mtable)); auto id_vec = c->input(1).flat(); auto id_split = c->input(2).flat(); OP_REQUIRES(c, id_split.size() - 1 == mtable->size(), MismatchLength("id", id_split.size() - 1, mtable->size())); auto value_vec = c->input(3).flat(); int64 update_time = c->input(4).scalar()(); int n = mtable->size(); int value_offset = 0; for (int i = 0; i < n; ++i) { EmbeddingHashTableTfBridge* table = mtable->table(i); const int num_ids = id_split(i + 1) - id_split(i); const int value_size = num_ids * table->dim_size(); OP_REQUIRES(c, value_offset + value_size <= value_vec.size(), LengthTooShort("value", value_vec.size())); for (int j = id_split(i); j < id_split(i + 1); ++j) { auto value = absl::MakeConstSpan(value_vec.data() + value_offset, table->dim_size()); OP_REQUIRES_OK(c, table->AssignAdd2(id_vec(j), value, update_time)); value_offset += table->dim_size(); } } c->set_output(0, c->input(0)); } }; REGISTER_OP("MonolithMultiHashTableAssignAdd") .Input("mtable: resource") .Input("id: int64") .Input("id_split: int64") .Input("value: float") .Input("update_time: int64") .Output("updated_table: resource") .SetShapeFn(shape_inference::ScalarShape); REGISTER_KERNEL_BUILDER( Name("MonolithMultiHashTableAssignAdd").Device(DEVICE_CPU), MultiHashTableAssignAddOp); class MultiHashTableReinitializeOp : public OpKernel { public: explicit MultiHashTableReinitializeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} void Compute(OpKernelContext* c) override { core::RefCountPtr mtable; OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &mtable)); auto table_name = c->input(1).scalar()(); auto id_vec = c->input(2).flat(); Tensor* status_tensor; OP_REQUIRES_OK(c, c->allocate_output(1, {id_vec.size()}, &status_tensor)); auto status_vec = status_tensor->vec(); // -1: table_name does not exist, and the id will not be processed // 0: the id was inserted and is initialized // 1: the id was already in the table and is reinitialized status_vec.setConstant(-1); int* status = reinterpret_cast(status_tensor->data()); std::vector names = mtable->names(); auto it = std::find_if( names.begin(), names.end(), [&table_name](const std::string& name) { return name == table_name; }); if (it == names.end()) { LOG(ERROR) << "table " << table_name << " does not exist!"; } else { int index = std::distance(names.begin(), it); EmbeddingHashTableTfBridge* table = mtable->table(index); OP_REQUIRES_OK(c, table->Reinitialize( reinterpret_cast(id_vec.data()), id_vec.size(), status)); } c->set_output(0, c->input(0)); } }; REGISTER_OP("MonolithMultiHashTableReinitialize") .Input("mtable: resource") .Input("table_name: string") .Input("id: int64") .Output("updated_table: resource") .Output("id_status: int32") .SetShapeFn([](shape_inference::InferenceContext* ctx) { ctx->set_output(0, ctx->Scalar()); ctx->set_output(1, ctx->input(2)); return Status::OK(); }); REGISTER_KERNEL_BUILDER( Name("MonolithMultiHashTableReinitialize").Device(DEVICE_CPU), MultiHashTableReinitializeOp); template class MultiHashTableFusedOptimizeOp : public OpKernel { public: explicit MultiHashTableFusedOptimizeOp(OpKernelConstruction* ctx) : OpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("num_of_shards", &num_shards_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("enable_grad_accumulation", &enable_grad_accumulation_)); } void ComputeH(OpKernelContext* ctx); void Compute(OpKernelContext* ctx) override { ComputeH(ctx); ctx->set_output(0, ctx->input(0)); } private: bool enable_grad_accumulation_; int num_shards_; }; template <> void MultiHashTableFusedOptimizeOp::ComputeH(OpKernelContext* ctx) { auto ids = ctx->input(1).vec().data(); auto num_ids = ctx->input(1).NumElements(); auto indices = ctx->input(2).vec().data(); auto slot_size_vec = ctx->input(3).vec().data(); auto id_grads = ctx->input(4).vec().data(); auto num_grads = ctx->input(4).NumElements(); auto key_offsets = ctx->input(5).vec().data(); auto emb_offsets = ctx->input(6).vec().data(); auto learning_rates = ctx->input(7).vec().data(); auto req_time = ctx->input(8).scalar()(); auto global_step = ctx->input(9).scalar()(); core::RefCountPtr mtable; OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &mtable)); int num_tables_ = mtable->size(); auto optimize = [&](const int begin, const int end) { for (int shard_id = begin; shard_id < end; shard_id++) { int learning_rate_offset = 0; for (int table_id = 0; table_id < num_tables_; table_id++) { int curr_idx = shard_id * num_tables_ + table_id; auto table = mtable->table(table_id); auto learning_rate = absl::MakeConstSpan( learning_rates + learning_rate_offset, table->slice_size()); learning_rate_offset += table->slice_size(); table->BatchOptimize(ctx, slot_size_vec[curr_idx], ids + key_offsets[curr_idx], id_grads + emb_offsets[curr_idx], learning_rate, req_time, enable_grad_accumulation_, global_step); } } }; // TODO(zouxuan): tweak this number for optimization. const int64 kCostPerUnit = 10000000; const DeviceBase::CpuWorkerThreads& worker_threads = *ctx->device()->tensorflow_cpu_worker_threads(); Shard(worker_threads.num_threads, worker_threads.workers, num_shards_, kCostPerUnit, optimize); } REGISTER_OP("MonolithMultiHashTableFusedOptimize") .Input("mtable: resource") .Input("ids: int64") .Input("indices: int64") .Input("fused_slot_size: int32") .Input("id_grads: float") .Input("id_offsets: int32") .Input("grad_offsets: int32") .Input("learning_rate_tensors: float") .Input("req_time: int64") .Input("global_step: int64") .Output("mtable_out: resource") .Attr("num_of_shards: int") .Attr("enable_grad_accumulation: bool = false") .SetShapeFn(shape_inference::ScalarShape); REGISTER_KERNEL_BUILDER( Name("MonolithMultiHashTableFusedOptimize").Device(DEVICE_CPU), MultiHashTableFusedOptimizeOp); } // namespace } // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/runtime/ops/net_utils.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include #include #include #include namespace tensorflow { namespace monolith_tf { std::multimap GetLocalIpAddreeses() { std::multimap addresses; ifaddrs *ifaddrs_list = nullptr; getifaddrs(&ifaddrs_list); for (ifaddrs *ifa = ifaddrs_list; ifa != nullptr; ifa = ifa->ifa_next) { if (!ifa->ifa_addr) { continue; } void *tmp_addr = nullptr; if (ifa->ifa_addr->sa_family == AF_INET) { // check it is IP4 tmp_addr = &(reinterpret_cast(ifa->ifa_addr)->sin_addr); char buffer[INET_ADDRSTRLEN]; inet_ntop(AF_INET, tmp_addr, buffer, INET_ADDRSTRLEN); addresses.insert({ifa->ifa_name, buffer}); } else if (ifa->ifa_addr->sa_family == AF_INET6) { // check it is IP6 // is a valid IP6 Address tmp_addr = &(reinterpret_cast(ifa->ifa_addr)->sin6_addr); char buffer[INET6_ADDRSTRLEN]; inet_ntop(AF_INET6, tmp_addr, buffer, INET6_ADDRSTRLEN); addresses.insert({ifa->ifa_name, buffer}); } } if (ifaddrs_list != nullptr) freeifaddrs(ifaddrs_list); return addresses; } std::string GetMyHostIp() { // If we are in TCE, env var will provide ip to us. char *ip = getenv("MY_HOST_IP"); if (ip != nullptr) { return ip; } auto addresses = GetLocalIpAddreeses(); auto it = addresses.find("eth0"); if (it == addresses.end()) { return ""; } return it->second; } } // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/runtime/ops/net_utils.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_NATIVE_TRAINING_RUNTIME_OPS_NET_UTILS_H_ #define MONOLITH_NATIVE_TRAINING_RUNTIME_OPS_NET_UTILS_H_ #include #include namespace tensorflow { namespace monolith_tf { // Gets network interface name to ip addresses mapping. std::multimap GetLocalIpAddreeses(); // Gets a string represents ip address of eth0. std::string GetMyHostIp(); } // namespace monolith_tf } // namespace tensorflow #endif // MONOLITH_NATIVE_TRAINING_RUNTIME_OPS_NET_UTILS_H_ ================================================ FILE: monolith/native_training/runtime/ops/net_utils_test.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/runtime/ops/net_utils.h" #include "gtest/gtest.h" namespace tensorflow { namespace monolith_tf { TEST(NetUtilsTest, Basic) { GetLocalIpAddreeses(); GetMyHostIp(); } } // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/runtime/ops/normalize_merged_split_op.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" #include "absl/strings/str_format.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/platform/threadpool.h" namespace tensorflow { namespace monolith_tf { namespace { class NormalizeMergedSplitOp : public OpKernel { public: explicit NormalizeMergedSplitOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} void Compute(OpKernelContext* ctx) override { const Tensor *row_split_input; OP_REQUIRES_OK(ctx, ctx->input("row_split", &row_split_input)); const Tensor *row_split_size_input; OP_REQUIRES_OK(ctx, ctx->input("row_split_size", &row_split_size_input)); const auto row_split_vec = row_split_input->flat(); int split_num = row_split_input->dim_size(0); const auto row_split_size_vec = row_split_size_input->flat(); int merge_num = row_split_size_input->dim_size(0); int output_size = split_num + 1 - merge_num; Tensor *normed_row_split_tensor; OP_REQUIRES_OK(ctx, ctx->allocate_output("normed_row_split", TensorShape({ output_size, }), &normed_row_split_tensor)); auto normed_row_split_flat = normed_row_split_tensor->flat(); int offset = 0; int pre_size = 0; int output_idx = 0; /* row_split: 0, 2, 5, 5, 9, 0, 0, 3, 4 row_split_size: 5, 4 offset = 0, pre_size = 0 output: 0, 2, 5, 5, 9 offset = 5, pre_size = 9 output: 0, 2, 5, 5, 9, 9, 12, 13 offset = 9, pre_size = 13 */ for (size_t i = 0; i < merge_num; ++i) { if (i == 0) { for (int j = offset; j < offset + row_split_size_vec(i); ++j) { normed_row_split_flat(output_idx) = pre_size + row_split_vec(j); output_idx++; } } else { for (int j = offset + 1; j < offset + row_split_size_vec(i); ++j) { normed_row_split_flat(output_idx) = pre_size + row_split_vec(j); output_idx++; } } offset += row_split_size_vec(i); pre_size += row_split_vec(offset - 1); } } }; REGISTER_OP("MonolithNormalizeMergedSplit") .Input("row_split: int64") .Input("row_split_size: int32") .Output("normed_row_split: int64") .SetShapeFn([](shape_inference::InferenceContext* ctx) { ctx->set_output(0, ctx->Vector(ctx->UnknownDim())); return Status::OK(); }); REGISTER_KERNEL_BUILDER(Name("MonolithNormalizeMergedSplit").Device(DEVICE_CPU), NormalizeMergedSplitOp); } // namespace } // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/runtime/ops/parameter_sync_ops.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "grpcpp/ext/proto_server_reflection_plugin.h" #include "grpcpp/grpcpp.h" #include "grpcpp/health_check_service_interface.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_op_kernel.h" #include "tensorflow/core/platform/tstring.h" #include "monolith/native_training/runtime/ops/parameter_sync_tf_bridge.h" #include "monolith/native_training/runtime/parameter_sync/dummy_sync_client.h" #include "monolith/native_training/runtime/parameter_sync/parameter_sync.pb.h" #include "monolith/native_training/runtime/parameter_sync/parameter_sync_client.h" namespace tensorflow { namespace monolith_tf { class DummySyncServerOp : public ResourceOpKernel { public: explicit DummySyncServerOp(OpKernelConstruction* ctx) : ResourceOpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("address", &address_)); } ~DummySyncServerOp() override = default; private: Status CreateResource(DummySyncServerTfBridge** server_bridge) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) override { *server_bridge = new DummySyncServerTfBridge(address_); return Status::OK(); }; std::string address_; }; class DummySyncServerShutdownOp : public OpKernel { public: explicit DummySyncServerShutdownOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} void Compute(OpKernelContext* ctx) override { DummySyncServerTfBridge* server = nullptr; OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &server)); core::ScopedUnref unref(server); server->Shutdown(); // TODO(zhangbiao.david): remove LOG(INFO) << server->DebugString() << " has been shutdown"; Tensor* output; OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {1}, &output)); auto output_vec = output->vec(); output_vec(0) = 100; } }; class DummySyncServerGetPortOp : public OpKernel { public: explicit DummySyncServerGetPortOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} void Compute(OpKernelContext* ctx) override { DummySyncServerTfBridge* server = nullptr; OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &server)); core::ScopedUnref unref(server); int port = server->GetSelectedPort(); Tensor* output; OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {1}, &output)); output->scalar()() = port; } }; class DummySyncClientOp : public ResourceOpKernel { public: explicit DummySyncClientOp(OpKernelConstruction* ctx) : ResourceOpKernel(ctx) {} ~DummySyncClientOp() override = default; private: Status CreateResource(ParameterSyncClientTfBridge** client_bridge) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) override { *client_bridge = new ParameterSyncClientTfBridge(true, [](const std::string& target) { return std::make_unique( target); }); return Status::OK(); }; }; class ParameterSyncClientOp : public ResourceOpKernel { public: explicit ParameterSyncClientOp(OpKernelConstruction* ctx) : ResourceOpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("config", &config_serialized_)); OP_REQUIRES(ctx, config_.ParseFromString(config_serialized_), errors::InvalidArgument("Unable to parse config. Make " "sure it is serialized version of " "ClientConfig")); } ~ParameterSyncClientOp() override = default; private: Status CreateResource(ParameterSyncClientTfBridge** client_bridge) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) override { *client_bridge = new ParameterSyncClientTfBridge(false, [](const std::string& target) { return std::make_unique< monolith::parameter_sync::ParameterSyncClient>(target); }); (*client_bridge) ->TryReplace(config_.targets(), config_.targets_extra_info()); return Status::OK(); }; std::string config_serialized_; monolith::parameter_sync::ClientConfig config_; }; class ParameterSyncOp : public OpKernel { public: explicit ParameterSyncOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} void Compute(OpKernelContext* ctx) override { ParameterSyncClientTfBridge* client = nullptr; OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &client)); core::ScopedUnref unref(client); const Tensor* config_str; monolith::parameter_sync::ClientConfig config; OP_REQUIRES_OK(ctx, ctx->input("config_str", &config_str)); OP_REQUIRES(ctx, config.ParseFromString(config_str->flat()(0)), errors::InvalidArgument("Unable to parse config. Make " "sure it is serialized version of " "ClientConfig")); client->TryReplace(config.targets(), config.targets_extra_info()); LOG_EVERY_N_SEC(INFO, 600) << client->DebugString(); LOG_EVERY_N_SEC(INFO, 600) << "ClientConfig: " << config.ShortDebugString() << std::endl; monolith::parameter_sync::PushResult result; OP_REQUIRES_OK(ctx, client->Push(config.model_name(), config.signature_name(), config.timeout_in_ms(), &result)); std::string json; auto option = google::protobuf::util::JsonOptions(); option.add_whitespace = true; option.preserve_proto_field_names = true; google::protobuf::util::MessageToJsonString(result, &json, option); Tensor* output; OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {1}, &output)); auto result_vec = output->vec(); result_vec(0) = json; } }; REGISTER_OP("MonolithDummySyncServer") .Output("handle: resource") .Attr("address: string") .Attr("container: string = ''") .Attr("shared_name: string = ''") .SetIsStateful() .SetShapeFn(shape_inference::ScalarShape); REGISTER_KERNEL_BUILDER(Name("MonolithDummySyncServer").Device(DEVICE_CPU), DummySyncServerOp); REGISTER_OP("MonolithDummySyncServerShutdown") .Input("handle: resource") .Output("size: int64") .SetIsStateful() .SetShapeFn(shape_inference::ScalarShape); REGISTER_KERNEL_BUILDER( Name("MonolithDummySyncServerShutdown").Device(DEVICE_CPU), DummySyncServerShutdownOp); REGISTER_OP("MonolithDummySyncServerGetPort") .Input("handle: resource") .Output("size: int32") .SetIsStateful() .SetShapeFn(shape_inference::ScalarShape); REGISTER_KERNEL_BUILDER( Name("MonolithDummySyncServerGetPort").Device(DEVICE_CPU), DummySyncServerGetPortOp); REGISTER_OP("MonolithParameterSyncClient") .Output("handle: resource") .Attr("config: string") .Attr("container: string = ''") .Attr("shared_name: string = ''") .SetIsStateful() .SetShapeFn(shape_inference::ScalarShape); REGISTER_KERNEL_BUILDER(Name("MonolithParameterSyncClient").Device(DEVICE_CPU), ParameterSyncClientOp); REGISTER_OP("MonolithDummySyncClient") .Output("handle: resource") .Attr("config: string = ''") .Attr("container: string = ''") .Attr("shared_name: string = ''") .SetIsStateful() .SetShapeFn(shape_inference::ScalarShape); REGISTER_KERNEL_BUILDER(Name("MonolithDummySyncClient").Device(DEVICE_CPU), DummySyncClientOp); REGISTER_OP("MonolithParameterSync") .Input("handle: resource") .Input("config_str: string") .Output("result: string") .SetIsStateful() .SetShapeFn(shape_inference::ScalarShape); REGISTER_KERNEL_BUILDER(Name("MonolithParameterSync").Device(DEVICE_CPU), ParameterSyncOp); } // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/runtime/ops/parameter_sync_tf_bridge.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/runtime/ops/parameter_sync_tf_bridge.h" #include "absl/strings/str_cat.h" namespace tensorflow { namespace monolith_tf { namespace { using ::monolith::parameter_sync::ClientConfig_TargetExtraInfo; using ::monolith::parameter_sync::PushRequest; using ::monolith::parameter_sync::PushResult; void AddIdToDelta( const std::string& name, const EmbeddingHashTableTfBridge& table, const std::vector& ids, monolith::parameter_sync::PushRequest_DeltaEmbeddingHashTable* delta) { int dim_size = static_cast(table.dim_size()); std::vector embedding(dim_size); delta->set_unique_id(name); delta->set_dim_size(dim_size); int delta_size = static_cast(ids.size()); auto* mutable_fids = delta->mutable_fids(); auto* embeddings = delta->mutable_embeddings(); mutable_fids->Reserve(delta_size); embeddings->Reserve(delta_size * dim_size); for (int64_t id : ids) { mutable_fids->Add(id); table.Lookup(nullptr, id, embedding.data()); embeddings->Add(embedding.data(), embedding.data() + dim_size); } } } // namespace Status ParameterSyncClientTfBridge::Push(const std::string& model_name, const std::string& signature_name, int64_t timeout_in_ms, PushResult* result) const { try { PushRequest request; const bool is_mtable = mtable_ != nullptr; request.set_model_name(model_name); if (is_mtable) { // TODO(leqi.zou): Currently it is hard coded. // Will revisit this part later. request.set_signature_name( absl::StrCat(mtable_->shared_name(), "/raw_assign")); request.mutable_delta_multi_hash_tables()->Reserve(mtable_->size()); } else { request.set_signature_name(signature_name); request.mutable_delta_hash_tables()->Reserve(hash_tables_.size()); } request.set_timeout_in_ms(timeout_in_ms); std::vector> fids_and_tables = touched_key_set_->GetAndClear(); std::unordered_map> table_to_fids; for (const auto& fid_and_table : fids_and_tables) { table_to_fids[fid_and_table.second].push_back(fid_and_table.first); } if (is_mtable) { for (int i = 0; i < mtable_->size(); ++i) { AddIdToDelta(mtable_->name(i), *mtable_->table(i), table_to_fids[mtable_->table(i)], request.mutable_delta_multi_hash_tables()->Add()); } } else { for (const auto& kv : hash_tables_) { const std::string& name = kv.first; const auto* table = kv.second; AddIdToDelta(name, *table, table_to_fids[table], request.mutable_delta_hash_tables()->Add()); } } if (fids_and_tables.size() > 0) { *result = sync_client_manager_->Push(request, model_name, signature_name); LOG_EVERY_N_SEC(INFO, 600) << "Response: " << result->ShortDebugString(); } else { LOG_EVERY_N_SEC(INFO, 600) << "No updated FIDs!"; } return Status::OK(); } catch (const std::exception& e) { return errors::InvalidArgument(e.what()); } } Status ParameterSyncClientTfBridge::TryReplace( const google::protobuf::RepeatedPtrField& targets, const google::protobuf::RepeatedPtrField& targets_extra_info) { try { sync_client_manager_->TryReplace(targets, targets_extra_info); return Status::OK(); } catch (const std::exception& e) { return errors::InvalidArgument(e.what()); } } } // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/runtime/ops/parameter_sync_tf_bridge.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_NATIVE_TRAINING_RUNTIME_OPS_PARAMETER_SYNC_TF_BRIDGE_H_ #define MONOLITH_NATIVE_TRAINING_RUNTIME_OPS_PARAMETER_SYNC_TF_BRIDGE_H_ #include #include #include "absl/strings/str_format.h" #include "absl/synchronization/mutex.h" #include "monolith/native_training/runtime/ops/embedding_hash_table_tf_bridge.h" #include "monolith/native_training/runtime/ops/multi_hash_table.h" #include "monolith/native_training/runtime/parameter_sync/dummy_sync_server.h" #include "monolith/native_training/runtime/parameter_sync/sync_client_interface.h" #include "monolith/native_training/runtime/parameter_sync/sync_client_manager.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/framework/resource_op_kernel.h" namespace tensorflow { namespace monolith_tf { // 64MB const size_t MAX_TOUCHED_KEYS = 64 * 1024 * 1024 / (8 * 4); class DummySyncServerTfBridge : public ResourceBase { public: using DummySyncServer = monolith::parameter_sync::DummySyncServer; explicit DummySyncServerTfBridge(const std::string& target) { server_ = std::make_unique(target); } void Shutdown() const { server_->Shutdown(); } std::string GetTarget() const { return server_->GetTarget(); } int GetSelectedPort() const { return server_->GetSelectedPort(); } std::string DebugString() const override { return absl::StrFormat("DummySyncServerTfBridge target = %s", server_->GetTarget()); } private: std::unique_ptr server_; }; class ParameterSyncClientTfBridge : public ResourceBase { public: using SyncClientInterface = monolith::parameter_sync::SyncClientInterface; using PushResult = monolith::parameter_sync::PushResult; using SyncClientManager = monolith::parameter_sync::SyncClientManager; ParameterSyncClientTfBridge( bool is_dummy_sync_client, std::function(const std::string&)> client_factory) : is_dummy_sync_client_(is_dummy_sync_client) { sync_client_manager_ = std::make_unique(std::move(client_factory)); if (!IsDummySyncClient()) { touched_key_set_ = std::move( std::make_unique>>( MAX_TOUCHED_KEYS, 1024)); } } Status Push(const std::string& model_name, const std::string& signature_name, int64_t timeout_in_ms, PushResult* result) const; Status TryReplace( const google::protobuf::RepeatedPtrField& targets, const google::protobuf::RepeatedPtrField< monolith::parameter_sync::ClientConfig_TargetExtraInfo>& targets_extra_info); Status AddHashTableResource(const std::string& name, EmbeddingHashTableTfBridge* hash_table) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) { absl::WriterMutexLock l(&mu_); DCHECK(!hash_tables_.count(name)); if (mtable_ != nullptr) { return errors::InvalidArgument( "Only one type of tables can be set. MultiHashTable is set."); } hash_tables_[name] = hash_table; return Status::OK(); } Status SetMultiHashTableResource(MultiHashTable* mtable) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) { absl::WriterMutexLock l(&mu_); if (mtable_ != nullptr) { return errors::AlreadyExists( "The sync client is set a mtable resource already."); } if (hash_tables_.size() > 0) { return errors::InvalidArgument( "Only one type of tables can be set. HashTable is set."); } mtable_ = mtable; return Status::OK(); } std::string DebugString() const override { std::vector hash_table_names; hash_table_names.reserve(hash_tables_.size()); std::transform(hash_tables_.begin(), hash_tables_.end(), std::back_inserter(hash_table_names), [](const auto& kv) { return kv.first; }); return absl::StrFormat("hash tables = [%s]", absl::StrJoin(hash_table_names, ", ")); } bool IsDummySyncClient() const { return is_dummy_sync_client_; } HopscotchHashSet>* GetTouchedKeySet() { return touched_key_set_.get(); } private: // hash table name -> hash table resource std::map hash_tables_ ABSL_GUARDED_BY(mu_); MultiHashTable* mtable_ = nullptr; std::unique_ptr sync_client_manager_; std::unique_ptr>> touched_key_set_; mutable absl::Mutex mu_; bool is_dummy_sync_client_; }; } // namespace monolith_tf } // namespace tensorflow #endif // MONOLITH_NATIVE_TRAINING_RUNTIME_OPS_PARAMETER_SYNC_TF_BRIDGE_H_ ================================================ FILE: monolith/native_training/runtime/ops/prediction_service_grpc.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. /* Copyright 2020 Google Inc. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "monolith/native_training/runtime/ops/prediction_service_grpc.h" #include "absl/time/clock.h" #include "grpcpp/create_channel.h" #include "grpcpp/security/credentials.h" #include "grpcpp/support/channel_arguments.h" namespace tensorflow { namespace monolith_tf { namespace { absl::Status FromGrpcStatus(const ::grpc::Status& s) { if (s.ok()) { return absl::Status(); } return absl::Status(static_cast(s.error_code()), s.error_message()); } int GetCallbackThreadNum() { const char* thread_num_str = std::getenv("MONOLITH_GRPC_REMOTE_CALLBACK_THREADS"); if (thread_num_str == nullptr) { return 10; } return std::stoi(std::string(thread_num_str)); } } // namespace ::grpc::CompletionQueue* GetSharedCompletionQueue() { static CompletionQueueWithThreads* cq_with_threads = new CompletionQueueWithThreads(GetCallbackThreadNum()); return cq_with_threads->GetCompletionQueue(); } CompletionQueueWithThreads::CompletionQueueWithThreads( const size_t thread_num) { queues_ = std::vector<::grpc::CompletionQueue>(thread_num); for (size_t i = 0; i < thread_num; ++i) { auto* cq = &queues_[i]; auto pooling_fn = [cq]() { void* p_tag; bool ok; while (cq->Next(&p_tag, &ok)) { RemotePredictCQTag* cq_tag = static_cast(p_tag); cq_tag->OnCompleted(ok); } }; cq_threads_.emplace_back(std::make_unique(pooling_fn)); } } CompletionQueueWithThreads::~CompletionQueueWithThreads() { for (size_t i = 0; i < queues_.size(); ++i) { queues_[i].Shutdown(); } for (size_t i = 0; i < cq_threads_.size(); ++i) { cq_threads_[i]->join(); } } ::grpc::CompletionQueue* CompletionQueueWithThreads::GetCompletionQueue() { return &queues_[queue_idx_++ % queues_.size()]; } PredictionServiceGrpcPerAddress::PredictionServiceGrpcPerAddress( const std::string& target_address) { // TODO(b/159739577): Set security channel from incoming rpc request. // auto channel = ::grpc::CreateChannel(target_address, // ::grpc::InsecureChannelCredentials()); ::grpc::ChannelArguments arg; arg.SetMaxReceiveMessageSize(INT32_MAX); arg.SetMaxSendMessageSize(INT32_MAX); auto channel = ::grpc::CreateCustomChannel( target_address, ::grpc::InsecureChannelCredentials(), arg); stub_ = tensorflow::serving::PredictionService::NewStub(channel); } void PredictionServiceGrpcPerAddress::Predict( tensorflow::serving::PredictRequest* request, tensorflow::serving::PredictResponse* response, std::function callback, int64_t max_rpc_deadline_millis, DoneCallback op_done) { ::grpc::ClientContext* rpc = new ::grpc::ClientContext; DoneCallback rpc_done = [rpc, done = op_done]() { delete rpc; done(); }; std::function wrapped_callback = [callback, rpc_done = std::move(rpc_done)](::grpc::Status status) mutable { callback(FromGrpcStatus(status), std::forward(rpc_done)); }; new RemotePredictCQTag(GetSharedCompletionQueue(), rpc, &stub_, request, response, std::move(wrapped_callback)); } } // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/runtime/ops/prediction_service_grpc.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. /* Copyright 2020 Google Inc. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #ifndef MONOLITH_NATIVE_TRAINING_RUNTIME_OPS_PREDICTION_SERVICE_GRPC_H_ #define MONOLITH_NATIVE_TRAINING_RUNTIME_OPS_PREDICTION_SERVICE_GRPC_H_ #include #include #include "absl/status/status.h" #include "absl/time/time.h" #include "tensorflow_serving/apis/prediction_service.grpc.pb.h" namespace tensorflow { namespace monolith_tf { class RemotePredictCQTag { public: RemotePredictCQTag( ::grpc::CompletionQueue *cq, ::grpc::ClientContext *rpc, std::unique_ptr<::tensorflow::serving::PredictionService::Stub> *stub_, ::tensorflow::serving::PredictRequest *request, ::tensorflow::serving::PredictResponse *response, std::function callback) : response_(response), callback_(std::move(callback)) { std::unique_ptr< grpc::ClientAsyncResponseReader<::tensorflow::serving::PredictResponse>> rpc_call = (*stub_)->AsyncPredict(rpc, *request, cq); rpc_call->Finish(response, &status_, reinterpret_cast(this)); }; ~RemotePredictCQTag() {} // OnCompleted is invoked when the RPC has finished. // Implementations of OnCompleted can delete *this. void OnCompleted(bool ok) { callback_(status_); delete this; } private: ::tensorflow::serving::PredictResponse *response_; std::function callback_; grpc::Status status_; }; class CompletionQueueWithThreads { public: explicit CompletionQueueWithThreads(const size_t thread_num); ~CompletionQueueWithThreads(); ::grpc::CompletionQueue *GetCompletionQueue(); private: std::atomic_ullong queue_idx_; std::vector<::grpc::CompletionQueue> queues_; std::vector> cq_threads_; }; ::grpc::CompletionQueue *GetSharedCompletionQueue(); // gRPC based communication point with PredictionService. class PredictionServiceGrpcPerAddress { public: using DoneCallback = std::function; explicit PredictionServiceGrpcPerAddress(const std::string &target_address); void Predict( ::tensorflow::serving::PredictRequest *request, ::tensorflow::serving::PredictResponse *response, std::function callback, int64_t max_rpc_deadline_millis, DoneCallback op_done); private: std::unique_ptr<::tensorflow::serving::PredictionService::Stub> stub_; }; class PredictionServiceGrpc { public: using DoneCallback = std::function; void Predict( tensorflow::serving::PredictRequest *request, tensorflow::serving::PredictResponse *response, std::function callback, int64_t max_rpc_deadline_millis, DoneCallback op_done) { size_t idx = std::rand() % services_.size(); services_[idx]->Predict(request, response, callback, max_rpc_deadline_millis, op_done); } explicit PredictionServiceGrpc(const std::vector &address_list) { size_t n = address_list.size(); services_.reserve(n); for (const auto &addr : address_list) { services_.push_back( std::make_unique(addr)); } } private: std::vector> services_; }; } // namespace monolith_tf } // namespace tensorflow #endif // MONOLITH_NATIVE_TRAINING_RUNTIME_OPS_PREDICTION_SERVICE_GRPC_H_ ================================================ FILE: monolith/native_training/runtime/ops/reduce_op.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/platform/threadpool.h" #include "tensorflow/core/util/work_sharder.h" #include "monolith/native_training/runtime/hash_table/optimizer/avx_utils.h" namespace tensorflow { namespace monolith_tf { // The difference between this reduce sum op and tf.sparse.reduce_sum is that // this supports sparse values which are vectors. class ReduceSumOp : public OpKernel { public: explicit ReduceSumOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} void Compute(OpKernelContext* ctx) override { const Tensor& id_indices = ctx->input(0); auto id_indices_mat = id_indices.matrix(); const Tensor& id_values = ctx->input(1); const int64 value_size = id_values.shape().dim_size(1); auto id_values_mat = id_values.matrix(); const Tensor& id_dense_shape = ctx->input(2); const int64 batch_size = id_dense_shape.flat()(0); Tensor* reduced; OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {batch_size, value_size}, &reduced)); std::memset(reduced->data(), 0, reduced->AllocatedBytes()); auto reduced_mat = reduced->matrix(); for (int64 i = 0; i < id_indices_mat.dimension(0); ++i) { int64 batch = id_indices_mat(i, 0); reduced_mat.chip<0>(batch) += id_values_mat.chip<0>(i); } } }; // The difference between this reduce mean op and tf.sparse.reduce_mean is that // this supports sparse values which are vectors. class ReduceMeanOp : public OpKernel { public: explicit ReduceMeanOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} void Compute(OpKernelContext* ctx) override { const Tensor& id_indices = ctx->input(0); auto id_indices_mat = id_indices.matrix(); const Tensor& id_values = ctx->input(1); const int64 value_size = id_values.shape().dim_size(1); auto id_values_mat = id_values.matrix(); const Tensor& id_dense_shape = ctx->input(2); const int64 batch_size = id_dense_shape.flat()(0); Tensor* reduced; OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {batch_size, value_size}, &reduced)); std::memset(reduced->data(), 0, reduced->AllocatedBytes()); auto reduced_mat = reduced->matrix(); std::vector counter(batch_size, 0); for (int64 i = 0; i < id_indices_mat.dimension(0); ++i) { int64 batch = id_indices_mat(i, 0); reduced_mat.chip<0>(batch) += id_values_mat.chip<0>(i); counter[batch] += 1; } for (int64 i = 0; i < batch_size; ++i) { float multiply = 1.0 / static_cast(counter[i]); for (int64 j = 0; j < value_size; ++j) { reduced_mat(i, j) *= multiply; } } } }; // The difference between this reduce square norm and tf.sparse.segment_sqrt_n // is that this supports sparse values which are vectors. class ReduceSquareNormOp : public OpKernel { public: explicit ReduceSquareNormOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} void Compute(OpKernelContext* ctx) override { const Tensor& id_indices = ctx->input(0); auto id_indices_mat = id_indices.matrix(); const Tensor& id_values = ctx->input(1); const int64 value_size = id_values.shape().dim_size(1); auto id_values_mat = id_values.matrix(); const Tensor& id_dense_shape = ctx->input(2); const int64 batch_size = id_dense_shape.flat()(0); Tensor* reduced; OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {batch_size, value_size}, &reduced)); std::memset(reduced->data(), 0, reduced->AllocatedBytes()); auto reduced_mat = reduced->matrix(); for (int64 i = 0; i < id_indices_mat.dimension(0); ++i) { int64 batch = id_indices_mat(i, 0); for (int64 j = 0; j < value_size; ++j) { reduced_mat(batch, j) += (id_values_mat(i, j) * id_values_mat(i, j)); } } for (int64 i = 0; i < batch_size; ++i) { for (int64 j = 0; j < value_size; ++j) { reduced_mat(i, j) = std::sqrt(reduced_mat(i, j)); } } } }; class ReduceSumGradientOp : public OpKernel { public: explicit ReduceSumGradientOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} void Compute(OpKernelContext* ctx) override { const Tensor& id_indices = ctx->input(0); auto id_indices_mat = id_indices.matrix(); const int64 len_ids = id_indices_mat.dimension(0); const Tensor& grads = ctx->input(1); auto grads_mat = grads.matrix(); const int64 grad_size = grads_mat.dimension(1); Tensor* id_value_grads; OP_REQUIRES_OK( ctx, ctx->allocate_output(0, {len_ids, grad_size}, &id_value_grads)); auto id_value_grads_flat = id_value_grads->flat(); // Single thread is actually more efficient. for (int64 i = 0; i < len_ids; ++i) { int64 batch = id_indices_mat(i, 0); std::memcpy( static_cast(id_value_grads_flat.data()) + i * grad_size, const_cast(grads_mat.data()) + batch * grad_size, sizeof(float) * grad_size); } } }; class ReduceMeanGradientOp : public OpKernel { public: explicit ReduceMeanGradientOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} void Compute(OpKernelContext* ctx) override { const Tensor& id_indices = ctx->input(0); auto id_indices_mat = id_indices.matrix(); const int64 len_ids = id_indices_mat.dimension(0); const Tensor& grads = ctx->input(1); auto grads_mat = grads.matrix(); const int64 grad_size = grads_mat.dimension(1); Tensor* id_value_grads; OP_REQUIRES_OK( ctx, ctx->allocate_output(0, {len_ids, grad_size}, &id_value_grads)); auto id_value_grads_mat = id_value_grads->matrix(); // grad_size equals to indices's batch size. std::vector counter(grad_size, 0); for (int64 i = 0; i < len_ids; ++i) { int64 batch = id_indices_mat(i, 0); counter[batch] += 1; } for (int64 i = 0; i < len_ids; ++i) { int64 batch = id_indices_mat(i, 0); float multiply = 1.0 / static_cast(counter[batch]); for (int64 j = 0; j < grad_size; ++j) { id_value_grads_mat(i, j) = grads_mat(batch, j) * multiply; } } } }; class ReduceSquareNormGradientOp : public OpKernel { public: explicit ReduceSquareNormGradientOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} void Compute(OpKernelContext* ctx) override { const Tensor& id_indices = ctx->input(0); auto id_indices_mat = id_indices.matrix(); const int64 len_ids = id_indices_mat.dimension(0); const Tensor& id_values = ctx->input(1); auto id_values_mat = id_values.matrix(); const Tensor& grads = ctx->input(2); auto grads_mat = grads.matrix(); const int64 batch_size = grads_mat.dimension(0); const int64 grad_size = grads_mat.dimension(1); Tensor* id_value_grads; OP_REQUIRES_OK( ctx, ctx->allocate_output(0, {len_ids, grad_size}, &id_value_grads)); auto id_value_grads_mat = id_value_grads->matrix(); Tensor reduced_values(DT_FLOAT, TensorShape({batch_size, grad_size})); std::memset(reduced_values.data(), 0, reduced_values.AllocatedBytes()); auto reduced_mat = reduced_values.matrix(); for (int64 i = 0; i < len_ids; ++i) { int64 batch = id_indices_mat(i, 0); for (int64 j = 0; j < grad_size; ++j) { reduced_mat(batch, j) += (id_values_mat(i, j) * id_values_mat(i, j)); } } for (int64 i = 0; i < batch_size; ++i) { for (int64 j = 0; j < grad_size; ++j) { reduced_mat(i, j) = std::sqrt(reduced_mat(i, j)); } } for (int64 i = 0; i < len_ids; ++i) { int64 batch = id_indices_mat(i, 0); for (int64 j = 0; j < grad_size; ++j) { // dl/dx = x/sqrt(sum(x)) * dl/dy float multiply = (reduced_mat(batch, j) == 0) ? 0.0 : id_values_mat(i, j) / reduced_mat(batch, j); id_value_grads_mat(i, j) = grads_mat(batch, j) * multiply; } } } }; // This is an extended op that supports reducesum and split in one fused op. class ReduceSumAndSplitOp : public OpKernel { public: explicit ReduceSumAndSplitOp(OpKernelConstruction* ctx) : OpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("M", &M_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("split_dims", &split_dims_)); } void Compute(OpKernelContext* ctx) override { const Tensor& id_indices = ctx->input(0); auto id_indices_mat = id_indices.matrix(); const Tensor& id_values = ctx->input(1); const int64 value_size = id_values.shape().dim_size(1); auto id_values_mat = id_values.matrix(); const Tensor& id_dense_shape = ctx->input(2); const int64 batch_size = id_dense_shape.flat()(0); std::vector reduced_list(M_); for (int i = 0; i < M_; ++i) { OP_REQUIRES_OK(ctx, ctx->allocate_output(i, {batch_size, split_dims_[i]}, &reduced_list[i])); std::memset(reduced_list[i]->data(), 0, reduced_list[i]->AllocatedBytes()); } for (int64 i = 0; i < id_indices_mat.dimension(0); ++i) { int64 batch = id_indices_mat(i, 0); int emb_offset = 0; for (int j = 0; j < M_; ++j) { auto reduced_mat = reduced_list[j]->matrix(); int embedding_dim = split_dims_[j]; float* input_a = const_cast(id_values_mat.data()) + i * value_size + emb_offset; float* output_b = static_cast(reduced_mat.data()) + batch * split_dims_[j]; ::monolith::hash_table::ReduceSum(input_a, output_b, output_b, split_dims_[j]); emb_offset += split_dims_[j]; } } } private: int M_; std::vector split_dims_; }; class ReduceSumAndSplitGradientOp : public OpKernel { public: explicit ReduceSumAndSplitGradientOp(OpKernelConstruction* ctx) : OpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("M", &M_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("split_dims", &split_dims_)); grad_dim_ = 0; for (int i = 0; i < M_; i++) { grad_dim_ += split_dims_[i]; } } void Compute(OpKernelContext* ctx) override { const Tensor& id_indices = ctx->input(0); auto id_indices_mat = id_indices.matrix(); const int64 len_ids = id_indices_mat.dimension(0); Tensor* id_value_grads; OP_REQUIRES_OK( ctx, ctx->allocate_output(0, {len_ids, grad_dim_}, &id_value_grads)); auto id_value_grads_flat = id_value_grads->flat(); int offset = 0; for (int i = 0; i < M_; ++i) { const Tensor& grads = ctx->input(i + 1); auto grads_mat = grads.matrix(); const int64 grad_size = grads_mat.dimension(1); CHECK(grad_size == split_dims_[i]); auto block_size = sizeof(float) * grad_size; for (int64 j = 0; j < len_ids; ++j) { int64 batch = id_indices_mat(j, 0); std::memcpy(static_cast(id_value_grads_flat.data()) + j * grad_dim_ + offset, const_cast(grads_mat.data()) + batch * grad_size, block_size); } offset += grad_size; } } private: int M_; std::vector split_dims_; int grad_dim_; }; Status ReduceShape(shape_inference::InferenceContext* ctx) { shape_inference::ShapeHandle dense_handle; TF_RETURN_IF_ERROR(ctx->MakeShapeFromShapeTensor(2, &dense_handle)); shape_inference::DimensionHandle dim0 = ctx->Dim(dense_handle, 0); shape_inference::DimensionHandle dim1 = ctx->Dim(ctx->input(1), 1); ctx->set_output(0, ctx->MakeShape({dim0, dim1})); return Status::OK(); } Status GradientReduceShape(shape_inference::InferenceContext* ctx) { shape_inference::DimensionHandle len_id = ctx->Dim(ctx->input(0), 0); shape_inference::DimensionHandle grad_size = ctx->Dim(ctx->input(1), 1); ctx->set_output(0, ctx->MakeShape({len_id, grad_size})); return Status::OK(); } Status FusedReduceShape(shape_inference::InferenceContext* ctx) { int M; TF_RETURN_IF_ERROR(ctx->GetAttr("M", &M)); std::vector split_dims; TF_RETURN_IF_ERROR(ctx->GetAttr("split_dims", &split_dims)); CHECK_EQ(split_dims.size(), M); shape_inference::ShapeHandle dense_handle; TF_RETURN_IF_ERROR(ctx->MakeShapeFromShapeTensor(2, &dense_handle)); for (int i = 0; i < M; i++) { shape_inference::DimensionHandle dim0 = ctx->Dim(dense_handle, 0); ctx->set_output(i, ctx->MakeShape({dim0, split_dims[i]})); } return Status::OK(); } Status FusedGradientReduceShape(shape_inference::InferenceContext* ctx) { int M; TF_RETURN_IF_ERROR(ctx->GetAttr("M", &M)); std::vector split_dims; TF_RETURN_IF_ERROR(ctx->GetAttr("split_dims", &split_dims)); CHECK_EQ(split_dims.size(), M); int grad_dim = 0; for (int i = 0; i < M; i++) { grad_dim += split_dims[i]; } shape_inference::DimensionHandle len_id = ctx->Dim(ctx->input(0), 0); ctx->set_output(0, ctx->MakeShape({len_id, grad_dim})); return Status::OK(); } REGISTER_OP("MonolithReduceSum") .Input("id_indices: int64") .Input("id_values: float") .Input("id_dense_shape: int64") .Output("reduced: float") .SetShapeFn(ReduceShape); REGISTER_OP("MonolithReduceMean") .Input("id_indices: int64") .Input("id_values: float") .Input("id_dense_shape: int64") .Output("reduced: float") .SetShapeFn(ReduceShape); REGISTER_OP("MonolithReduceSquareNorm") .Input("id_indices: int64") .Input("id_values: float") .Input("id_dense_shape: int64") .Output("reduced: float") .SetShapeFn(ReduceShape); REGISTER_OP("MonolithReduceSumGradient") .Input("id_indices: int64") .Input("grads: float") .Output("id_values_grads: float") .SetShapeFn(GradientReduceShape); REGISTER_OP("MonolithReduceMeanGradient") .Input("id_indices: int64") .Input("grads: float") .Output("id_values_grads: float") .SetShapeFn(GradientReduceShape); REGISTER_OP("MonolithReduceSquareNormGradient") .Input("id_indices: int64") .Input("id_values: float") .Input("grads: float") .Output("id_values_grads: float") .SetShapeFn(GradientReduceShape); REGISTER_OP("MonolithFusedReduceSumAndSplit") .Input("id_indices: int64") .Input("id_values: float") .Input("id_dense_shape: int64") .Output("reduced: M * float") .Attr("M: int") .Attr("split_dims: list(int)") .SetShapeFn(FusedReduceShape); REGISTER_OP("MonolithFusedReduceSumAndSplitGradient") .Input("id_indices: int64") .Input("grads: M * float") .Output("output_grad: float") .Attr("M: int") .Attr("split_dims: list(int)") .SetShapeFn(FusedGradientReduceShape); REGISTER_KERNEL_BUILDER(Name("MonolithReduceSum").Device(DEVICE_CPU), ReduceSumOp); REGISTER_KERNEL_BUILDER(Name("MonolithReduceMean").Device(DEVICE_CPU), ReduceMeanOp); REGISTER_KERNEL_BUILDER(Name("MonolithReduceSquareNorm").Device(DEVICE_CPU), ReduceSquareNormOp); REGISTER_KERNEL_BUILDER( Name("MonolithFusedReduceSumAndSplit").Device(DEVICE_CPU), ReduceSumAndSplitOp); REGISTER_KERNEL_BUILDER(Name("MonolithReduceSumGradient").Device(DEVICE_CPU), ReduceSumGradientOp); REGISTER_KERNEL_BUILDER(Name("MonolithReduceMeanGradient").Device(DEVICE_CPU), ReduceMeanGradientOp); REGISTER_KERNEL_BUILDER( Name("MonolithReduceSquareNormGradient").Device(DEVICE_CPU), ReduceSquareNormGradientOp); REGISTER_KERNEL_BUILDER( Name("MonolithFusedReduceSumAndSplitGradient").Device(DEVICE_CPU), ReduceSumAndSplitGradientOp); } // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/runtime/ops/reduce_op.cu.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #if GOOGLE_CUDA #define EIGEN_USE_GPU #include "monolith/native_training/runtime/ops/alloc_utils.h" #include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/kernels/gpu_device_array.h" #include "tensorflow/core/kernels/gpu_device_array_gpu.h" #include "tensorflow/core/profiler/lib/traceme.h" #include "tensorflow/core/util/gpu_kernel_helper.h" namespace tensorflow { namespace monolith_tf { typedef Eigen::GpuDevice GPUDevice; // To run mnay segment_sum ops on various input lengths and emb dims, // in one single GPU kernel. We define input group i: // * indices with n_i length and s_i segments, // s_i <= n_i as input_outer_dim_size; // s_i <= output_outer_dim_size also; // For example, [1,1,1,2,2,4] with n_i = 5, s_i = 3 // where output_outer_dim_size >= 4 >= s_i // * values with n_i input_outer_dim_size and d_i dims // The total computation workload is sum n_i * d_i on i. // For all n_i, we stride with a fixed length k_n, so that // the same stride can have chance to reduce in local thread. // The total gpu workload is now the sum on i of // [(n_i // k_n) + 1] * d_i template __global__ void FusedSortedSegmentSumCustomKernel( GpuDeviceArrayStruct input_outer_dim_sizes_data, GpuDeviceArrayStruct inner_dim_sizes_data, GpuDeviceArrayStruct output_outer_dim_sizes_data, GpuDeviceArrayStruct segment_idss_data, // __restrict__ GpuDeviceArrayStruct inputs_data, // __restrict__ GpuDeviceArrayStruct outputs_data, // __restrict__ GpuDeviceArrayStruct stripe_offsets_data, const Index total_stripe_count) { Index* input_outer_dim_sizes = GetGpuDeviceArrayOnDevice(&input_outer_dim_sizes_data); Index* inner_dim_sizes = GetGpuDeviceArrayOnDevice(&inner_dim_sizes_data); Index* output_outer_dim_sizes = GetGpuDeviceArrayOnDevice(&output_outer_dim_sizes_data); const Index* __restrict__* segment_idss = GetGpuDeviceArrayOnDevice(&segment_idss_data); const T* __restrict__* inputs = GetGpuDeviceArrayOnDevice(&inputs_data); T* __restrict__* outputs = GetGpuDeviceArrayOnDevice(&outputs_data); Index* stripe_offsets = GetGpuDeviceArrayOnDevice(&stripe_offsets_data); // if using shared memory // Ref: // https://github.com/tensorflow/tensorflow/blob/v2.4.0/tensorflow/core/kernels/split_lib_gpu.cu.cc#L124 GPU_DYNAMIC_SHARED_MEM_DECL(sizeof(Index), unsigned char, smem); Index N = input_outer_dim_sizes_data.size; Index* ptr = reinterpret_cast(smem); Index* smem_input_outer_dim_sizes = ptr; ptr += N; Index* smem_inner_dim_sizes = ptr; ptr += N; Index* smem_output_outer_dim_sizes = ptr; ptr += N; Index* smem_stripe_offsets = ptr; for (int x = threadIdx.x; x < N; x += blockDim.x) { smem_input_outer_dim_sizes[x] = input_outer_dim_sizes[x]; smem_inner_dim_sizes[x] = inner_dim_sizes[x]; smem_output_outer_dim_sizes[x] = output_outer_dim_sizes[x]; } for (int x = threadIdx.x; x < N + 1 /*stripe_offsets_data.size*/; x += blockDim.x) { smem_stripe_offsets[x] = stripe_offsets[x]; } __syncthreads(); stripe_offsets = smem_stripe_offsets; input_outer_dim_sizes = smem_input_outer_dim_sizes; inner_dim_sizes = smem_inner_dim_sizes; output_outer_dim_sizes = smem_output_outer_dim_sizes; Index i = 0; for (Index stripe_index : GpuGridRangeX(total_stripe_count)) { // Determine the abstract computation unit amd local_stripe_index while (stripe_offsets[i + 1] <= stripe_index) ++i; Index local_stripe_index = stripe_index - stripe_offsets[i]; auto input_outer_dim_size = input_outer_dim_sizes[i]; auto inner_dim_size = inner_dim_sizes[i]; auto output_outer_dim_size = output_outer_dim_sizes[i]; if (input_outer_dim_size == 0 || inner_dim_size == 0 || output_outer_dim_size == 0) continue; auto segment_ids = segment_idss[i]; auto input = inputs[i]; auto output = outputs[i]; // Start computation: segment sum const Index segment_offset = local_stripe_index % inner_dim_size; const Index input_outer_dim_index_base = local_stripe_index / inner_dim_size * Index(OuterDimTileSize); T sum = T(0); Index first_segment_id = segment_ids[input_outer_dim_index_base]; Index last_output_segment_id = output_outer_dim_size; const Index actual_stripe_height = min(Index(OuterDimTileSize), input_outer_dim_size - input_outer_dim_index_base); // #pragma unroll for (Index j = 0; j < actual_stripe_height; j++) { Index current_output_segment_id = segment_ids[input_outer_dim_index_base + j]; // Decide whether to write result to global memory. // Result is only written to global memory if we move // to another segment. Otherwise we can keep accumulating // locally. if (current_output_segment_id > last_output_segment_id) { const Index output_index = last_output_segment_id * inner_dim_size + segment_offset; // decide whether to write result to global memory using atomic // operations if (last_output_segment_id == first_segment_id) { GpuAtomicAdd(output + output_index, sum); } else { *(output + output_index) = sum; } sum = T(0); } sum += ldg(input + (input_outer_dim_index_base + j) * inner_dim_size + segment_offset); last_output_segment_id = current_output_segment_id; } // For the last result in a strip, always write using atomic operations // due to possible race conditions with threads computing // the following strip. const Index output_index = last_output_segment_id * inner_dim_size + segment_offset; GpuAtomicAdd(output + output_index, sum); } } // Returns true if the three tensors have valid number of elements // If shape_input has 0 elements, then we need to have indices and updates with // exactly 0 elements too, otherwise we should error. If indices has 0 elements // then updates should also have 0 elements, otherwise we should error. bool ValidEmptyOutputShape(int64 num_inputs, int64 num_indices, int64 num_updates) { if (num_indices == 0 && num_updates == 0) { return true; // regardless of num_inputs ?= 0, covers both cases } // now we want all 3 tensors to have values return (num_inputs != 0 && num_indices != 0 && num_updates != 0); } template class FusedSegmentSumGPU : public OpKernel { public: explicit FusedSegmentSumGPU(OpKernelConstruction* ctx) : OpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("N", &N_)); } void Compute(OpKernelContext* ctx) override { GPUDevice gpu_device = ctx->eigen_device(); const int OuterDimTileSize = 8; Index stripe_offset = 0; // max as total_stripe_count GpuDeviceArrayOnHost stripe_offsets(ctx, N_ + 1); OP_REQUIRES_OK(ctx, stripe_offsets.Init()); OpInputList indices_list; OP_REQUIRES_OK(ctx, ctx->input_list("indices", &indices_list)); OpInputList updates_list; OP_REQUIRES_OK(ctx, ctx->input_list("updates", &updates_list)); OpInputList shape_list; OP_REQUIRES_OK(ctx, ctx->input_list("shape", &shape_list)); OpOutputList outputs; OP_REQUIRES_OK(ctx, ctx->output_list("outputs", &outputs)); GpuDeviceArrayOnHost indices_ptrs(ctx, N_); // TODO(peng): concat then memcpy if necessary OP_REQUIRES_OK(ctx, indices_ptrs.Init()); GpuDeviceArrayOnHost updates_ptrs(ctx, N_); OP_REQUIRES_OK(ctx, updates_ptrs.Init()); GpuDeviceArrayOnHost output_ptrs(ctx, N_); OP_REQUIRES_OK(ctx, output_ptrs.Init()); GpuDeviceArrayOnHost input_outer_dim_sizes(ctx, N_); OP_REQUIRES_OK(ctx, input_outer_dim_sizes.Init()); GpuDeviceArrayOnHost inner_dim_sizes(ctx, N_); OP_REQUIRES_OK(ctx, inner_dim_sizes.Init()); GpuDeviceArrayOnHost output_outer_dim_sizes(ctx, N_); OP_REQUIRES_OK(ctx, output_outer_dim_sizes.Init()); // Shared memory used by four typed Device array. int smem_usage = sizeof(Index) * (4 * N_ + 1); for (int i = 0; i < N_; ++i) { const Tensor& indices = indices_list[i]; const Tensor& updates = updates_list[i]; const Tensor& shape_input = shape_list[i]; OP_REQUIRES(ctx, indices.shape().dims() >= 1, errors::InvalidArgument( "Indices shape must have rank at least one. Found:", indices.shape().DebugString())); OP_REQUIRES(ctx, updates.shape().dims() >= 1, errors::InvalidArgument( "Updates shape must have rank at least one. Found:", updates.shape().DebugString())); auto vec = shape_input.flat(); TensorShape output_shape; OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape(vec.data(), vec.size(), &output_shape)); OP_REQUIRES(ctx, ValidEmptyOutputShape(shape_input.NumElements(), indices.shape().num_elements(), updates.shape().num_elements()), errors::InvalidArgument( "Indices and updates specified for empty output shape")); OP_REQUIRES(ctx, shape_input.dims() == 1, errors::InvalidArgument("Shape must be a vector")); // Index input_total_size = updates.NumElements(); auto input_shape = updates.shape(); Index input_outer_dim_size = input_shape.dim_size(0); Index inner_dim_size = 1; for (int j = 1; j < input_shape.dims(); ++j) inner_dim_size *= input_shape.dim_size(j); input_outer_dim_sizes.Set(i, input_outer_dim_size); inner_dim_sizes.Set(i, inner_dim_size); output_outer_dim_sizes.Set(i, output_shape.dim_size(0)); stripe_offsets.Set(i, stripe_offset); Index input_outer_dim_num_stripe = Eigen::divup(input_outer_dim_size, Index(OuterDimTileSize)); stripe_offset += input_outer_dim_num_stripe * inner_dim_size; // Tensor* out; OP_REQUIRES_OK(ctx, outputs.allocate(i, output_shape, &out)); gpu_device.memset(out->flat().data(), T(0), sizeof(T) * out->NumElements()); output_ptrs.Set(i, out->flat().data()); updates_ptrs.Set(i, updates.flat().data()); indices_ptrs.Set(i, indices.flat().data()); } const Index total_stripe_count = stripe_offset; stripe_offsets.Set(N_, stripe_offset); OP_REQUIRES_OK(ctx, stripe_offsets.Finalize()); OP_REQUIRES_OK(ctx, input_outer_dim_sizes.Finalize()); OP_REQUIRES_OK(ctx, inner_dim_sizes.Finalize()); OP_REQUIRES_OK(ctx, output_outer_dim_sizes.Finalize()); OP_REQUIRES_OK(ctx, indices_ptrs.Finalize()); OP_REQUIRES_OK(ctx, updates_ptrs.Finalize()); OP_REQUIRES_OK(ctx, output_ptrs.Finalize()); auto config = GetGpuLaunchConfig(total_stripe_count, gpu_device); GpuLaunchKernel( FusedSortedSegmentSumCustomKernel, config.block_count, config.thread_per_block, /*shared_memory_size_bytes=*/smem_usage, gpu_device.stream(), input_outer_dim_sizes.data(), inner_dim_sizes.data(), output_outer_dim_sizes.data(), indices_ptrs.data(), updates_ptrs.data(), output_ptrs.data(), stripe_offsets.data(), total_stripe_count); } private: int N_; }; // a struct that stores a mapping from the current column to the base address in // the tensor after split template struct __align__(16) ColumnInfo { T* __restrict__ base_address; // this includes the current slice index int slice_len; }; template struct __align__(16) TableInfo { int c_offset; int r_offset; T* __restrict__ embs; }; // given the row and column index in output, // sum up the corresponding column of the rows that need to be reduced template __device__ __forceinline__ void fused_reduce_and_split( int r_group_id, const int* __restrict__ row_prefix, ColumnInfo col_info, const float* __restrict__ emb, int emb_len) { int row_end = ldg(row_prefix + r_group_id + 1); T sum = T(0); for (int r = ldg(row_prefix + r_group_id); r < row_end; r++) sum += ldg(emb + r * emb_len); col_info.base_address[r_group_id * col_info.slice_len] = sum; } // initialize shared memory so that // first (N + 1) * sizeof(TableInfo) byte is the array of TableInfo // and the following (N + 1) * sizeof(int) byte is the array of int (table // splits) template __device__ __forceinline__ const int* init_shared_mem( GpuDeviceArrayStruct& _table_splits, // N + 1 GpuDeviceArrayStruct, 0>& _table_infos, // N + 1 int* shared_mem) { int table_info_sz = _table_splits.size * (sizeof(TableInfo) / sizeof(int)); auto s_table_infos = shared_mem; auto s_table_splits = shared_mem + table_info_sz; auto g_table_splits = GetGpuDeviceArrayOnDevice(&_table_splits); auto g_table_infos = reinterpret_cast(GetGpuDeviceArrayOnDevice(&_table_infos)); int total_shared_sz = table_info_sz + _table_splits.size; for (int i = threadIdx.x; i < total_shared_sz; i += blockDim.x) { if (i < table_info_sz) { s_table_infos[i] = g_table_infos[i]; } else { int j = i - table_info_sz; s_table_splits[j] = g_table_splits[j]; } } __syncthreads(); return s_table_splits; } // GpuDeviceArrayStruct guarantees stroing address in global memory // so __ldg can works properly // this kernel works the best when the number of rows to be reduced is // relatively even across the batch dimension, so that each threads' workload is // about the same template __global__ void FusedReduceSumAndSplitKernel( GpuDeviceArrayStruct _table_splits, // N + 1 GpuDeviceArrayStruct, 0> _table_infos, // N + 1 const ColumnInfo* __restrict__ col_infos, const int* __restrict__ row_splits, int total // =total number of elements in output (tables * n_rows after // reduction * emb_len) ) { extern __shared__ int shared_mem[]; auto table_infos = reinterpret_cast*>(shared_mem); auto table_splits = init_shared_mem(_table_splits, _table_infos, shared_mem); int table_idx = 1; for (int outer_tid = threadIdx.x + blockIdx.x * blockDim.x; outer_tid < total; outer_tid += blockDim.x * gridDim.x) { while (outer_tid >= table_splits[table_idx]) table_idx++; table_idx -= 1; int idx = outer_tid - table_splits[table_idx]; auto table = table_infos[table_idx]; auto next_table = table_infos[table_idx + 1]; int emb_len = next_table.c_offset - table.c_offset; int row_idx = idx / emb_len; int col_idx = idx % emb_len; fused_reduce_and_split(row_idx, row_splits + table.r_offset, col_infos[table.c_offset + col_idx], // 16-byte load table.embs + col_idx, emb_len); } } template const ValueType* GetGpuDeviceArrayOnHost( const GpuDeviceArrayStruct* data) { if (data->size <= MaxInlineValues) { return data->inline_values; } else { return data->out_of_line_values; } } template class FusedReduceAndSplitGPU : public OpKernel { public: explicit FusedReduceAndSplitGPU(OpKernelConstruction* ctx) : OpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("N", &N_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("slice_dims", &slice_dims_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("row_split_splits", &row_split_splits_)); } void Compute(OpKernelContext* ctx) override { profiler::TraceMe activity( []() { return "FusedReduceAndSplitGPUPreprocessing"; }); const auto& gpu_device = ctx->eigen_gpu_device(); OpInputList updates_list; OpOutputList outputs; OP_REQUIRES_OK(ctx, ctx->input_list("updates", &updates_list)); OP_REQUIRES_OK(ctx, ctx->output_list("outputs", &outputs)); GpuDeviceArrayOnHost table_splits(ctx, N_ + 1); GpuDeviceArrayOnHost, 0> table_infos(ctx, N_ + 1); OP_REQUIRES_OK(ctx, table_splits.Init()); OP_REQUIRES_OK(ctx, table_infos.Init()); // precalculate the size we need for storing row splits and column infos int col_info_size = 0; // number of rows after reduction = batch_size + 1 int batch_size = row_split_splits_[1] - row_split_splits_[0] - 1; FusedAlignedOutputAllocator fao_alloc( ctx); for (int i = 0; i < N_; i++) { const Tensor& updates = updates_list[i]; table_splits.Set(i, fao_alloc.get_unaligned_total()); table_infos.Set( i, {col_info_size, row_split_splits_[i], updates.flat().data()}); int emb_len = updates.dim_size(1); // number of rows after reduction * embedding length fao_alloc.add_slice(batch_size * emb_len); col_info_size += emb_len; } table_splits.Set(N_, fao_alloc.get_unaligned_total()); table_infos.Set(N_, {col_info_size, row_split_splits_[N_], nullptr}); OP_REQUIRES_OK(ctx, table_splits.Finalize()); OP_REQUIRES_OK(ctx, table_infos.Finalize()); // since these arrays are used as the backing storage, // we don't allow them to store values inline // because that will be on the host's stack and will not be // device-accessible GpuDeviceArrayOnHost, 0> col_infos(ctx, col_info_size); OP_REQUIRES_OK(ctx, col_infos.Init()); // hide some latency by the h2ds above fao_alloc.allocate(outputs.expected_output_dtype(0)); col_info_size = 0; for (int i = 0; i < slice_dims_.size(); i++) { int slice_len = slice_dims_[i]; Tensor out = fao_alloc.get_slice({batch_size, slice_len}); auto data = out.flat().data(); outputs.set(i, std::move(out)); for (int k = 0; k < slice_len; k++) { col_infos.Set(col_info_size++, {data + k, slice_len}); } } OP_REQUIRES_OK(ctx, col_infos.Finalize()); auto smem_sz = (sizeof(TableInfo) + sizeof(int)) * (N_ + 1); auto config = GetGpuLaunchConfig(fao_alloc.get_unaligned_total(), gpu_device, FusedReduceSumAndSplitKernel, smem_sz, 0); auto grid_offset = 24; char* ptr = std::getenv("MONOLITH_GT_OVERSUB_SM"); if (ptr) grid_offset = std::atoi(ptr); grid_offset += 2; GpuLaunchKernel(FusedReduceSumAndSplitKernel, config.block_count - grid_offset, config.thread_per_block, smem_sz, gpu_device.stream(), table_splits.data(), table_infos.data(), GetGpuDeviceArrayOnHost(&col_infos.data()), ctx->input(0).vec().data(), // row_splits base ptr fao_alloc.get_unaligned_total()); } private: int N_; std::vector slice_dims_, row_split_splits_; }; template __device__ __forceinline__ void fused_reduce_and_split_grad( int r_group_id, const int* __restrict__ row_splits, int count, ColumnInfo col_info, float* __restrict__ emb, int emb_len) { // from https://en.cppreference.com/w/cpp/algorithm/upper_bound int it, step, first = 0; while (count > 0) { it = first; step = count / 2; it += step; // ldg will not be helpful here due to irregular pattern if (!(r_group_id < row_splits[it])) { first = ++it; count -= step + 1; } else { count = step; } } // find the largest first element in row_splits >= r_group_id emb[r_group_id * emb_len] = col_info.base_address[(first - 1) * col_info.slice_len]; } // workload of threads are perfectly balanced: (1 load + 1 store) * (total / // num_threads) the bottleneck of this kernel is the binary search in the device // function above template __global__ void FusedReduceSumAndSplitKernelGrad( GpuDeviceArrayStruct _table_splits, // N + 1 GpuDeviceArrayStruct, 0> _table_infos, // N + 1 const ColumnInfo* __restrict__ col_infos, const int* __restrict__ row_splits, int total // =total number of elements in output (tables * n_rows before // reduction * emb_len) ) { extern __shared__ int shared_mem[]; auto table_infos = reinterpret_cast*>(shared_mem); auto table_splits = init_shared_mem(_table_splits, _table_infos, shared_mem); int table_idx = 1; for (int outer_tid = threadIdx.x + blockIdx.x * blockDim.x; outer_tid < total; outer_tid += blockDim.x * gridDim.x) { while (outer_tid >= table_splits[table_idx]) table_idx++; table_idx -= 1; int idx = outer_tid - table_splits[table_idx]; auto table = table_infos[table_idx]; auto next_table = table_infos[table_idx + 1]; int emb_len = next_table.c_offset - table.c_offset; int row_len = next_table.r_offset - table.r_offset; int row_idx = idx / emb_len; int col_idx = idx % emb_len; fused_reduce_and_split_grad( row_idx, row_splits + table.r_offset, row_len, col_infos[table.c_offset + col_idx], // 16-byte load table.embs + col_idx, emb_len); } } template class FusedReduceAndSplitGPUGrad : public OpKernel { public: explicit FusedReduceAndSplitGPUGrad(OpKernelConstruction* ctx) : OpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("slice_dims", &slice_dims_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("row_split_splits", &row_split_splits_)); N_ = row_split_splits_.size() - 1; // =num_tables } void Compute(OpKernelContext* ctx) override { profiler::TraceMe activity( []() { return "FusedReduceAndSplitGPUGradPreprocessing"; }); const auto& gpu_device = ctx->eigen_gpu_device(); OpInputList slice_list, updates_list; OpOutputList outputs; OP_REQUIRES_OK(ctx, ctx->input_list("updates", &updates_list)); OP_REQUIRES_OK(ctx, ctx->input_list("slices", &slice_list)); OP_REQUIRES_OK(ctx, ctx->output_list("outputs", &outputs)); GpuDeviceArrayOnHost table_splits(ctx, N_ + 1); GpuDeviceArrayOnHost, 0> table_infos(ctx, N_ + 1); OP_REQUIRES_OK(ctx, table_splits.Init()); OP_REQUIRES_OK(ctx, table_infos.Init()); FusedAlignedOutputAllocator fao_alloc( ctx); int col_info_size = 0; for (int i = 0; i < N_; i++) { // number of rows before reduction int num_rows = updates_list[i].dim_size(0); int emb_len = updates_list[i].dim_size(1); table_splits.Set(i, fao_alloc.get_unaligned_total()); fao_alloc.add_slice(num_rows * emb_len); col_info_size += emb_len; } // check for overflow OP_REQUIRES(ctx, fao_alloc.get_unaligned_total() <= INT_MAX, errors::InvalidArgument("There are too many elements to be " "processed by fused reduce and split")); table_splits.Set(N_, fao_alloc.get_unaligned_total()); OP_REQUIRES_OK(ctx, table_splits.Finalize()); // col_info stores the base src address of each slice GpuDeviceArrayOnHost, 0> col_infos(ctx, col_info_size); OP_REQUIRES_OK(ctx, col_infos.Init()); col_info_size = 0; for (int i = 0; i < slice_dims_.size(); i++) { int slice_len = slice_dims_[i]; auto data = slice_list[i].flat().data(); for (int j = 0; j < slice_len; j++) { col_infos.Set(col_info_size++, {data + j, slice_len}); } } OP_REQUIRES_OK(ctx, col_infos.Finalize()); // hide some latency by the h2ds above fao_alloc.allocate(outputs.expected_output_dtype(0)); col_info_size = 0; for (int i = 0; i < N_; i++) { int num_rows = updates_list[i].dim_size(0); int emb_len = updates_list[i].dim_size(1); Tensor out = fao_alloc.get_slice({num_rows, emb_len}); table_infos.Set( i, {col_info_size, row_split_splits_[i], out.flat().data()}); outputs.set(i, std::move(out)); col_info_size += emb_len; } table_infos.Set(N_, {col_info_size, row_split_splits_[N_], nullptr}); OP_REQUIRES_OK(ctx, table_infos.Finalize()); auto smem_sz = (sizeof(TableInfo) + sizeof(int)) * (N_ + 1); auto config = GetGpuLaunchConfig(fao_alloc.get_unaligned_total(), gpu_device, FusedReduceSumAndSplitKernelGrad, smem_sz, 0); GpuLaunchKernel(FusedReduceSumAndSplitKernelGrad, config.block_count, config.thread_per_block, smem_sz, gpu_device.stream(), table_splits.data(), table_infos.data(), GetGpuDeviceArrayOnHost(&col_infos.data()), ctx->input(0).vec().data(), fao_alloc.get_unaligned_total()); } private: int N_; std::vector slice_dims_; // fused array of slice dimensions std::vector row_split_splits_; }; #define REGISTER_FUSED_REDUCE_AND_SPLIT_GRAD(type) \ REGISTER_KERNEL_BUILDER(Name("MonolithFusedReduceAndSplitGPUGrad") \ .Device(DEVICE_GPU) \ .TypeConstraint("T"), \ FusedReduceAndSplitGPUGrad) TF_CALL_float(REGISTER_FUSED_REDUCE_AND_SPLIT_GRAD); REGISTER_OP("MonolithFusedReduceAndSplitGPUGrad") .Input("splits: int32") // input of the forward op .Input("updates: M * T") // input of the forward op, needed to do shape // inference .Input("slices: N * T") // output of the forward op .Output("outputs: M * T") .Attr("slice_dims: list(int)") // from the forward op .Attr("row_split_splits: list(int)") // from the forward op .Attr("T: type") .Attr("N: int") .Attr("M: int") .SetShapeFn([](shape_inference::InferenceContext* c) { int N, M; std::vector slice_dims, row_split_splits; TF_RETURN_IF_ERROR(c->GetAttr("N", &N)); TF_RETURN_IF_ERROR(c->GetAttr("M", &M)); TF_RETURN_IF_ERROR(c->GetAttr("slice_dims", &slice_dims)); TF_RETURN_IF_ERROR(c->GetAttr("row_split_splits", &row_split_splits)); // simple sanity checks if (slice_dims.size() != N) return errors::InvalidArgument( "len(slice_dims) must equal to the number of input slices"); if (row_split_splits.size() != M + 1) return errors::InvalidArgument( "len(row_split_splits) must equal to M + 1"); for (int i = 0; i < M; i++) { c->set_output(i, c->input(1 + i)); } return Status::OK(); }); #define REGISTER_FUSED_REDUCE_AND_SPLIT(type) \ REGISTER_KERNEL_BUILDER(Name("MonolithFusedReduceAndSplitGPU") \ .Device(DEVICE_GPU) \ .TypeConstraint("T"), \ FusedReduceAndSplitGPU) TF_CALL_float(REGISTER_FUSED_REDUCE_AND_SPLIT); REGISTER_OP("MonolithFusedReduceAndSplitGPU") .Input("splits: int32") .Input("updates: N * T") .Output("outputs: num_slices * T") .Attr("num_slices: int >= 1") .Attr("slice_dims: list(int)") .Attr("row_split_splits: list(int)") .Attr("T: type") .Attr("N: int") .SetShapeFn([](shape_inference::InferenceContext* c) { int N; std::vector slice_dims, row_split_splits; TF_RETURN_IF_ERROR(c->GetAttr("N", &N)); TF_RETURN_IF_ERROR(c->GetAttr("slice_dims", &slice_dims)); TF_RETURN_IF_ERROR(c->GetAttr("row_split_splits", &row_split_splits)); // simple sanity checks int num_outputs = c->num_outputs(); if (slice_dims.size() != num_outputs) return errors::InvalidArgument( "len(slice_dims) must equal to num_slices"); int batch_size = row_split_splits[1] - row_split_splits[0] - 1; for (int i = 0; i < num_outputs; i++) { auto output_shape = c->MakeShape({batch_size, slice_dims[i]}); c->set_output(i, output_shape); } return Status::OK(); }); #define REGISTER_FUSED_SCATTER_ND_KERNEL_INDEX(type, index_type) \ REGISTER_KERNEL_BUILDER(Name("MonolithFusedSegmentSum") \ .Device(DEVICE_GPU) \ .TypeConstraint("T") \ .TypeConstraint("Tindices") \ .HostMemory("shape"), \ FusedSegmentSumGPU) #define REGISTER_FUSED_SCATTER_ND_KERNEL(type) \ REGISTER_FUSED_SCATTER_ND_KERNEL_INDEX(type, int32); \ REGISTER_FUSED_SCATTER_ND_KERNEL_INDEX(type, int64); TF_CALL_float(REGISTER_FUSED_SCATTER_ND_KERNEL); // TF_CALL_GPU_NUMBER_TYPES(REGISTER_FUSED_SCATTER_ND_KERNEL); #undef REGISTER_FUSED_SCATTER_ND_KERNEL #undef REGISTER_FUSED_SCATTER_ND_KERNEL_INDEX REGISTER_OP("MonolithFusedSegmentSum") .Input("indices: N * Tindices") .Input("updates: N * T") .Input("shape: N * Tindices") .Output("outputs: N * T") .Attr("T: type") .Attr("Tindices: {int32, int64}") .Attr("N: int") .SetShapeFn([](shape_inference::InferenceContext* c) { int N; TF_RETURN_IF_ERROR(c->GetAttr("N", &N)); for (int i = N - 1; i >= 0; --i) { shape_inference::ShapeHandle indices_shape; TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(i), 1, &indices_shape)); shape_inference::ShapeHandle updates_shape; TF_RETURN_IF_ERROR( c->WithRankAtLeast(c->input(N + i), 1, &updates_shape)); shape_inference::ShapeHandle output_shape; TF_RETURN_IF_ERROR( c->MakeShapeFromShapeTensor(2 * N + i, &output_shape)); shape_inference::ShapeHandle expanded_indices_shape; // mimic expand_dims(indices, -1) TF_RETURN_IF_ERROR(c->Concatenate(indices_shape, c->Vector(1), &expanded_indices_shape)); TF_RETURN_IF_ERROR(shape_inference::ScatterNdShapeHelper( c, expanded_indices_shape, updates_shape, output_shape)); // set shape to output 0 if (c->input_handle_shapes_and_types(0) == nullptr && c->num_outputs() > 0) { c->set_output(i, c->output(0)); } } return Status::OK(); }); } // namespace monolith_tf } // namespace tensorflow #endif // GOOGLE_CUDA ================================================ FILE: monolith/native_training/runtime/ops/remote_predict_op.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ // forked from: // https://github.com/tensorflow/serving/blob/2.4.0/tensorflow_serving/experimental/tensorflow/ops/remote_predict/kernels/remote_predict_op_kernel.h // https://github.com/tensorflow/serving/blob/2.4.0/tensorflow_serving/experimental/tensorflow/ops/remote_predict/kernels/remote_predict_op_kernel.cc // https://github.com/tensorflow/serving/blob/2.4.0/tensorflow_serving/experimental/tensorflow/ops/remote_predict/ops/remote_predict_op.cc // with: // "#ifndef // TENSORFLOW_SERVING_EXPERIMENTAL_TENSORFLOW_OPS_REMOTE_PREDICT_KERNELS_REMOTE_PREDICT_OP_KERNEL_H_ // ..." removed #include "absl/status/status.h" #include "absl/time/time.h" #include "glog/logging.h" #include "google/protobuf/map.h" #include "google/protobuf/wrappers.pb.h" #include "monolith/native_training/runtime/common/metrics.h" #include "monolith/native_training/runtime/ops/agent_heartbeat.h" #include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/profiler/lib/traceme.h" #include "tensorflow/core/protobuf/named_tensor.pb.h" #include "tensorflow_serving/apis/model.pb.h" #include "tensorflow_serving/apis/predict.pb.h" namespace tensorflow { namespace monolith_tf { typedef google::protobuf::Map AliasTensorMap; using ::tensorflow::serving::PredictRequest; using ::tensorflow::serving::PredictResponse; // Remote Predict Op kernel implementation class templated on different // PredictionServiceStubTypes. template class RemotePredictOp : public AsyncOpKernel { public: explicit RemotePredictOp(OpKernelConstruction *context) : AsyncOpKernel(context) { OP_REQUIRES_OK(context, context->GetAttr("model_name", &model_name_)); OP_REQUIRES_OK(context, context->GetAttr("model_version", &model_version_)); OP_REQUIRES_OK(context, context->GetAttr("max_rpc_deadline_millis", &max_rpc_deadline_millis_)); OP_REQUIRES_OK(context, context->GetAttr("fail_op_on_rpc_error", &fail_op_on_rpc_error_)); OP_REQUIRES_OK(context, context->GetAttr("signature_name", &signature_name_)); OP_REQUIRES_OK(context, context->GetAttr("old_model_name", &old_model_name_)); OP_REQUIRES_OK(context, context->GetAttr("task", &task_)); if (AgentHeartbeat::GetInstance().api_version() == 0 && old_model_name_.size() > 0) { req_model_name_ = old_model_name_; } else { req_model_name_ = model_name_; } } void ComputeAsync(OpKernelContext *context, DoneCallback done) override { auto activity = std::make_shared([this]() { return name(); }); auto remote_predict_op_latency_start = std::chrono::system_clock::now(); // Get the input tensor alias names. const auto &input_tensor_aliases = context->input(0).flat(); // Get the input tensors. OpInputList input_tensors; OP_REQUIRES_OK_ASYNC( context, context->input_list("input_tensors", &input_tensors), done); // Get the output tensor alias names. // Directly index to output_tensor_aliases by moving past all the input // before it, including the input_tensor_aliases and input_tensors. auto output_tensor_aliases = context->input(1 + input_tensors.size()).flat(); // Build the PredictRequest. std::shared_ptr request(new PredictRequest); request->mutable_model_spec()->set_name(req_model_name_); request->mutable_model_spec()->set_signature_name(signature_name_); if (model_version_ >= 0) { request->mutable_model_spec()->mutable_version()->set_value( model_version_); } AliasTensorMap &inputs = *request->mutable_inputs(); for (int i = 0; i < input_tensor_aliases.size(); ++i) { tensorflow::TensorProto proto; input_tensors[i].AsProtoField(&proto); inputs[input_tensor_aliases(i)] = proto; } for (int i = 0; i < output_tensor_aliases.size(); ++i) { request->add_output_filter(tensorflow::string(output_tensor_aliases(i))); } std::shared_ptr response(new PredictResponse()); std::shared_ptr prediction_service = nullptr; if (AgentHeartbeat::GetInstance().api_version() == 0) { if (old_model_name_.size() > 0) { size_t pos = old_model_name_.find("_"); std::string real_model_name = old_model_name_; if (pos != std::string::npos) { real_model_name.replace(pos, 1, ":"); } LOG_FIRST_N(INFO, 3) << "GetPredictionService by old_model_name_:" << real_model_name; prediction_service = AgentHeartbeat::GetInstance().GetPredictionService(real_model_name); } else { LOG_FIRST_N(INFO, 3) << "GetPredictionService by task_:" << task_; prediction_service = AgentHeartbeat::GetInstance().GetPredictionServiceByIdx(task_); } } else { prediction_service = AgentHeartbeat::GetInstance().GetPredictionService(model_name_); } OP_REQUIRES_ASYNC( context, prediction_service != nullptr, errors::Unavailable("No available remote servers. model_name=", model_name_, ",signature_name=", signature_name_), done); auto serving_latency_start = std::chrono::system_clock::now(); auto callback = [this, context, request, response, activity, output_tensor_aliases, done, serving_latency_start, remote_predict_op_latency_start, prediction_service]( const absl::Status &status, DoneCallback &&rpc_done) { std::ostringstream tagkv; tagkv << "model_name=" << model_name_ << "|signature_name=" << signature_name_; auto serving_latency_end = std::chrono::system_clock::now(); std::chrono::duration serving_latency_diff = std::chrono::duration_cast( serving_latency_end - serving_latency_start); monolith::GetMetrics()->emit_timer( "serving_latency", serving_latency_diff.count(), tagkv.str()); LOG_EVERY_N(INFO, 1000) << "emit_timer serving_latency " << tagkv.str(); PostProcessResponse(context, response.get(), status, fail_op_on_rpc_error_, output_tensor_aliases, std::forward(rpc_done)); auto remote_predict_op_latency_end = std::chrono::system_clock::now(); std::chrono::duration remote_predict_op_latency_diff = std::chrono::duration_cast( remote_predict_op_latency_end - remote_predict_op_latency_start); monolith::GetMetrics()->emit_timer("remote_predict_op_latency", remote_predict_op_latency_diff.count(), tagkv.str()); monolith::GetMetrics()->emit_counter("remote_predict_op_throughput", 1, tagkv.str()); LOG_EVERY_N(INFO, 1000) << "emit_timer remote_predict_op_latency " << tagkv.str(); }; // Make the RPC call. prediction_service->Predict(request.get(), response.get(), callback, max_rpc_deadline_millis_, done); } void PostProcessResponse(OpKernelContext *context, PredictResponse *response, const absl::Status &rpc_status, bool fail_op_on_rpc_error, TTypes::Flat output_tensor_aliases, DoneCallback rpc_done) { auto rpc_cleaner = gtl::MakeCleanup([&] { rpc_done(); }); Tensor *status_code; OP_REQUIRES_OK_ASYNC( context, context->allocate_output(0, TensorShape({}), &status_code), rpc_cleaner.release()); status_code->scalar()() = static_cast(rpc_status.code()); Tensor *status_error_message; OP_REQUIRES_OK_ASYNC( context, context->allocate_output(1, TensorShape({}), &status_error_message), rpc_cleaner.release()); status_error_message->scalar()() = rpc_status.message(); OpOutputList output_tensors_list; OP_REQUIRES_OK_ASYNC( context, context->output_list("output_tensors", &output_tensors_list), rpc_cleaner.release()); // Process the response. if (!rpc_status.ok()) { if (fail_op_on_rpc_error) { OP_REQUIRES_OK_ASYNC( context, tensorflow::Status(static_cast( rpc_status.code()), rpc_status.message()), rpc_cleaner.release()); } else { // Allocate some empty output for the output_tensors. for (int i = 0; i < output_tensors_list.size(); ++i) { Tensor *unused; OP_REQUIRES_OK_ASYNC(context, output_tensors_list.allocate( i, TensorShape({}), &unused), rpc_cleaner.release()); } return; } } OP_REQUIRES_ASYNC( context, output_tensors_list.size() == output_tensor_aliases.size(), errors::Internal( "Response doesn't have the right number of outputs; actual: ", output_tensors_list.size(), " expected: ", output_tensor_aliases.size()), rpc_cleaner.release()); AliasTensorMap &outputs = *response->mutable_outputs(); for (int i = 0; i < output_tensor_aliases.size(); i++) { Tensor output_tensor; OP_REQUIRES_ASYNC( context, output_tensor.FromProto(outputs[output_tensor_aliases(i)]), errors::Internal("Response tensor proto: ", tensorflow::string(output_tensor_aliases(i)), " cannot be converted back to a tensor."), rpc_cleaner.release()); output_tensors_list.set(i, output_tensor); } } private: std::string model_name_; int64 model_version_; bool fail_op_on_rpc_error_; int64 max_rpc_deadline_millis_; std::string signature_name_; std::string old_model_name_; int task_; std::string req_model_name_; }; } // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/runtime/ops/remote_predict_op_grpc.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "monolith/native_training/runtime/ops/remote_predict_op.h" #include "monolith/native_training/runtime/ops/prediction_service_grpc.h" namespace tensorflow { namespace monolith_tf { namespace { REGISTER_KERNEL_BUILDER( Name("TfServingRemotePredict").Device(DEVICE_CPU), RemotePredictOp<::tensorflow::monolith_tf::PredictionServiceGrpc, ::tensorflow::monolith_tf::AgentHeartbeat< ::tensorflow::monolith_tf::PredictionServiceGrpc>>); REGISTER_OP("TfServingRemotePredict") .Attr("T: list(type)") .Attr("model_name: string = ''") .Attr("model_version: int = -1") .Attr("fail_op_on_rpc_error: bool = true") .Attr("max_rpc_deadline_millis: int = 30") .Attr("signature_name: string = 'serving_default'") .Attr("old_model_name: string = ''") .Attr("task: int = -1") .Input("input_tensor_aliases: string") .Input("input_tensors: T") .Input("output_tensor_aliases: string") .Output("status_code: int32") .Output("status_error_message: string") .Output("output_tensors: output_types") .Attr("output_types: list(type)") .SetShapeFn([](shape_inference::InferenceContext *c) { shape_inference::ShapeHandle unused; // Checks the length of input_tensor_aliases with that of input_tensors. std::vector input_aliases_handle; TF_RETURN_IF_ERROR( c->input("input_tensor_aliases", &input_aliases_handle)); TF_RETURN_IF_ERROR(c->WithRank(input_aliases_handle[0], 1, &unused)); std::vector inputs_handle; TF_RETURN_IF_ERROR(c->input("input_tensors", &inputs_handle)); if (c->Value(c->NumElements(input_aliases_handle[0])) != inputs_handle.size()) { return errors::InvalidArgument( "'input_tensors' should be equal in length to " "'input_tensor_aliases'. Length of 'input_tensors': ", inputs_handle.size(), ", length of 'input_tensor_aliases': ", c->Value(c->NumElements(input_aliases_handle[0]))); } // Checks the length of output_tensor_aliases with that of output_types. DataTypeVector output_types; TF_RETURN_IF_ERROR(c->GetAttr("output_types", &output_types)); std::vector output_aliases_handle; TF_RETURN_IF_ERROR( c->input("output_tensor_aliases", &output_aliases_handle)); if (c->Value(c->NumElements(output_aliases_handle[0])) != output_types.size()) { return errors::InvalidArgument( "'output_types' should be equal in length to " "'output_tensor_aliases'. Length of 'output_types': ", output_types.size(), ", length of 'output_tensor_aliases': ", c->Value(c->NumElements(output_aliases_handle[0]))); } // We know the shape of the first 2 outputs, but not the rest. TF_RETURN_IF_ERROR(c->set_output("status_code", {c->Scalar()})); TF_RETURN_IF_ERROR(c->set_output("status_error_message", {c->Scalar()})); for (int i = 2; i < c->num_outputs(); ++i) { c->set_output(i, c->UnknownShape()); } return Status::OK(); }) .SetIsStateful() .Doc(R"doc( Invokes Predict on a remote graph. fail_op_on_rpc_error: If set true, the Op fails if the rpc fails, and returns the status code as 0 and an empty status_message. Otherwise the Op returns the status of the rpc call, along with the output tensors, if any. Set true by default. max_rpc_deadline_millis: The rpc deadline for remote predict. The actual deadline is min(incoming_rpc_deadline, max_rpc_deadline_millis). task: The task id of ps/entry, if the server_type is entry, it would be 0. signature_name: the signature def for remote graph inference, defaulting to "serving_default". model_name: Model name of the remote TF graph. model_version: the target version for the Predict call. When unset, the default value (-1) implies the latest available version should be used. input_tensor_aliases: Tensor of strings for the input tensor alias names to supply to the RemotePredict call. input_tensors: List of tensors to provide as input. Should be equal in length to 'input_tensor_aliases'. output_tensor_aliases: Tensor of strings for the output tensor alias names to supply to the Predict call. status_code: Returns the status code of the rpc call; basically converting tensorflow::error::Code to it's int value, so 0 means OK. status_error_message: Returns the error message in the rpc status. output_tensors: Tensors returned by the Predict call on the remote graph, which are in the same order as output_tensor_aliases. output_types: A list of types of the output tensors. Length of this list should be equal to the length of 'output_tensor_aliases'. old_model_name: the model name used in previous monolith version (deprecated). task: the index of old_model_name (deprecated). )doc"); } // namespace } // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/runtime/ops/split_by_indices_op.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "absl/algorithm/container.h" #include "absl/container/flat_hash_set.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/platform/threadpool.h" namespace tensorflow { namespace monolith_tf { namespace { // Given an int64 tensor, split it into multiple tensors based on the value. template class SplitByIndicesOp : public OpKernel { public: explicit SplitByIndicesOp(OpKernelConstruction* ctx) : OpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("num_splits", &num_splits_)); } void Compute(OpKernelContext* ctx) override { const Tensor& indices = ctx->input(0); auto indices_vec = indices.vec(); const int64 num_elements = indices.NumElements(); const Tensor& input = ctx->input(1); const int64 element_size = num_elements == 0 ? 0 : input.NumElements() / num_elements; auto input_mat = input.shaped({num_elements, element_size}); // Here we use a naive implementation (No parallel) std::vector split_sizes(num_splits_, 0); for (int i = 0; i < num_elements; ++i) { ++split_sizes[indices_vec(i)]; } std::vector splitted(num_splits_); for (int i = 0; i < num_splits_; ++i) { TensorShape output_shape; output_shape.AddDim(split_sizes[i]); const TensorShape& shape = input.shape(); for (int j = 1; j < shape.dims(); ++j) { output_shape.AddDim(shape.dim_size(j)); } OP_REQUIRES_OK(ctx, ctx->allocate_output(i, output_shape, &splitted[i])); } std::vector::Matrix> splitted_mat; splitted_mat.reserve(num_splits_); for (int i = 0; i < num_splits_; ++i) { splitted_mat.emplace_back( splitted[i]->shaped({split_sizes[i], element_size})); } for (int64 i = num_elements - 1; i >= 0; --i) { int64 index = indices_vec(i); splitted_mat[index].template chip<0>(--split_sizes[index]) = input_mat.template chip<0>(i); } } private: int num_splits_; }; class SplitByIndicesGradientOp : public OpKernel { public: explicit SplitByIndicesGradientOp(OpKernelConstruction* ctx) : OpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("num_splits", &num_splits_)); } void Compute(OpKernelContext* ctx) override { const Tensor& indices = ctx->input(0); auto indices_vec = indices.vec(); const int64 num_elements = indices.NumElements(); std::vector split_sizes(num_splits_, 0); for (int64 i = 0; i < num_elements; ++i) { ++split_sizes[indices_vec(i)]; } const Tensor& input = ctx->input(1); const int64 element_size = num_elements == 0 ? 0 : input.NumElements() / num_elements; const int grads_offset = 2; std::vector::Matrix> grads_mat; grads_mat.reserve(num_splits_); for (int i = 0; i < num_splits_; ++i) { const Tensor& grad = ctx->input(grads_offset + i); grads_mat.emplace_back( grad.shaped({split_sizes[i], element_size})); } Tensor* input_grads; TensorShape input_grads_shape = ctx->input(grads_offset).shape(); input_grads_shape.set_dim(0, num_elements); OP_REQUIRES_OK(ctx, ctx->allocate_output(0, input_grads_shape, &input_grads)); auto input_grads_mat = input_grads->shaped({num_elements, element_size}); for (int64 i = num_elements - 1; i >= 0; --i) { int64 index = indices_vec(i); input_grads_mat.chip<0>(i) = grads_mat[index].chip<0>(--split_sizes[index]); } } private: int num_splits_; }; template class ReorderByIndicesOp : public OpKernel { public: explicit ReorderByIndicesOp(OpKernelConstruction* ctx) : OpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("num_of_shards", &num_of_shards_)); } void Compute(OpKernelContext* ctx) override { const Tensor& input = ctx->input(0); const Tensor& shard_ids = ctx->input(1); auto shard_id_vec = shard_ids.vec(); const int64 num_elements = shard_ids.NumElements(); // NOTE: element_size should always be 1 except for the test where we use // assign-add for initial assignment. const int64 element_size = num_elements == 0 ? 0 : input.NumElements() / num_elements; auto input_mat = input.shaped({num_elements, element_size}); // We first count the number of unique FIDs, and also calculate the size for // each shard. typename absl::flat_hash_set id_set; std::vector splits_offsets(num_of_shards_, 0); std::vector> ids_for_splits(num_of_shards_, std::vector{}); int64 uniq_id_size = 0; for (int i = 0; i < num_elements; ++i) { // First insertion if never sees it before. if (id_set.insert(input_mat(i, 0)).second) { auto index = shard_id_vec(i); ids_for_splits[index].emplace_back(i); ++(splits_offsets[index]); ++uniq_id_size; } } // We allocate the buffer here. TensorShape output_shape; output_shape.AddDim(uniq_id_size); const TensorShape& shape = input.shape(); for (int j = 1; j < shape.dims(); ++j) { output_shape.AddDim(shape.dim_size(j)); } Tensor *outputs, *output_sizes; OP_REQUIRES_OK(ctx, ctx->allocate_output(0, output_shape, &outputs)); OP_REQUIRES_OK(ctx, ctx->allocate_output(1, TensorShape{num_of_shards_}, &output_sizes)); // We assign the split sizes here. auto output_shape_vec = output_sizes->vec(); for (int i = 0; i < num_of_shards_; i++) { output_shape_vec(i) = splits_offsets[i]; if (i > 0) splits_offsets[i] += splits_offsets[i - 1]; } // We assign the reordered IDs here. typename TTypes::Matrix splitted_mat = outputs->shaped({uniq_id_size, element_size}); for (int i = num_of_shards_ - 1; i >= 0; --i) { for (const int64& j : ids_for_splits[i]) { splitted_mat.template chip<0>(--splits_offsets[i]) = input_mat.template chip<0>(j); } } } private: int num_of_shards_; }; class RaggedSplitByIndicesOp : public OpKernel { public: explicit RaggedSplitByIndicesOp(OpKernelConstruction* c) : OpKernel(c) { OP_REQUIRES_OK(c, c->GetAttr("num_splits", &num_splits_)); } void Compute(OpKernelContext* c) override { const auto indices = c->input(0).vec(); const auto num = c->input(1).vec(); const auto num_split = c->input(2).vec(); OP_REQUIRES(c, num.size() == indices.size(), errors::InvalidArgument( "Ragged tensor values size must match indices size. Got ", num.size(), " v.s. ", indices.size())); std::vector split_sizes(num_splits_, 0); for (int64 i = 0; i < indices.size(); ++i) { ++split_sizes[indices(i)]; } std::vector splitted_offsets(num_splits_, 0); std::vector::Vec> splitted_nums; std::vector::Vec> splitted_num_splits; std::vector::Vec> splitted_pos; for (int i = 0; i < num_splits_; ++i) { Tensor* t; OP_REQUIRES_OK(c, c->allocate_output(i, {split_sizes[i]}, &t)); splitted_nums.push_back(t->vec()); OP_REQUIRES_OK( c, c->allocate_output(i + num_splits_, {num_split.size()}, &t)); auto splitted_num_split = t->vec(); splitted_num_split(0) = 0; splitted_num_splits.push_back(splitted_num_split); OP_REQUIRES_OK( c, c->allocate_output(i + 2 * num_splits_, {split_sizes[i]}, &t)); splitted_pos.push_back(t->vec()); } int split_offset = 1; for (int64 i = 0;; ++i) { while (split_offset < num_split.size() && i == num_split(split_offset)) { for (int index = 0; index < num_splits_; ++index) { splitted_num_splits[index](split_offset) = splitted_offsets[index]; } ++split_offset; } if (i >= indices.size()) break; int64 index = indices(i); splitted_pos[index](splitted_offsets[index]) = i; splitted_nums[index](splitted_offsets[index]) = num(i); ++splitted_offsets[index]; } OP_REQUIRES(c, split_offset == num_split.size(), errors::InvalidArgument("The input ragged tensor is invalid.")); } private: int num_splits_; }; REGISTER_OP("MonolithRaggedSplitByIndices") .Input("indices: int64") .Input("num: int64") .Input("num_split: int64") .Output("splitted_nums: num_splits * int64") .Output("splitted_num_splits: num_splits * int64") .Output("splitted_pos: num_splits * int64") .Attr("num_splits: int") .SetShapeFn([](shape_inference::InferenceContext* c) { int num_splits; TF_RETURN_IF_ERROR(c->GetAttr("num_splits", &num_splits)); int offset = 0; for (int i = 0; i < num_splits; ++i) { c->set_output(i, c->Vector(c->UnknownDim())); } offset += num_splits; for (int i = 0; i < num_splits; ++i) { c->set_output(offset + i, c->input(2)); } offset += num_splits; for (int i = 0; i < num_splits; ++i) { c->set_output(offset + i, c->Vector(c->UnknownDim())); } return Status::OK(); }); REGISTER_KERNEL_BUILDER(Name("MonolithRaggedSplitByIndices").Device(DEVICE_CPU), RaggedSplitByIndicesOp); REGISTER_OP("MonolithSplitByIndices") .Input("indices: int64") .Input("input: T") .Output("splitted: num_splits * T") .Attr("num_splits: int") .Attr("T: type") .SetShapeFn([](shape_inference::InferenceContext* c) { int num_splits; TF_RETURN_IF_ERROR(c->GetAttr("num_splits", &num_splits)); shape_inference::ShapeHandle shape_handle = c->input(1); int rank = c->Rank(shape_handle); std::vector dim_handles; dim_handles.push_back(c->UnknownDim()); for (int i = 1; i < rank; ++i) { dim_handles.push_back(c->Dim(shape_handle, i)); } for (int i = 0; i < num_splits; ++i) { c->set_output(i, c->MakeShape(dim_handles)); } return Status::OK(); }); REGISTER_OP("MonolithSplitByIndicesGradient") .Input("indices: int64") .Input("input: float") .Input("grads: num_splits * float") .Output("input_grads: float") .Attr("num_splits: int") .SetShapeFn([](shape_inference::InferenceContext* c) { int num_splits; TF_RETURN_IF_ERROR(c->GetAttr("num_splits", &num_splits)); shape_inference::DimensionHandle num_elements = c->Dim(c->input(0), 0); shape_inference::ShapeHandle shape_handle = c->input(1); int rank = c->Rank(shape_handle); std::vector dim_handles; dim_handles.push_back(num_elements); for (int i = 1; i < rank; ++i) { dim_handles.push_back(c->Dim(shape_handle, i)); } c->set_output(0, c->MakeShape(dim_handles)); return Status::OK(); }); REGISTER_OP("MonolithReorderByIndices") .Input("input: T") .Input("shard_ids: int32") .Output("reordered_tensor: T") .Output("shard_sizes: int32") .Attr("num_of_shards: int") .Attr("T: type") .SetDoNotOptimize() // Crash with grappler. .SetShapeFn([](shape_inference::InferenceContext* c) { int num_of_shards; TF_RETURN_IF_ERROR(c->GetAttr("num_of_shards", &num_of_shards)); shape_inference::ShapeHandle shape_handle = c->input(0); int rank = c->Rank(shape_handle); std::vector dim_handles; dim_handles.push_back(c->UnknownDim()); for (int i = 1; i < rank; ++i) { dim_handles.push_back(c->Dim(shape_handle, i)); } // The first output is for the reshuffled first element. c->set_output(0, c->MakeShape(dim_handles)); // The second output is for the all2all sizes. c->set_output(1, c->MakeShape({num_of_shards})); return Status::OK(); }); #define REGISTER_KERNEL(type) \ REGISTER_KERNEL_BUILDER(Name("MonolithSplitByIndices") \ .Device(DEVICE_CPU) \ .TypeConstraint("T"), \ SplitByIndicesOp) REGISTER_KERNEL(float); REGISTER_KERNEL(int64); #undef REGISTER_KERNEL REGISTER_KERNEL_BUILDER( Name("MonolithSplitByIndicesGradient").Device(DEVICE_CPU), SplitByIndicesGradientOp); #define REGISTER_KERNEL_REORDER_BY_INDICES(type) \ REGISTER_KERNEL_BUILDER(Name("MonolithReorderByIndices") \ .Device(DEVICE_CPU) \ .TypeConstraint("T"), \ ReorderByIndicesOp) REGISTER_KERNEL_REORDER_BY_INDICES(float); REGISTER_KERNEL_REORDER_BY_INDICES(int64); #undef REGISTER_KERNEL_REORDER_BY_INDICES } // namespace } // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/runtime/ops/static_reshape_op.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/util/work_sharder.h" namespace tensorflow { namespace monolith_tf { class StaticReshapeNOp : public OpKernel { public: explicit StaticReshapeNOp(OpKernelConstruction* ctx) : OpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("shapes", &shapes_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("enable_parallelism", &enable_parallelism_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("cost_per_tensor", &cost_per_tensor_)); } void Compute(OpKernelContext* ctx) override { int num_inputs = ctx->num_inputs(); OP_REQUIRES(ctx, shapes_.size() == num_inputs, errors::InvalidArgument( "`shapes` size must equal to `inputs`, got shapes (", shapes_.size(), ") vs `inputs` (", num_inputs, ")")); Tensor* tensor_sizes = nullptr; OP_REQUIRES_OK( ctx, ctx->allocate_output(num_inputs, {num_inputs}, &tensor_sizes)); auto sizes_vec = tensor_sizes->vec(); auto thread_pool = ctx->device()->tensorflow_cpu_worker_threads()->workers; auto reshape = [&](int64_t begin, int64_t end) { for (int idx = static_cast(begin); idx < static_cast(end); idx++) { const Tensor& input = ctx->input(idx); int tensor_size = input.NumElements(); TensorShape shape; const PartialTensorShape& partial = shapes_.at(idx); // Maybe infer unk dim. int64_t product = 1; int unknown_index = -1; OP_REQUIRES(ctx, partial.dims() > 0, errors::InvalidArgument( "Shape cannot be unknown rank for input [", idx, "]!")); for (int d = 0; d < partial.dims(); d++) { int dim = partial.dim_size(d); if (dim == -1) { OP_REQUIRES( ctx, unknown_index == -1, errors::InvalidArgument( "Only one input size may be -1, not both ", unknown_index, " and ", d, "for input [", idx, "]!")); unknown_index = d; shape.AddDim(1); } else { shape.AddDim(dim); product *= dim; } } if (unknown_index != -1) { if (product == 0) { // In this case, tensor_size should be 0. // Check will perform later. shape.set_dim(unknown_index, 0); } else { OP_REQUIRES(ctx, tensor_size % product == 0, errors::InvalidArgument( "Input[", idx, "] of size ", tensor_size, " cannot be reshaped as ", shape.DebugString())); shape.set_dim(unknown_index, tensor_size / product); } } OP_REQUIRES( ctx, input.NumElements() == shape.num_elements(), errors::InvalidArgument( "Input[", idx, "] to reshape is a tensor with ", input.NumElements(), " values, but the requested shape has ", shape.num_elements())); Tensor output(input.dtype()); CHECK(output.CopyFrom(input, shape)); ctx->set_output(idx, output); sizes_vec(idx) = shape.num_elements(); } }; if (enable_parallelism_) { thread_pool->ParallelFor(num_inputs, cost_per_tensor_, reshape); } else { reshape(0, num_inputs); } } private: std::vector shapes_; bool enable_parallelism_; int64 cost_per_tensor_; }; REGISTER_OP("MonolithStaticReshapeN") .Input("inputs: dtypes") .Output("outputs: dtypes") .Output("sizes: int64") .Attr("dtypes: list(type)") .Attr("shapes: list(shape)") .Attr("enable_parallelism: bool = true") .Attr("cost_per_tensor: int = 10000000") .SetShapeFn([](shape_inference::InferenceContext* c) { std::vector shapes; TF_RETURN_IF_ERROR(c->GetAttr("shapes", &shapes)); for (int i = 0; i < shapes.size(); i++) { shape_inference::ShapeHandle shape; TF_RETURN_IF_ERROR( c->MakeShapeFromPartialTensorShape(shapes[i], &shape)); c->set_output(i, shape); } c->set_output(shapes.size(), c->Vector(shapes.size())); return Status::OK(); }); REGISTER_KERNEL_BUILDER(Name("MonolithStaticReshapeN").Device(DEVICE_CPU), StaticReshapeNOp); } // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/runtime/ops/touched_key_set_insert_op.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_op_kernel.h" #include "monolith/native_training/runtime/ops/touched_key_set_tf_bridge.h" namespace tensorflow { namespace monolith_tf { class TouchedKeySetInsertOp : public OpKernel { public: explicit TouchedKeySetInsertOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} void Compute(OpKernelContext* ctx) override { TouchedKeySetTfBridge* touched_key_set = nullptr; OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &touched_key_set)); core::ScopedUnref unref(touched_key_set); const Tensor& tensor = ctx->input(1); const int64 num_elements = tensor.NumElements(); auto ids = tensor.vec(); int64 total_dropped_num = 0; for (int64 i = 0; i < num_elements; ++i) { total_dropped_num += touched_key_set->Insert(ids(i)); } Tensor* output; OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {1}, &output)); auto output_vec = output->vec(); output_vec(0) = total_dropped_num; } }; REGISTER_OP("MonolithTouchedKeySetInsert") .Input("handle: resource") .Input("ids: int64") .Output("size: int64") .SetIsStateful() .SetShapeFn(shape_inference::ScalarShape); REGISTER_KERNEL_BUILDER(Name("MonolithTouchedKeySetInsert").Device(DEVICE_CPU), TouchedKeySetInsertOp); } // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/runtime/ops/touched_key_set_op.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_op_kernel.h" #include "monolith/native_training/runtime/ops/touched_key_set_tf_bridge.h" namespace tensorflow { namespace monolith_tf { class TouchedKeySetOp : public ResourceOpKernel { public: explicit TouchedKeySetOp(OpKernelConstruction* ctx) : ResourceOpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("capacity", &capacity_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("concurrency_level", &concurrency_level_)); } ~TouchedKeySetOp() override = default; private: Status CreateResource(TouchedKeySetTfBridge** touched_key_set_bridge) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) override { auto touched_key_set = std::make_unique>( capacity_, concurrency_level_); *touched_key_set_bridge = new TouchedKeySetTfBridge(std::move(touched_key_set)); return Status::OK(); }; int64 capacity_; int concurrency_level_; }; REGISTER_OP("MonolithTouchedKeySet") .Output("handle: resource") .Attr("capacity: int = 2097152") .Attr("concurrency_level: int = 1024") .Attr("container: string = ''") .Attr("shared_name: string = ''") .SetIsStateful() .SetShapeFn(shape_inference::ScalarShape); REGISTER_KERNEL_BUILDER(Name("MonolithTouchedKeySet").Device(DEVICE_CPU), TouchedKeySetOp); } // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/runtime/ops/touched_key_set_steal_op.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_op_kernel.h" #include "monolith/native_training/runtime/ops/touched_key_set_tf_bridge.h" namespace tensorflow { namespace monolith_tf { class TouchedKeySetStealOp : public OpKernel { public: explicit TouchedKeySetStealOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} void Compute(OpKernelContext* ctx) override { TouchedKeySetTfBridge* touched_key_set = nullptr; OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &touched_key_set)); core::ScopedUnref unref(touched_key_set); auto ids = touched_key_set->Steal(); Tensor* output; OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {static_cast(ids.size())}, &output)); auto output_vec = output->vec(); for (size_t i = 0; i < ids.size(); ++i) { output_vec(i) = ids[i]; } } }; REGISTER_OP("MonolithTouchedKeySetSteal") .Input("handle: resource") .Output("ids: int64") .SetIsStateful() .SetShapeFn(shape_inference::ScalarShape); REGISTER_KERNEL_BUILDER(Name("MonolithTouchedKeySetSteal").Device(DEVICE_CPU), TouchedKeySetStealOp); } // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/runtime/ops/touched_key_set_tf_bridge.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_MONOLITH_NATIVE_TRAINING_RUNTIME_OPS_TOUCHED_KEY_SET_TF_BRIDGE_H_ #define MONOLITH_MONOLITH_NATIVE_TRAINING_RUNTIME_OPS_TOUCHED_KEY_SET_TF_BRIDGE_H_ #include #include "absl/strings/str_format.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/framework/resource_op_kernel.h" #include "monolith/native_training/runtime/hopscotch/hopscotch_hash_set.h" namespace tensorflow { namespace monolith_tf { class TouchedKeySetTfBridge : public ResourceBase { public: explicit TouchedKeySetTfBridge(std::unique_ptr> touched_key_set) : touched_key_set_(std::move(touched_key_set)) {} size_t Insert(int64_t key) { return touched_key_set_->insert(key); } std::vector Steal() { return touched_key_set_->GetAndClear(); } size_t Size() const { return touched_key_set_->size(); } std::string DebugString() const override { return absl::StrFormat("TouchedKeySet with capacity: %d", touched_key_set_->capacity()); } private: std::unique_ptr> touched_key_set_; }; } // namespace monolith_tf } // namespace tensorflow #endif // MONOLITH_MONOLITH_NATIVE_TRAINING_RUNTIME_OPS_TOUCHED_KEY_SET_TF_BRIDGE_H_ ================================================ FILE: monolith/native_training/runtime/ops/unique_mapping_ops.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" #include "absl/strings/str_format.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/platform/threadpool.h" namespace tensorflow { namespace monolith_tf { namespace { // We want to create a tensor, while its lifecycle going with the step. // So the idea here is to use step_container. // // The following code is borrowed from variable_ops.cc in Tensorflow. string SharedTensorName(const string& tensor_name, const FrameAndIter& control_frame) { if (control_frame.frame_id != kIllegalFrameId && control_frame.iter_id != kIllegalIterId) { return strings::StrCat(tensor_name, "/frame:", control_frame.frame_id, "/iter:", control_frame.iter_id); } return tensor_name; } struct SharedTensor : public ResourceBase { // Maybe we can add a mutex here if needed. std::string name; Tensor val; string DebugString() const override { return name; } int64 MemoryUsed() const override { return val.AllocatedBytes(); } }; class UniqueKeyWitValueAndOffsetOp : public OpKernel { public: explicit UniqueKeyWitValueAndOffsetOp(OpKernelConstruction* c) : OpKernel(c) { OP_REQUIRES_OK(c, c->GetAttr("dims", &dims_)); OP_REQUIRES_OK(c, c->GetAttr("generate_buffer", &generate_buffer_)); dims_size_ = dims_.size(); } void Compute(OpKernelContext* c) override { auto key = c->input(0).vec(); auto key_split = c->input(1).vec(); OP_REQUIRES(c, key_split.size() == dims_size_ + 1, errors::InvalidArgument("RaggedKey should have ", dims_size_, " but got ", key_split.size() - 1)); Tensor* t; std::vector unique_key; unique_key.reserve(key.size()); OP_REQUIRES_OK(c, c->allocate_output(1, {dims_size_ + 1}, &t)); auto unique_key_split_vec = t->vec(); unique_key_split_vec(0) = 0; OP_REQUIRES_OK(c, c->allocate_output(2, {key.size()}, &t)); auto value_offset_vec = t->vec(); int64 value_offset_vec_offset = 0; std::vector value_offset_split; value_offset_split.reserve(key.size()); value_offset_split.push_back(0); int64 value_offset = 0; absl::flat_hash_map> m; int j = 0; m.reserve(2 * (key_split(1) - key_split(0))); for (int i = 0;; ++i) { while (i == key_split(j + 1)) { unique_key_split_vec(j + 1) = unique_key.size(); for (int k = unique_key_split_vec(j); k < unique_key_split_vec(j + 1); ++k) { auto it = m.find(unique_key[k]); for (int64 value_offset_for_key : it->second) { value_offset_vec(value_offset_vec_offset++) = value_offset_for_key; } value_offset_split.push_back(value_offset_vec_offset); } ++j; if (j < dims_size_) { m.clear(); m.reserve(2 * (key_split(j + 1) - key_split(j))); } else { break; } } if (i == key.size()) break; auto it = m.find(key(i)); if (it == m.end()) { m.insert({key(i), {value_offset}}); unique_key.push_back(key(i)); } else { it->second.push_back(value_offset); } value_offset += dims_[j]; } OP_REQUIRES_OK(c, c->allocate_output(0, {unique_key.size()}, &t)); auto unique_key_vec = t->vec(); std::memcpy(unique_key_vec.data(), unique_key.data(), sizeof(int64) * unique_key.size()); OP_REQUIRES_OK(c, c->allocate_output(3, {value_offset_split.size()}, &t)); auto value_offset_split_vec = t->vec(); std::memcpy(value_offset_split_vec.data(), value_offset_split.data(), sizeof(int64) * value_offset_split.size()); OP_REQUIRES_OK(c, CreateSharedTensor(c, {value_offset})); } Status CreateSharedTensor(OpKernelContext* c, TensorShape shape) { if (!generate_buffer_) { Tensor* handle; TF_RETURN_IF_ERROR(c->allocate_output(4, TensorShape({}), &handle)); handle->scalar()() = ResourceHandle(); return Status::OK(); } const std::string unique_name = SharedTensorName(def().name(), c->frame_iter()); Tensor t; TF_RETURN_IF_ERROR(c->allocate_temp(DataType::DT_FLOAT, shape, &t)); SharedTensor* st = new SharedTensor(); st->name = unique_name; st->val = std::move(t); auto* container = c->step_container(); TF_RETURN_IF_ERROR( container->Create(c->resource_manager(), unique_name, st)); Tensor* handle; TF_RETURN_IF_ERROR(c->allocate_output(4, TensorShape({}), &handle)); handle->scalar()() = container->MakeResourceHandle(unique_name, *c->device()); return Status::OK(); } private: std::vector dims_; int dims_size_; bool generate_buffer_; }; REGISTER_OP("MonolithUniqueKeyWithValueAndOffset") .Input("key: int64") .Input("key_split: int64") .Output("unique_key: int64") .Output("unique_key_split: int64") .Output("value_offset: int64") .Output("value_offset_split: int64") .Output("value_buffer: resource") .Attr("dims: list(int)") .Attr("generate_buffer: bool") .SetShapeFn([](shape_inference::InferenceContext* c) { c->set_output(0, c->Vector(c->UnknownDim())); c->set_output(1, c->input(1)); c->set_output(2, c->input(0)); c->set_output(3, c->Vector(c->UnknownDim())); c->set_output(4, c->Scalar()); return Status::OK(); }); REGISTER_KERNEL_BUILDER( Name("MonolithUniqueKeyWithValueAndOffset").Device(DEVICE_CPU), UniqueKeyWitValueAndOffsetOp); class FinallizeSharedTensorOp : public OpKernel { public: explicit FinallizeSharedTensorOp(OpKernelConstruction* c) : OpKernel(c) {} void Compute(OpKernelContext* c) override { core::RefCountPtr st = nullptr; OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &st)); c->set_output(0, st->val); OP_REQUIRES_OK(c, DeleteResource(c, HandleFromInput(c, 0))); } }; REGISTER_OP("MonolithFinalizeSharedTensor") .Input("handle: num_tensors * resource") .Output("t: dtype") .Attr("shape: shape") .Attr("dtype: type") .Attr("num_tensors: int") .SetIsStateful() .SetShapeFn(shape_inference::ExplicitShape); REGISTER_KERNEL_BUILDER(Name("MonolithFinalizeSharedTensor").Device(DEVICE_CPU), FinallizeSharedTensorOp); class FillWithOffsetMapOp : public OpKernel { public: explicit FillWithOffsetMapOp(OpKernelConstruction* c) : OpKernel(c) { OP_REQUIRES_OK(c, c->GetAttr("dims", &dims_)); } void Compute(OpKernelContext* c) override { auto pos = c->input(0).vec(); auto pos_split = c->input(1).vec(); auto value = c->input(2).vec(); auto value_offset_map = c->input(3).vec(); auto value_offset_map_split = c->input(4).vec(); core::RefCountPtr st = nullptr; OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 5), &st)); auto value_buffer = st->val.vec(); int64 value_offset = 0; OP_REQUIRES( c, pos_split.size() == dims_.size() + 1, errors::InvalidArgument("Pos's first dim doesn't match dim size. ", pos_split.size() - 1, " v.s. ", dims_.size())); int j = 0; for (int i = 0; i < pos.size(); ++i) { while (i == pos_split(j + 1)) { ++j; } OP_REQUIRES( c, pos(i) < value_offset_map.size(), errors::InvalidArgument("pos is bigger than offset size. ", pos(i), " v.s. ", value_offset_map.size())); const int64 value_offset_end = value_offset + dims_[j]; OP_REQUIRES(c, value_offset_end <= value.size(), errors::InvalidArgument(FormatValueError(pos_split, value))); for (int64 offset_pos = value_offset_map_split(pos(i)); offset_pos < value_offset_map_split(pos(i) + 1); ++offset_pos) { std::memcpy(value_buffer.data() + value_offset_map(offset_pos), value.data() + value_offset, dims_[j] * sizeof(float)); } value_offset = value_offset_end; } c->set_output(0, c->input(5)); } private: std::string FormatValueError(TTypes::Vec pos_split, TTypes::Vec value) { std::string s; std::vector pos_split_vec; for (int i = 0; i < pos_split.size(); ++i) { pos_split_vec.push_back(pos_split(i)); } int64 expected_size = 0; for (int i = 0; i < dims_.size(); ++i) { expected_size += (pos_split_vec[i + 1] - pos_split_vec[i]) * dims_[i]; } absl::StrAppend(&s, absl::StrFormat("Value size doesn't match expected size. " "expected: %d, actual: %d. \n", expected_size, value.size())); absl::StrAppend(&s, absl::StrFormat("Pos split: %s", absl::StrJoin(pos_split_vec, ","))); return s; } std::vector dims_; }; REGISTER_OP("MonolithFillWithOffsetMap") .Input("pos: int64") .Input("pos_split: int64") .Input("value : float") .Input("value_offset_map: int64") .Input("value_offset_map_split: int64") .Input("value_buffer: resource") .Output("out_value_buffer: resource") .Attr("dims: list(int)") .SetShapeFn(shape_inference::ScalarShape); REGISTER_KERNEL_BUILDER(Name("MonolithFillWithOffsetMap").Device(DEVICE_CPU), FillWithOffsetMapOp); class FillWithOffsetMapGradientOp : public OpKernel { public: explicit FillWithOffsetMapGradientOp(OpKernelConstruction* c) : OpKernel(c) { OP_REQUIRES_OK(c, c->GetAttr("dims", &dims_)); } void Compute(OpKernelContext* c) override { auto pos = c->input(0).vec(); auto pos_split = c->input(1).vec(); auto grad = c->input(2).vec(); auto grad_offset_map = c->input(3).vec(); auto grad_offset_map_split = c->input(4).vec(); int64 bgrad_size = 0; for (int j = 0; j < dims_.size(); ++j) { bgrad_size += dims_[j] * (pos_split(j + 1) - pos_split(j)); } Tensor* t; OP_REQUIRES_OK(c, c->allocate_output(0, {bgrad_size}, &t)); auto bgrad = t->vec(); bgrad.setZero(); int64 bgrad_offset = 0; int j = 0; for (int i = 0; i < pos.size(); ++i) { while (i == pos_split(j + 1)) { ++j; } OP_REQUIRES( c, pos(i) < grad_offset_map.size(), errors::InvalidArgument("pos is bigger than offset size. ", pos(i), " v.s. ", grad_offset_map.size())); for (int64 offset_pos = grad_offset_map_split(pos(i)); offset_pos < grad_offset_map_split(pos(i) + 1); ++offset_pos) { const int64 grad_offset = grad_offset_map(offset_pos); for (int k = 0; k < dims_[j]; ++k) { bgrad(bgrad_offset + k) += grad(grad_offset + k); } } bgrad_offset += dims_[j]; } } private: std::vector dims_; }; REGISTER_OP("MonolithFillWithOffsetMapGradient") .Input("pos: int64") .Input("pos_split: int64") .Input("grad: float") .Input("grad_offset_map: int64") .Input("grad_offset_map_split: int64") .Output("backprop_grad: float") .Attr("dims: list(int)") .SetShapeFn([](shape_inference::InferenceContext* c) { c->set_output(0, c->Vector(c->UnknownDim())); return Status::OK(); }); REGISTER_KERNEL_BUILDER( Name("MonolithFillWithOffsetMapGradient").Device(DEVICE_CPU), FillWithOffsetMapGradientOp); class FusedValueRowidsOp : public OpKernel { public: explicit FusedValueRowidsOp(OpKernelConstruction* c) : OpKernel(c) {} void Compute(OpKernelContext* c) override { auto splits = c->input(0).vec(); Tensor* t; const int len = splits.size() - 1; OP_REQUIRES_OK(c, c->allocate_output(0, {splits(len)}, &t)); auto rowids = t->vec(); for (int64 i = 0; i < len; ++i) { for (int64 j = splits(i); j < splits(i + 1); ++j) { rowids(j) = i; } } } }; REGISTER_OP("MonolithFusedValueRowids") .Input("splits: int64") .Output("rowids: int64") .SetShapeFn([](shape_inference::InferenceContext* c) { c->set_output(0, c->Vector(c->UnknownDim())); return Status::OK(); }); REGISTER_KERNEL_BUILDER(Name("MonolithFusedValueRowids").Device(DEVICE_CPU), FusedValueRowidsOp); } // namespace } // namespace monolith_tf } // namespace tensorflow ================================================ FILE: monolith/native_training/runtime/parameter_sync/BUILD ================================================ load("@rules_cc//cc:defs.bzl", "cc_library", "cc_proto_library", "cc_test") load("@com_github_grpc_grpc//bazel:cc_grpc_library.bzl", "cc_grpc_library") load("@com_google_protobuf//:protobuf.bzl", "py_proto_library") load("@org_tensorflow//tensorflow:tensorflow.bzl", "tf_cc_test") load("@com_github_grpc_grpc//bazel:cc_grpc_library.bzl", "cc_grpc_library") package(default_visibility = ["//monolith/native_training/runtime:__subpackages__"]) cc_library( name = "sync_client_interface", hdrs = ["sync_client_interface.h"], deps = [ "@com_github_grpc_grpc//:grpc++", ], ) cc_library( name = "dummy_sync_client", hdrs = ["dummy_sync_client.h"], deps = [ ":sync_client_interface", ], ) proto_library( name = "parameter_sync_proto", srcs = ["parameter_sync.proto"], ) cc_proto_library( name = "parameter_sync_cc_proto", deps = [":parameter_sync_proto"], ) py_proto_library( name = "parameter_sync_py_proto", srcs = ["parameter_sync.proto"], srcs_version = "PY2AND3", visibility = ["//visibility:public"], deps = [], ) cc_grpc_library( name = "parameter_sync_cc_grpc", srcs = [":parameter_sync_proto"], grpc_only = True, deps = [":parameter_sync_cc_proto"], ) cc_library( name = "parameter_sync_client", srcs = ["parameter_sync_client.cc"], hdrs = ["parameter_sync_client.h"], deps = [ ":parameter_sync_cc_grpc", ":sync_client_interface", "@com_google_absl//absl/strings:str_format", "@com_google_glog//:glog", "@org_tensorflow//tensorflow/core/platform:logging", "@org_tensorflow_serving//tensorflow_serving/apis:prediction_service_proto", ], ) tf_cc_test( name = "parameter_sync_client_test", srcs = ["parameter_sync_client_test.cc"], deps = [ ":parameter_sync_client", "@com_google_googletest//:gtest_main", "@org_tensorflow//tensorflow/core:test", ], ) cc_library( name = "dummy_sync_server", srcs = ["dummy_sync_server.cc"], hdrs = ["dummy_sync_server.h"], deps = [ ":parameter_sync_cc_grpc", "@com_github_grpc_grpc//:grpc++", "@com_github_grpc_grpc//:grpc++_reflection", "@com_google_absl//absl/strings:str_format", "@com_google_glog//:glog", "@org_tensorflow_serving//tensorflow_serving/apis:prediction_service_proto", ], ) cc_library( name = "request_splitter", srcs = ["request_splitter.cc"], hdrs = ["request_splitter.h"], deps = [ ":parameter_sync_cc_grpc", "@com_github_grpc_grpc//:grpc++", "@com_github_grpc_grpc//:grpc++_reflection", "@com_google_absl//absl/strings:str_format", "@com_google_glog//:glog", ], ) cc_test( name = "request_splitter_test", srcs = ["request_splitter_test.cc"], deps = [ ":request_splitter", "@com_google_googletest//:gtest_main", ], ) cc_library( name = "sync_client_manager", srcs = ["sync_client_manager.cc"], hdrs = ["sync_client_manager.h"], deps = [ ":parameter_sync_cc_grpc", ":sync_client_interface", "//monolith/native_training/runtime/common:metrics", "//monolith/native_training/runtime/parameter_sync:parameter_sync_client", "//monolith/native_training/runtime/parameter_sync:request_splitter", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", "@com_google_glog//:glog", "@org_tensorflow//tensorflow/core/platform:logging", "@org_tensorflow_serving//tensorflow_serving/apis:prediction_service_proto", ], ) ================================================ FILE: monolith/native_training/runtime/parameter_sync/dummy_sync_client.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_MONOLITH_NATIVE_TRAINING_RUNTIME_PARAMETER_SYNC_DUMMY_SYNC_CLIENT_H_ #define MONOLITH_MONOLITH_NATIVE_TRAINING_RUNTIME_PARAMETER_SYNC_DUMMY_SYNC_CLIENT_H_ #include "monolith/native_training/runtime/parameter_sync/sync_client_interface.h" namespace monolith { namespace parameter_sync { class DummySyncClient : public SyncClientInterface { public: explicit DummySyncClient(const std::string& target) {} grpc::Status Push(const PushRequest&, PushResponse*) const override { return grpc::Status::OK; } }; } // namespace parameter_sync } // namespace monolith #endif // MONOLITH_MONOLITH_NATIVE_TRAINING_RUNTIME_PARAMETER_SYNC_DUMMY_SYNC_CLIENT_H_ ================================================ FILE: monolith/native_training/runtime/parameter_sync/dummy_sync_server.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/runtime/parameter_sync/dummy_sync_server.h" namespace monolith { namespace parameter_sync { using grpc::Server; using grpc::ServerBuilder; using grpc::ServerContext; using grpc::Status; DummySyncServer::DummySyncServer(std::string target) : target_(std::move(target)), selected_port_(0) { grpc::EnableDefaultHealthCheckService(true); grpc::reflection::InitProtoReflectionServerBuilderPlugin(); ServerBuilder builder; // Listen on the given address without any authentication mechanism. builder.AddListeningPort(target_, grpc::InsecureServerCredentials(), &selected_port_); // Register "service" as the instance through which we'll communicate with // clients. In this case it corresponds to an *synchronous* service. builder.RegisterService(&service_); // Finally assemble the server. server_ = builder.BuildAndStart(); LOG(INFO) << "Server listening on " << target_ << ", selecting port " << selected_port_ << std::endl; } void DummySyncServer::Shutdown() const { server_->Shutdown(); } const std::string& DummySyncServer::GetTarget() const { return target_; } int DummySyncServer::GetSelectedPort() const { return selected_port_; } } // namespace parameter_sync } // namespace monolith ================================================ FILE: monolith/native_training/runtime/parameter_sync/dummy_sync_server.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_MONOLITH_NATIVE_TRAINING_RUNTIME_PARAMETER_SYNC_DUMMY_SYNC_SERVER_H_ #define MONOLITH_MONOLITH_NATIVE_TRAINING_RUNTIME_PARAMETER_SYNC_DUMMY_SYNC_SERVER_H_ #include "absl/strings/match.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "glog/logging.h" #include "grpcpp/ext/proto_server_reflection_plugin.h" #include "grpcpp/grpcpp.h" #include "grpcpp/health_check_service_interface.h" #include "tensorflow_serving/apis/prediction_service.grpc.pb.h" #include "monolith/native_training/runtime/parameter_sync/parameter_sync.grpc.pb.h" namespace monolith { namespace parameter_sync { // Test only class PredictionServiceImpl final : public tensorflow::serving::PredictionService::Service { grpc::Status Predict( grpc::ServerContext* context, const tensorflow::serving::PredictRequest* request, tensorflow::serving::PredictResponse* response) override { // TODO(zhangbiao.david): remove LOG(INFO) << "PredictionServiceImpl" << std::endl; for (const auto& kv : request->inputs()) { std::vector output; if (absl::EndsWith(kv.first, "_id")) { std::transform(kv.second.int64_val().begin(), kv.second.int64_val().end(), std::back_inserter(output), [](int64_t id) { return std::to_string(id); }); } else if (absl::EndsWith(kv.first, "_value")) { std::transform(kv.second.float_val().begin(), kv.second.float_val().end(), std::back_inserter(output), [](float value) { return std::to_string(value); }); } else { LOG(FATAL) << "Inputs' key should end with '_id' or '_value'"; } LOG(INFO) << absl::StrFormat("%s: %s", kv.first, absl::StrJoin(output, " ")); } response->mutable_model_spec()->CopyFrom(request->model_spec()); return grpc::Status::OK; } }; class DummySyncServer { public: explicit DummySyncServer(std::string target); void Shutdown() const; const std::string& GetTarget() const; int GetSelectedPort() const; private: std::string target_; int selected_port_; PredictionServiceImpl service_; std::unique_ptr server_; }; } // namespace parameter_sync } // namespace monolith #endif // MONOLITH_MONOLITH_NATIVE_TRAINING_RUNTIME_PARAMETER_SYNC_DUMMY_SYNC_SERVER_H_ ================================================ FILE: monolith/native_training/runtime/parameter_sync/parameter_sync.proto ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. syntax = "proto2"; package monolith.parameter_sync; // The request message containing delta fids and embeddings. message PushRequest { message DeltaEmbeddingHashTable { optional string unique_id = 1; optional int32 dim_size = 2; repeated int64 fids = 3; repeated float embeddings = 4; } optional string model_name = 1; optional string signature_name = 2; repeated DeltaEmbeddingHashTable delta_hash_tables = 3; // The embedding changes from multi hash tables // The size of this field should equal to number of hash tables included in // multi hash table. repeated DeltaEmbeddingHashTable delta_multi_hash_tables = 5; optional int64 timeout_in_ms = 4 [default = 1000]; } // The response message message PushResponse { // gRPC's server address optional string target = 3; // gRPC's status code, 0 means OK. optional int32 status_code = 1; // gRPC's error message optional string error_message = 2; // Number of fids successfully assigned optional int32 update_num = 4; } message PushResult { repeated PushResponse responses = 1; } message ClientConfig { optional string model_name = 1; optional string signature_name = 2; repeated string targets = 3; optional int64 timeout_in_ms = 4 [default = 1000]; message TargetExtraInfo { optional string idc = 1; optional string cluster = 2; optional int64 replica_id = 3 [default = -1]; } repeated TargetExtraInfo targets_extra_info = 5; } ================================================ FILE: monolith/native_training/runtime/parameter_sync/parameter_sync_client.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/runtime/parameter_sync/parameter_sync_client.h" #include "absl/strings/match.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "grpc/impl/codegen/gpr_types.h" #include "tensorflow/core/platform/default/logging.h" namespace monolith { namespace parameter_sync { using grpc::Channel; using grpc::ClientContext; using grpc::Status; using tensorflow::serving::PredictionService; using tensorflow::serving::PredictRequest; using tensorflow::serving::PredictResponse; ParameterSyncClient::ConvertResult ParameterSyncClient::Convert( const PushRequest& request) { ConvertResult result; PredictRequest& predict_request = result.req; predict_request.mutable_model_spec()->set_name(request.model_name()); predict_request.mutable_model_spec()->set_signature_name( request.signature_name()); auto& inputs = *predict_request.mutable_inputs(); int& total = result.total; if (request.delta_multi_hash_tables_size() > 0) { tensorflow::TensorProto id_tensor, id_split_tensor, emb_tensor; id_tensor.set_dtype(tensorflow::DataType::DT_INT64); id_split_tensor.set_dtype(tensorflow::DataType::DT_INT64); emb_tensor.set_dtype(tensorflow::DataType::DT_FLOAT); int64_t split = 0; id_split_tensor.add_int64_val(split); for (const PushRequest::DeltaEmbeddingHashTable& table : request.delta_multi_hash_tables()) { total += table.fids_size(); split += table.fids_size(); id_split_tensor.add_int64_val(split); for (int64_t id : table.fids()) { id_tensor.add_int64_val(id); } for (float value : table.embeddings()) { emb_tensor.add_float_val(value); } } id_tensor.mutable_tensor_shape()->add_dim()->set_size( id_tensor.int64_val_size()); id_split_tensor.mutable_tensor_shape()->add_dim()->set_size( id_split_tensor.int64_val_size()); emb_tensor.mutable_tensor_shape()->add_dim()->set_size( emb_tensor.float_val_size()); // names here should match what we write in `saved_model_exporters` inputs["id"] = std::move(id_tensor); inputs["id_split"] = std::move(id_split_tensor); inputs["flat_value"] = std::move(emb_tensor); } else { for (const auto& delta : request.delta_hash_tables()) { int num_update = delta.fids().size(); total += num_update; tensorflow::TensorProto proto_fid, proto_emb; proto_fid.set_dtype(tensorflow::DataType::DT_INT64); proto_emb.set_dtype(tensorflow::DataType::DT_FLOAT); for (int64_t id : delta.fids()) { proto_fid.add_int64_val(id); } for (float value : delta.embeddings()) { proto_emb.add_float_val(value); } int dimension = delta.dim_size(); proto_fid.mutable_tensor_shape()->add_dim()->set_size(num_update); proto_emb.mutable_tensor_shape()->add_dim()->set_size(num_update); proto_emb.mutable_tensor_shape()->add_dim()->set_size(dimension); inputs[delta.unique_id() + "_id"] = proto_fid; inputs[delta.unique_id() + "_value"] = proto_emb; } } return result; } grpc::Status ParameterSyncClient::Push(const PushRequest& request, PushResponse* response) const { // Context for the client. It could be used to convey extra information to // the server and/or tweak certain RPC behaviors. ConvertResult convert_result = Convert(request); const PredictRequest& predict_request = convert_result.req; PredictResponse predict_response; ClientContext context; gpr_timespec ts; ts.tv_sec = request.timeout_in_ms() / 1000; ts.tv_nsec = (request.timeout_in_ms() % 1000) * 1000 * 1000; ts.clock_type = GPR_TIMESPAN; context.set_deadline(ts); // TODO(zhangbiao.david): predict_request.DebugString() causes a segment // fault, but I have no idea about it. // LOG(INFO) << "PredictRequest\n" << predict_request.DebugString() << // std::endl; // The actual RPC. Status status; status = stub_->Predict(&context, predict_request, &predict_response); response->set_status_code(status.error_code()); response->set_error_message(status.error_message()); // Act upon its status. if (status.ok()) { response->set_update_num(convert_result.total); return status; } else { response->set_update_num(0); LOG_EVERY_N_SEC(ERROR, 10) << status.error_code() << ": " << status.error_message(); return status; } } } // namespace parameter_sync } // namespace monolith ================================================ FILE: monolith/native_training/runtime/parameter_sync/parameter_sync_client.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_NATIVE_TRAINING_RUNTIME_PARAMETER_SYNC_PARAMETER_SYNC_CLIENT_H_ #define MONOLITH_NATIVE_TRAINING_RUNTIME_PARAMETER_SYNC_PARAMETER_SYNC_CLIENT_H_ #include "glog/logging.h" #include "grpcpp/grpcpp.h" #include "monolith/native_training/runtime/parameter_sync/parameter_sync.grpc.pb.h" #include "monolith/native_training/runtime/parameter_sync/sync_client_interface.h" #include "tensorflow_serving/apis/prediction_service.grpc.pb.h" namespace monolith { namespace parameter_sync { class ParameterSyncClient final : public SyncClientInterface { public: explicit ParameterSyncClient(std::string target) : ParameterSyncClient(CreateStub(target)) { target_ = std::move(target); } explicit ParameterSyncClient( std::unique_ptr stub) : stub_(std::move(stub)) {} // Assembles the client's payload, sends it and presents the response back // from the server. grpc::Status Push(const PushRequest& request, PushResponse* response) const override; // Ideally we should mock stub to simulate the behavior of this class. // However, there are some problems to generate mock class. // We just verify request conversion here. struct ConvertResult { tensorflow::serving::PredictRequest req; int total = 0; }; static ConvertResult Convert(const PushRequest& req); private: static std::unique_ptr CreateStub(const std::string& target) { // 32M const int MAX_MESSAGE_LENGTH = 32 * 1024 * 1024; grpc::ChannelArguments arguments; arguments.SetMaxSendMessageSize(MAX_MESSAGE_LENGTH); arguments.SetMaxReceiveMessageSize(MAX_MESSAGE_LENGTH); auto channel = grpc::CreateCustomChannel( target, grpc::InsecureChannelCredentials(), arguments); return tensorflow::serving::PredictionService::NewStub(channel); } std::string target_; std::unique_ptr stub_; }; } // namespace parameter_sync } // namespace monolith #endif // MONOLITH_NATIVE_TRAINING_RUNTIME_PARAMETER_SYNC_PARAMETER_SYNC_CLIENT_H_ ================================================ FILE: monolith/native_training/runtime/parameter_sync/parameter_sync_client_test.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/runtime/parameter_sync/parameter_sync_client.h" #include "gmock/gmock.h" #include "google/protobuf/text_format.h" #include "google/protobuf/util/message_differencer.h" #include "grpcpp/grpcpp.h" #include "gtest/gtest.h" namespace monolith { namespace parameter_sync { namespace { using ::tensorflow::serving::PredictRequest; TEST(MultiHashTableTest, Basic) { PushRequest req; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( model_name: "test_model" signature_name: "table/raw_assign" delta_multi_hash_tables: [ { fids: [1, 2] embeddings: [1.0, 2.0, 3.0, 4.0] }, { fids: [3] embeddings: [5.0] } ] )", &req)); auto result = ParameterSyncClient::Convert(req); PredictRequest expected_predict_req; ASSERT_TRUE( google::protobuf::TextFormat::ParseFromString(R"( model_spec { name: "test_model" signature_name: "table/raw_assign" } inputs { key: "flat_value" value { dtype: DT_FLOAT tensor_shape { dim { size: 5 } } float_val: [1.0, 2.0, 3.0, 4.0, 5.0] } } inputs { key: "id" value { dtype: DT_INT64 tensor_shape { dim { size: 3 } } int64_val: [1, 2, 3] } } inputs { key: "id_split" value { dtype: DT_INT64 tensor_shape { dim { size: 3 } } int64_val: [0, 2, 3] } } )", &expected_predict_req)); EXPECT_TRUE(google::protobuf::util::MessageDifferencer::Equals( result.req, expected_predict_req)); } TEST(HashTableTest, Basic) { PushRequest req; ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( model_name: "test_model" signature_name: "hashtable_assign" delta_hash_tables: [ { unique_id: "table1" dim_size: 2 fids: [1, 2] embeddings: [1.0, 2.0, 3.0, 4.0] }, { unique_id: "table2" dim_size: 1 fids: [3] embeddings: [5.0] } ] )", &req)); auto result = ParameterSyncClient::Convert(req); PredictRequest expected_predict_req; ASSERT_TRUE( google::protobuf::TextFormat::ParseFromString(R"( model_spec { name: "test_model" signature_name: "hashtable_assign" } inputs { key: "table1_id" value { dtype: DT_INT64 tensor_shape { dim { size: 2 } } int64_val: [1, 2] } } inputs { key: "table1_value" value { dtype: DT_FLOAT tensor_shape { dim { size: 2 } dim { size: 2 } } float_val: [1.0, 2.0, 3.0, 4.0] } } inputs { key: "table2_id" value { dtype: DT_INT64 tensor_shape { dim { size: 1 } } int64_val: [3] } } inputs { key: "table2_value" value { dtype: DT_FLOAT tensor_shape { dim { size: 1 } dim { size: 1 } } float_val: [5.0] } } )", &expected_predict_req)); EXPECT_TRUE(google::protobuf::util::MessageDifferencer::Equals( result.req, expected_predict_req)); } } // namespace } // namespace parameter_sync } // namespace monolith ================================================ FILE: monolith/native_training/runtime/parameter_sync/request_splitter.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/runtime/parameter_sync/request_splitter.h" #include "glog/logging.h" namespace monolith { namespace parameter_sync { namespace { void SplitTable(const PushRequest::DeltaEmbeddingHashTable& table, int split_num, int i, PushRequest::DeltaEmbeddingHashTable* target_table) { size_t delta_size = table.fids_size(); int dim_size = table.dim_size(); size_t q = delta_size / split_num; size_t part_size = i + 1 == split_num ? q + delta_size % split_num : q; target_table->set_unique_id(table.unique_id()); target_table->set_dim_size(table.dim_size()); auto* mutable_fids = target_table->mutable_fids(); auto* mutable_embeddings = target_table->mutable_embeddings(); // TODO(leqi.zou): This seems not very mem efficient. for (size_t j = 0; j < part_size; ++j) { int index = i * q + j; int64_t fid = table.fids(index); mutable_fids->Add(fid); const float* embedding = table.embeddings().data() + index * dim_size; mutable_embeddings->Add(embedding, embedding + dim_size); } } } // namespace std::vector RequestSplitter::Split( const PushRequest& push_request, int64_t max_message_length) const { DCHECK_GT(max_message_length, 0); size_t byte_size = push_request.ByteSizeLong(); if (byte_size <= max_message_length) { return {push_request}; } size_t split_num = (byte_size + max_message_length - 1) / max_message_length; const std::string& model_name = push_request.model_name(); const std::string& signature_name = push_request.signature_name(); int64_t timeout_in_ms = push_request.timeout_in_ms(); std::vector requests; requests.reserve(split_num); for (size_t i = 0; i < split_num; ++i) { requests.emplace_back(); PushRequest* request = &requests.back(); request->set_model_name(model_name); request->set_signature_name(signature_name); request->mutable_delta_hash_tables()->Reserve( push_request.delta_hash_tables_size()); request->mutable_delta_multi_hash_tables()->Reserve( push_request.delta_multi_hash_tables_size()); request->set_timeout_in_ms(timeout_in_ms); for (const auto& table : push_request.delta_hash_tables()) { auto* delta_hash_table = request->mutable_delta_hash_tables()->Add(); SplitTable(table, split_num, i, delta_hash_table); } for (const auto& table : push_request.delta_multi_hash_tables()) { auto* delta_multi_hash_table = request->mutable_delta_multi_hash_tables()->Add(); SplitTable(table, split_num, i, delta_multi_hash_table); } } return requests; } } // namespace parameter_sync } // namespace monolith ================================================ FILE: monolith/native_training/runtime/parameter_sync/request_splitter.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_MONOLITH_NATIVE_TRAINING_RUNTIME_PARAMETER_SYNC_REQUEST_SPLITTER_H_ #define MONOLITH_MONOLITH_NATIVE_TRAINING_RUNTIME_PARAMETER_SYNC_REQUEST_SPLITTER_H_ #include #include "monolith/native_training/runtime/parameter_sync/parameter_sync.grpc.pb.h" namespace monolith { namespace parameter_sync { class RequestSplitter { public: std::vector Split(const PushRequest& push_request, int64_t max_message_length) const; }; } // namespace parameter_sync } // namespace monolith #endif // MONOLITH_MONOLITH_NATIVE_TRAINING_RUNTIME_PARAMETER_SYNC_REQUEST_SPLITTER_H_ ================================================ FILE: monolith/native_training/runtime/parameter_sync/request_splitter_test.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/runtime/parameter_sync/request_splitter.h" #include #include "glog/logging.h" #include "google/protobuf/util/message_differencer.h" #include "gtest/gtest.h" namespace monolith { namespace parameter_sync { namespace { using ::google::protobuf::util::MessageDifferencer; PushRequest_DeltaEmbeddingHashTable SetUpOneDeltaHashTable( const std::string& unique_id, size_t fid_num, size_t dim, int fid_value_offset = 0) { PushRequest_DeltaEmbeddingHashTable table; table.set_unique_id(unique_id); table.set_dim_size(dim); std::vector fids(fid_num); std::iota(fids.begin(), fids.end(), fid_value_offset); std::vector embeddings(fid_num * dim); for (size_t i = 0; i < fid_num; ++i) { for (size_t j = 0; j < dim; ++j) { embeddings[i * dim + j] = static_cast(i + fid_value_offset); } } table.mutable_fids()->Add(fids.begin(), fids.end()); table.mutable_embeddings()->Add(embeddings.begin(), embeddings.end()); return table; } TEST(RequestSplitter, NoSplit) { PushRequest request; request.set_model_name("hello"); request.set_signature_name("hashtable_assign"); auto t0 = SetUpOneDeltaHashTable("table0", 0, 1); auto t1 = SetUpOneDeltaHashTable("table1", 2, 1); auto t2 = SetUpOneDeltaHashTable("table2", 3, 2); request.mutable_delta_hash_tables()->Add(std::move(t0)); request.mutable_delta_hash_tables()->Add(std::move(t1)); request.mutable_delta_hash_tables()->Add(std::move(t2)); RequestSplitter splitter; std::vector requests = splitter.Split(request, 4 * 1024 * 1024); EXPECT_EQ(requests.size(), 1); EXPECT_TRUE(MessageDifferencer::Equals(request, requests.front())); } // Test case(byte size = 111) TEST(RequestSplitter, SplitIntoTwo) { PushRequest request, request1, request2; request.set_model_name("hello"); request.set_signature_name("hashtable_assign"); auto t0 = SetUpOneDeltaHashTable("table0", 0, 1); auto t1 = SetUpOneDeltaHashTable("table1", 2, 1); auto t2 = SetUpOneDeltaHashTable("table2", 3, 2); request.mutable_delta_hash_tables()->Add(std::move(t0)); request.mutable_delta_hash_tables()->Add(std::move(t1)); request.mutable_delta_hash_tables()->Add(std::move(t2)); request.set_timeout_in_ms(100); RequestSplitter splitter; std::vector requests = splitter.Split(request, 60); EXPECT_EQ(requests.size(), 2); // part 1 request1.set_model_name("hello"); request1.set_signature_name("hashtable_assign"); t0 = SetUpOneDeltaHashTable("table0", 0, 1); t1 = SetUpOneDeltaHashTable("table1", 1, 1); t2 = SetUpOneDeltaHashTable("table2", 1, 2); request1.mutable_delta_hash_tables()->Add(std::move(t0)); request1.mutable_delta_hash_tables()->Add(std::move(t1)); request1.mutable_delta_hash_tables()->Add(std::move(t2)); request1.set_timeout_in_ms(100); EXPECT_TRUE(MessageDifferencer::Equals(requests.front(), request1)); // part 2 request2.set_model_name("hello"); request2.set_signature_name("hashtable_assign"); t0 = SetUpOneDeltaHashTable("table0", 0, 1); t1 = SetUpOneDeltaHashTable("table1", 1, 1, 1.0f); t2 = SetUpOneDeltaHashTable("table2", 2, 2, 1.0f); request2.mutable_delta_hash_tables()->Add(std::move(t0)); request2.mutable_delta_hash_tables()->Add(std::move(t1)); request2.mutable_delta_hash_tables()->Add(std::move(t2)); request2.set_timeout_in_ms(100); EXPECT_TRUE(MessageDifferencer::Equals(requests.back(), request2)); } // Test case(byte size = 113) TEST(RequestSplitter, MultiHashTableSplitIntoTwo) { PushRequest request, request1, request2; request.set_model_name("hello"); request.set_signature_name("table/raw_assign"); auto t0 = SetUpOneDeltaHashTable("table0", 0, 1); auto t1 = SetUpOneDeltaHashTable("table1", 2, 1); auto t2 = SetUpOneDeltaHashTable("table2", 3, 2); request.mutable_delta_multi_hash_tables()->Add(std::move(t0)); request.mutable_delta_multi_hash_tables()->Add(std::move(t1)); request.mutable_delta_multi_hash_tables()->Add(std::move(t2)); request.set_timeout_in_ms(100); RequestSplitter splitter; std::vector requests = splitter.Split(request, 60); EXPECT_EQ(requests.size(), 2); // part 1 request1.set_model_name("hello"); request1.set_signature_name("table/raw_assign"); t0 = SetUpOneDeltaHashTable("table0", 0, 1); t1 = SetUpOneDeltaHashTable("table1", 1, 1); t2 = SetUpOneDeltaHashTable("table2", 1, 2); request1.mutable_delta_multi_hash_tables()->Add(std::move(t0)); request1.mutable_delta_multi_hash_tables()->Add(std::move(t1)); request1.mutable_delta_multi_hash_tables()->Add(std::move(t2)); request1.set_timeout_in_ms(100); EXPECT_TRUE(MessageDifferencer::Equals(requests[0], request1)); // part 2 request2.set_model_name("hello"); request2.set_signature_name("table/raw_assign"); t0 = SetUpOneDeltaHashTable("table0", 0, 1); t1 = SetUpOneDeltaHashTable("table1", 1, 1, 1.0f); t2 = SetUpOneDeltaHashTable("table2", 2, 2, 1.0f); request2.mutable_delta_multi_hash_tables()->Add(std::move(t0)); request2.mutable_delta_multi_hash_tables()->Add(std::move(t1)); request2.mutable_delta_multi_hash_tables()->Add(std::move(t2)); request2.set_timeout_in_ms(100); EXPECT_TRUE(MessageDifferencer::Equals(requests[1], request2)); } } // namespace } // namespace parameter_sync } // namespace monolith ================================================ FILE: monolith/native_training/runtime/parameter_sync/sync_client_interface.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_MONOLITH_NATIVE_TRAINING_RUNTIME_PARAMETER_SYNC_SYNC_CLIENT_INTERFACE_H_ #define MONOLITH_MONOLITH_NATIVE_TRAINING_RUNTIME_PARAMETER_SYNC_SYNC_CLIENT_INTERFACE_H_ #include "grpcpp/grpcpp.h" #include "monolith/native_training/runtime/parameter_sync/parameter_sync.pb.h" namespace monolith { namespace parameter_sync { class SyncClientInterface { public: virtual grpc::Status Push(const PushRequest&, PushResponse*) const = 0; }; } // namespace parameter_sync } // namespace monolith #endif // MONOLITH_MONOLITH_NATIVE_TRAINING_RUNTIME_PARAMETER_SYNC_SYNC_CLIENT_INTERFACE_H_ ================================================ FILE: monolith/native_training/runtime/parameter_sync/sync_client_manager.cc ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "monolith/native_training/runtime/parameter_sync/sync_client_manager.h" #include "absl/strings/match.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "glog/logging.h" #include "tensorflow/core/platform/default/logging.h" #include "monolith/native_training/runtime/common/metrics.h" namespace monolith { namespace parameter_sync { // 4M const int MAX_MESSAGE_LENGTH = 4 * 1024 * 1024; SyncClientManager::SyncClientManager( std::function(const std::string&)> client_factory) : client_factory_(std::move(client_factory)) {} std::string SyncClientManager::PushRequestDebugString( const PushRequest& request, int index, int total) const { std::vector delta_hash_table_info; delta_hash_table_info.reserve(request.delta_hash_tables_size()); std::string prefix = "MonolithHashTable_"; for (const auto& table : request.delta_hash_tables()) { std::string simple_id = table.unique_id(); if (absl::StartsWith(simple_id, prefix)) { simple_id = simple_id.substr(prefix.length()); } delta_hash_table_info.push_back(absl::StrFormat( "(unique_id: %s, fid_num: %d)", simple_id, table.fids().size())); } for (const auto& table : request.delta_multi_hash_tables()) { std::string simple_id = table.unique_id(); delta_hash_table_info.push_back(absl::StrFormat( "(unique_id: %s, fid_num: %d)", simple_id, table.fids().size())); } return absl::StrFormat( "PushRequest[%d/%d]: model_name = %s, signature_name = %s, " "delta_hash_table = [%s]", index, total, request.model_name(), request.signature_name(), absl::StrJoin(delta_hash_table_info, ", ")); } PushResult SyncClientManager::Push(const PushRequest& request, const std::string& model_name, const std::string& signature_name) const { LOG_EVERY_N_SEC(INFO, 60) << PushRequestDebugString(request, -1, -1); std::vector requests = request_splitter_.Split(request, MAX_MESSAGE_LENGTH); int total = static_cast(requests.size()); std::vector debug_string(total); auto split_log = [&]() { for (int i = 0; i < total; ++i) { debug_string[i] = PushRequestDebugString(requests[i], i, total); } return absl::StrJoin(debug_string, "\n"); }; LOG_EVERY_N_SEC(INFO, 60) << split_log(); std::vector request_fid_count; request_fid_count.reserve(requests.size()); std::transform(requests.begin(), requests.end(), std::back_inserter(request_fid_count), [](const PushRequest& req) { int64_t count = 0; for (const auto& t : req.delta_hash_tables()) { count += t.fids_size(); } for (const auto& t : req.delta_multi_hash_tables()) { count += t.fids_size(); } return count; }); auto MakeTagKV = [&](const std::string& target, const std::string& status) { auto it = caddr_to_extra_info_map_.find(target); int replica_id = -1; std::string idc_cluster = "NA"; if (it != caddr_to_extra_info_map_.end()) { replica_id = it->second.replica_id(); idc_cluster = absl::StrFormat("%s:%s", it->second.idc(), it->second.cluster()); } return absl::StrFormat( "model_name=%s|signature_name=%s|target=%s|status=%s|idc=%s|replica_id=" "%d", model_name, signature_name, target, status, idc_cluster, replica_id); }; PushResult result; { absl::ReaderMutexLock l(&mu_); for (const auto& kv : clients_) { std::unordered_map fid_count = {{"OK", 0}, {"KO", 0}}; std::unordered_map byte_size = {{"OK", 0}, {"KO", 0}}; for (size_t i = 0; i < requests.size(); ++i) { const auto& req = requests[i]; auto response = result.add_responses(); if (!req.delta_hash_tables().empty() || !req.delta_multi_hash_tables().empty()) { int64_t start = absl::ToUnixMicros(absl::Now()); auto status = kv.second->Push(req, response); int64_t end = absl::ToUnixMicros(absl::Now()); std::string status_key = status.ok() ? "OK" : "KO"; std::string tag_kv = MakeTagKV(kv.first, status_key); monolith::GetMetrics()->emit_timer("parameter_sync_latency", end - start, tag_kv); fid_count[status_key] += request_fid_count[i]; byte_size[status_key] += static_cast(req.ByteSizeLong()); } response->set_target(kv.first); } for (const auto& p : fid_count) { if (p.second) { std::string tag_kv = MakeTagKV(kv.first, p.first); monolith::GetMetrics()->emit_counter("parameter_sync_fid_count", p.second, tag_kv); } } for (const auto& p : byte_size) { if (p.second) { std::string tag_kv = MakeTagKV(kv.first, p.first); monolith::GetMetrics()->emit_counter("parameter_sync_byte_size", p.second, tag_kv); } } } } return result; } bool SyncClientManager::TryReplace( const google::protobuf::RepeatedPtrField& targets, const google::protobuf::RepeatedPtrField< monolith::parameter_sync::ClientConfig_TargetExtraInfo>& targets_extra_info) { absl::WriterMutexLock l(&mu_); std::unordered_set unique_targets; for (const auto& target : targets) { unique_targets.insert(target); } caddr_to_extra_info_map_.clear(); if (targets.size() == targets_extra_info.size()) { for (int i = 0; i < targets.size(); i++) { caddr_to_extra_info_map_.emplace(targets[i], targets_extra_info[i]); } } // Remove invalid targets std::vector invalid_targets; for (const auto& kv : clients_) { if (!unique_targets.count(kv.first)) { invalid_targets.emplace_back(kv.first); } } for (const auto& target : invalid_targets) { clients_.erase(target); } // Add new targets for (const auto& target : unique_targets) { if (!clients_.count(target)) { auto client = client_factory_(target); clients_[target] = std::move(client); } } return true; } } // namespace parameter_sync } // namespace monolith ================================================ FILE: monolith/native_training/runtime/parameter_sync/sync_client_manager.h ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MONOLITH_MONOLITH_NATIVE_TRAINING_RUNTIME_PARAMETER_SYNC_SYNC_CLIENT_MANAGER_H_ #define MONOLITH_MONOLITH_NATIVE_TRAINING_RUNTIME_PARAMETER_SYNC_SYNC_CLIENT_MANAGER_H_ #include "absl/synchronization/mutex.h" #include "monolith/native_training/runtime/parameter_sync/request_splitter.h" #include "monolith/native_training/runtime/parameter_sync/sync_client_interface.h" namespace monolith { namespace parameter_sync { class SyncClientManager { public: SyncClientManager( std::function(const std::string&)> client_factory); PushResult Push(const PushRequest& request, const std::string& model_name, const std::string& signature_name) const ABSL_SHARED_LOCKS_REQUIRED(mu_); bool TryReplace( const google::protobuf::RepeatedPtrField& targets, const google::protobuf::RepeatedPtrField< monolith::parameter_sync::ClientConfig_TargetExtraInfo>& targets_extra_info) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); private: std::string PushRequestDebugString(const PushRequest& request, int index, int total) const; private: RequestSplitter request_splitter_; // We create an individual ParameterSyncClient for each target, which is an // online ps shard replica. Typically, each target has corresponding hash // tables like hash_tables_. std::map> clients_ ABSL_GUARDED_BY(mu_); std::map caddr_to_extra_info_map_ ABSL_GUARDED_BY(mu_); std::function(const std::string&)> client_factory_; mutable absl::Mutex mu_; }; } // namespace parameter_sync } // namespace monolith #endif // MONOLITH_MONOLITH_NATIVE_TRAINING_RUNTIME_PARAMETER_SYNC_SYNC_CLIENT_MANAGER_H_ ================================================ FILE: monolith/native_training/save_utils.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== from typing import Set import collections import threading import time import traceback import os, sys from datetime import datetime from absl import logging from google.protobuf import text_format import tensorflow as tf from tensorflow.core.protobuf import saver_pb2 from tensorflow.python.client import session from tensorflow.python.framework import ops, errors from tensorflow.python.lib.io import file_io from tensorflow.python.ops import variables from tensorflow.python.training import checkpoint_management from tensorflow.python.training import saver as tf_saver from tensorflow.python.training.py_checkpoint_reader import NewCheckpointReader, CheckpointReader from monolith.native_training.monolith_checkpoint_state_pb2 import MonolithCheckpointState from monolith.native_training import utils from monolith.native_training.session_run_hooks import tide_available_now from monolith.native_training.model_export.export_context import is_exporting from monolith.native_training.dense_reload_utils import CUSTOM_RESTORE_OP, calc_feed_dict from monolith.native_training.metric import cli from monolith.native_training import native_task_context _CkptStateCache = collections.namedtuple("_CkptStateCache", ["global_step_value", "ckpt_state"]) _ckpt_state_cache_map = {} MONOLITH_CKPT_STATE_FILE_NAME = "monolith_checkpoint" def get_latest_checkpoint_state(checkpoint_dir: str, global_step_value: int): """A function that helps to get ckpt state with cache. Args: global_step_value - used to decide if our cache is stale or not. """ cache = _ckpt_state_cache_map.get(checkpoint_dir, None) if cache is None or cache.global_step_value < global_step_value or cache.ckpt_state is None: cache = _CkptStateCache( global_step_value=global_step_value, ckpt_state=tf.train.get_checkpoint_state(checkpoint_dir)) _ckpt_state_cache_map[checkpoint_dir] = cache return _ckpt_state_cache_map.get(checkpoint_dir).ckpt_state def get_monolith_checkpoint_state(checkpoint_dir, filename=None, remove_invalid_path=False): """Returns MonolithCheckpointState proto from the "monolith_checkpoint" file. If the "monolith_checkpoint" file contains a valid MonolithCheckpointState proto, returns it. Args: checkpoint_dir: The directory of checkpoints. filename: Optional name of the monolith checkpoint file. Default to 'monolith_checkpoint'. Returns: A MonolithCheckpointState if the state was available, None otherwise. """ ckpt = None coord_checkpoint_filename = os.path.join( checkpoint_dir, filename if filename else MONOLITH_CKPT_STATE_FILE_NAME) try: # Check that the file exists before opening it to avoid # many lines of errors from colossus in the logs. if file_io.file_exists(coord_checkpoint_filename): file_content = file_io.read_file_to_string(coord_checkpoint_filename) ckpt = MonolithCheckpointState() text_format.Merge(file_content, ckpt) if remove_invalid_path: # For relative exempt_model_checkpoint_paths, prepend checkpoint_dir. for i, p in enumerate(ckpt.exempt_model_checkpoint_paths): if not os.path.isabs(p): ckpt.exempt_model_checkpoint_paths[i] = os.path.join( checkpoint_dir, p) # Remove ckpt paths which do not exist from exempt_model_checkpoint_paths ckpt_paths_not_exist = [] for i, p in enumerate(ckpt.exempt_model_checkpoint_paths): if not checkpoint_management.checkpoint_exists(p): ckpt_paths_not_exist.append(p) for p in ckpt_paths_not_exist: logging.warning( "%s not exists in file system, remove from monolith_checkpoint", p) ckpt.exempt_model_checkpoint_paths.remove(p) except errors.OpError as e: # It's ok if the file cannot be read logging.warning("%s: %s", type(e).__name__, e) logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename) return None except text_format.ParseError as e: logging.warning("%s: %s", type(e).__name__, e) logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename) return None return ckpt # TODO(leqi.zou): make this class more powerful. class SaveHelper: """A helper that provides some utils for saver listeners.""" def __init__(self, basename: str): self._basename = basename def get_ckpt_prefix(self, global_step_value: int) -> str: """Returns checkpoint prefix for given basename and global_step_value.""" return self._basename + "-" + str(global_step_value) @classmethod def get_ckpt_asset_dir(cls, ckpt_prefix: str) -> str: """Returns checkpoint asset directory for given basename and global_step_value. This is mainly to reduce the number files in the model_dir. """ return ckpt_prefix + ".assets/" def get_global_step_value(self, ckpt_prefix: str) -> int: """Returns global step value for given checkpoint prefix.""" if '-' in ckpt_prefix: return int(ckpt_prefix.split('-')[-1]) else: return 0 def get_existing_checkpoint_steps(self) -> Set[int]: ckpt_state = tf.train.get_checkpoint_state(os.path.dirname(self._basename)) checkpoint_steps = set() for path in ckpt_state.all_model_checkpoint_paths: checkpoint_steps.add(self.get_global_step_value(path)) return checkpoint_steps class SecondOrStepTimerWithTideSetting(tf.estimator.SecondOrStepTimer): """Timer that triggers at most once every N seconds or once every N steps. It'll trigger using a different setting when tide resources is not available. """ def __init__(self, every_secs=None, every_steps=None, tide_start_hour=None, tide_start_minute=None, tide_end_hour=None, tide_end_minute=None, tide_every_secs=None, save_helper: SaveHelper = None): super(SecondOrStepTimerWithTideSetting, self).__init__(every_secs=every_secs, every_steps=every_steps) self._tide_start_hour = tide_start_hour self._tide_start_minute = tide_start_minute self._tide_end_hour = tide_end_hour self._tide_end_minute = tide_end_minute self._tide_every_secs = tide_every_secs self._save_helper = save_helper self._enabled = True def enable(self): self._enabled = True def disable(self): self._enabled = False def should_trigger_for_step(self, step): """Return true if the timer should trigger for the specified step. Args: step: Training step to trigger on. Returns: True if the difference between the current time and the time of the last trigger exceeds `every_secs`, or if the difference between the current step and the last triggered step exceeds `every_steps`. False otherwise. """ if not self._enabled: return False if self._last_triggered_step is None: return True if self._last_triggered_step == step: return False if (self._tide_start_hour is not None and self._tide_end_hour is not None and self._tide_every_secs is not None) and not tide_available_now( self._tide_start_hour, self._tide_start_minute, self._tide_end_hour, self._tide_end_minute): if time.time() >= self._last_triggered_time + self._tide_every_secs: logging.info("Current UTC time: {} : {}".format( datetime.utcnow().hour, datetime.utcnow().minute)) logging.info( "Tide not available. Using tide checkpoint saving time interval.") logging.info("Now: {} Last: {} Interval: {}".format( time.time(), self._last_triggered_time, self._tide_every_secs)) return True else: if self._every_secs is not None: if time.time() >= self._last_triggered_time + self._every_secs: return True if self._every_steps is not None: if step >= self._last_triggered_step + self._every_steps: return True return False class NoFirstSaveCheckpointSaverHook(tf.estimator.CheckpointSaverHook): """A saver hook which won't perform the first save (which happened on after_create_session).""" _has_dense_only: bool = False _in_model_dump_mode: bool = False _last_triggered_step: int = 0 def __init__(self, checkpoint_dir, save_secs=None, save_steps=None, saver=None, checkpoint_basename="model.ckpt", scaffold=None, listeners=None, save_graph_def=True, tide_start_hour=None, tide_start_minute=None, tide_end_hour=None, tide_end_minute=None, tide_save_secs=None, ignore_save_errors=False, is_dense_only: bool = False, use_native_multi_hash_table: bool = False, no_first_save: bool = True, guard_saver_listeners=None): """ Args: guard_saver_listeners - listeners which are not related to saving, and will always call even there is an error. """ super().__init__(checkpoint_dir=checkpoint_dir, save_secs=save_secs, save_steps=save_steps, saver=saver, checkpoint_basename=checkpoint_basename, scaffold=scaffold, listeners=listeners, save_graph_def=save_graph_def) self._helper = SaveHelper(self._save_path) self._timer = SecondOrStepTimerWithTideSetting( every_secs=save_secs, every_steps=save_steps, tide_start_hour=tide_start_hour, tide_start_minute=tide_start_minute, tide_end_hour=tide_end_hour, tide_end_minute=tide_end_minute, tide_every_secs=tide_save_secs, save_helper=self._helper) self._no_first_save = no_first_save self._save_graph_def = save_graph_def self._ignore_save_errors = ignore_save_errors self._is_dense_only = is_dense_only self._use_native_multi_hash_table = use_native_multi_hash_table self._guard_saver_listeners = guard_saver_listeners or () self._mcli = cli.get_cli(utils.get_metric_prefix()) # Used to protect after_run, which may be called concurrently self._l = threading.Lock() @property def timer(self): return self._timer # Make sure this hook run after restore hook. def after_create_session(self, session, coord): super().after_create_session(session, coord) if self._save_graph_def: self._get_saver().export_meta_graph( utils.get_meta_graph_file_name(self._checkpoint_dir)) if isinstance(self._saver, PartialRecoverySaver): self._saver.setup_ps_initialized_state(session) self._create_or_update_monolith_ckpt_state(do_update=False) def trigger_save(self, session, ignore_save_errors=False): # There might be some concurency issue but should be fine # Since our goal is only to do the save. with self._l: self._timer.reset() run_context = tf.estimator.SessionRunContext((), session) run_values = tf.estimator.SessionRunValues(0, None, None) # this function may be called in the different thread. with session.graph.as_default(): with self._l: old_value = self._ignore_save_errors try: # Always throw error in this case. self._ignore_save_errors = ignore_save_errors super().after_run(run_context, run_values) finally: self._ignore_save_errors = old_value def after_run(self, run_context, run_values): with self._l: super().after_run(run_context, run_values) def _save(self, session, step: int) -> bool: if self._no_first_save: self._no_first_save = False return False # skip if full ckpt has happened in this step. # [todo] maybe bug when there are more saver hooks if self._is_dense_only: mode = 'dense_only' if step == self._last_triggered_step: return False else: mode = 'full' tags = {"mode": mode} try: for l in self._guard_saver_listeners: try: l.before_save(session, step) except: logging.error(traceback.format_exc()) for retries in range(2): try: start_time = time.time() should_stop = super()._save(session, step) self._last_triggered_step = step self._create_or_update_monolith_ckpt_state(do_update=True) end_time = time.time() logging.info("saving checkpoint took %f seconds", end_time - start_time) self._mcli.emit_counter("save_checkpoint", 1, tags) self._mcli.emit_timer("save_checkpoint_time", end_time - start_time, tags) return should_stop except tf.errors.OpError as op_error: self._mcli.emit_counter("save_checkpoint_failed", 1, tags) logging.error("Failed to save, retrying ...\n%s", traceback.format_exc()) catched_error = op_error continue finally: for l in reversed(self._guard_saver_listeners): try: l.after_save(session, step) except: logging.error(traceback.format_exc()) if self._ignore_save_errors: return False raise catched_error def _create_or_update_monolith_ckpt_state(self, do_update=False): # only save ckpt state if save_graph_def if not self._save_graph_def: return ckpt_state = get_monolith_checkpoint_state(self._checkpoint_dir, remove_invalid_path=True) if ckpt_state is None: logging.info("Create new monolith ckpt state") ckpt_state = MonolithCheckpointState() if self._use_native_multi_hash_table: ckpt_state.builtin_hash_table_type = MonolithCheckpointState.HashTableType.MULTI_CUCKOO_HASH_MAP else: ckpt_state.builtin_hash_table_type = MonolithCheckpointState.HashTableType.CUCKOO_HASH_MAP elif do_update is False: return else: logging.info("Update new monolith ckpt state") ckpt_state.last_checkpoint_save_timestamp = int(time.time()) monolith_checkpoint_filename = os.path.join(self._checkpoint_dir, MONOLITH_CKPT_STATE_FILE_NAME) file_io.atomic_write_string_to_file(monolith_checkpoint_filename, text_format.MessageToString(ckpt_state)) logging.info("monolith ckpt state saved") def end(self, session): last_step = session.run(self._global_step_tensor) if self._is_dense_only: pass elif self._has_dense_only or self._in_model_dump_mode: # force save self._timer.update_last_triggered_step(last_step) super()._save(session, last_step) for l in self._listeners: l.end(session, last_step) else: super().end(session) class PsMonitor(): """A monitor that use to detect ps status.""" def __init__(self, ps_num): self._queues = {} self._enqueue_ops = {} self._qsize_ops = {} for i in range(ps_num): device = utils.ps_device(i) with tf.device(device): queue = tf.queue.FIFOQueue(1, tf.int32, shared_name="ps_monitor_" + str(i)) self._queues[device] = queue self._enqueue_ops[device] = queue.enqueue(1) self._qsize_ops[device] = queue.size() def is_ps_uninitialized(self, sess, device): if device in self._qsize_ops: return sess.run(self._qsize_ops[device]) == 0 return True def setup_ps_initialized_state(self, sess): for device in self._queues.keys(): if sess.run(self._qsize_ops[device]) == 0: sess.run(self._enqueue_ops[device]) class SaverBuilder(tf_saver.BulkSaverBuilder): """SaverBuilder with support for partial recovery. Collect restore ops for each device. """ def _AddShardedRestoreOps(self, filename_tensor, per_device, restore_sequentially, reshape): """Add Ops to restore variables from multiple devices. Args: filename_tensor: Tensor for the path of the file to load. per_device: A list of (device, SaveableObject) pairs, as returned by _GroupByDevices(). restore_sequentially: True if we want to restore variables sequentially within a shard. reshape: True if we want to reshape loaded tensors to the shape of the corresponding variable. Returns: An Operation that restores the variables. """ sharded_restores = [] self._restore_ops_per_device = collections.defaultdict(list) for shard, (device, saveables) in enumerate(per_device): with tf.device(device): restore_op = self._AddRestoreOps(filename_tensor, saveables, restore_sequentially, reshape, preferred_shard=shard, name="restore_shard") sharded_restores.append(restore_op) self._restore_ops_per_device[device].append(restore_op) for device, restore_ops in self._restore_ops_per_device.items(): self._restore_ops_per_device[device] = tf.group(*restore_ops, name="restore_per_device") return tf.group(*sharded_restores, name="restore_all") @property def restore_ops_per_device(self): """Return restore ops per device.""" if hasattr(self, '_restore_ops_per_device'): return self._restore_ops_per_device return {} # Copy from tensorflow/python/training/saver.py. The major change is in restore function. # Apply partial recovery of dense part and hash table when ps_monitor is enabled. # TODO(xujinghao): Implement partial recovery of hash filter if needed. class PartialRecoverySaver(): """Saves and restores variables. See [Variables](https://tensorflow.org/guide/variables) for an overview of variables, saving and restoring. The `Saver` class adds ops to save and restore variables to and from *checkpoints*. It also provides convenience methods to run these ops. Checkpoints are binary files in a proprietary format which map variable names to tensor values. The best way to examine the contents of a checkpoint is to load it using a `Saver`. Savers can automatically number checkpoint filenames with a provided counter. This lets you keep multiple checkpoints at different steps while training a model. For example you can number the checkpoint filenames with the training step number. To avoid filling up disks, savers manage checkpoint files automatically. For example, they can keep only the N most recent files, or one checkpoint for every N hours of training. You number checkpoint filenames by passing a value to the optional `global_step` argument to `save()`: ```python saver.save(sess, 'my-model', global_step=0) ==> filename: 'my-model-0' ... saver.save(sess, 'my-model', global_step=1000) ==> filename: 'my-model-1000' ``` Additionally, optional arguments to the `Saver()` constructor let you control the proliferation of checkpoint files on disk: * `max_to_keep` indicates the maximum number of recent checkpoint files to keep. As new files are created, older files are deleted. If None or 0, no checkpoints are deleted from the filesystem but only the last one is kept in the `checkpoint` file. Defaults to 5 (that is, the 5 most recent checkpoint files are kept.) * `keep_checkpoint_every_n_hours`: In addition to keeping the most recent `max_to_keep` checkpoint files, you might want to keep one checkpoint file for every N hours of training. This can be useful if you want to later analyze how a model progressed during a long training session. For example, passing `keep_checkpoint_every_n_hours=2` ensures that you keep one checkpoint file for every 2 hours of training. The default value of 10,000 hours effectively disables the feature. Note that you still have to call the `save()` method to save the model. Passing these arguments to the constructor will not save variables automatically for you. A training program that saves regularly looks like: ```python ... # Create a saver. saver = tf.compat.v1.train.Saver(...variables...) # Launch the graph and train, saving the model every 1,000 steps. sess = tf.compat.v1.Session() for step in xrange(1000000): sess.run(..training_op..) if step % 1000 == 0: # Append the step number to the checkpoint name: saver.save(sess, 'my-model', global_step=step) ``` In addition to checkpoint files, savers keep a protocol buffer on disk with the list of recent checkpoints. This is used to manage numbered checkpoint files and by `latest_checkpoint()`, which makes it easy to discover the path to the most recent checkpoint. That protocol buffer is stored in a file named 'checkpoint' next to the checkpoint files. If you create several savers, you can specify a different filename for the protocol buffer file in the call to `save()`. """ def __init__(self, var_list=None, reshape=False, sharded=False, max_to_keep=5, keep_checkpoint_every_n_hours=10000.0, name=None, restore_sequentially=False, saver_def=None, builder=None, defer_build=False, allow_empty=False, pad_step_number=False, save_relative_paths=False, filename=None, ps_monitor=None, exempt_checkpoint_paths=None, skip_save=False, model_dir=None): """Creates a `Saver`. The constructor adds ops to save and restore variables. `var_list` specifies the variables that will be saved and restored. It can be passed as a `dict` or a list: * A `dict` of names to variables: The keys are the names that will be used to save or restore the variables in the checkpoint files. * A list of variables: The variables will be keyed with their op name in the checkpoint files. For example: ```python v1 = tf.Variable(..., name='v1') v2 = tf.Variable(..., name='v2') # Pass the variables as a dict: saver = tf.compat.v1.train.Saver({'v1': v1, 'v2': v2}) # Or pass them as a list. saver = tf.compat.v1.train.Saver([v1, v2]) # Passing a list is equivalent to passing a dict with the variable op names # as keys: saver = tf.compat.v1.train.Saver({v.op.name: v for v in [v1, v2]}) ``` Note: the newer `AutoTrackable` API is not supported by `Saver`. In this case, the `tf.train.Checkpoint` class should be used. The optional `reshape` argument, if `True`, allows restoring a variable from a save file where the variable had a different shape, but the same number of elements and type. This is useful if you have reshaped a variable and want to reload it from an older checkpoint. The optional `sharded` argument, if `True`, instructs the saver to shard checkpoints per device. Args: var_list: A list of `Variable`/`SaveableObject`, or a dictionary mapping names to `SaveableObject`s. If `None`, defaults to the list of all saveable objects. reshape: If `True`, allows restoring parameters from a checkpoint where the variables have a different shape. sharded: If `True`, shard the checkpoints, one per device. max_to_keep: Maximum number of recent checkpoints to keep. Defaults to 5. keep_checkpoint_every_n_hours: How often to keep checkpoints. Defaults to 10,000 hours. name: String. Optional name to use as a prefix when adding operations. restore_sequentially: A `Bool`, which if true, causes restore of different variables to happen sequentially within each device. This can lower memory usage when restoring very large models. saver_def: Optional `SaverDef` proto to use instead of running the builder. This is only useful for specialty code that wants to recreate a `Saver` object for a previously built `Graph` that had a `Saver`. The `saver_def` proto should be the one returned by the `as_saver_def()` call of the `Saver` that was created for that `Graph`. builder: Optional `SaverBuilder` to use if a `saver_def` was not provided. Defaults to `BulkSaverBuilder()`. defer_build: If `True`, defer adding the save and restore ops to the `build()` call. In that case `build()` should be called before finalizing the graph or using the saver. allow_empty: If `False` (default) raise an error if there are no variables in the graph. Otherwise, construct the saver anyway and make it a no-op. pad_step_number: if True, pads the global step number in the checkpoint filepaths to some fixed width (8 by default). This is turned off by default. save_relative_paths: If `True`, will write relative paths to the checkpoint state file. This is needed if the user wants to copy the checkpoint directory and reload from the copied directory. filename: If known at graph construction time, filename used for variable loading/saving. Raises: TypeError: If `var_list` is invalid. ValueError: If any of the keys or values in `var_list` are not unique. RuntimeError: If eager execution is enabled and`var_list` does not specify a list of variables to save. @compatibility(eager) When eager execution is enabled, `var_list` must specify a `list` or `dict` of variables to save. Otherwise, a `RuntimeError` will be raised. Although Saver works in some cases when executing eagerly, it is fragile. Please switch to `tf.train.Checkpoint` or `tf.keras.Model.save_weights`, which perform a more robust object-based saving. These APIs will load checkpoints written by `Saver`. @end_compatibility """ if defer_build and var_list: raise ValueError( "If `var_list` is provided then build cannot be deferred. " "Either set defer_build=False or var_list=None.") if tf.executing_eagerly(): logging.warning( "Saver is deprecated, please switch to tf.train.Checkpoint or " "tf.keras.Model.save_weights for training checkpoints. When " "executing eagerly variables do not necessarily have unique names, " "and so the variable.name-based lookups Saver performs are " "error-prone.") if var_list is None: raise RuntimeError( "When eager execution is enabled, `var_list` must specify a list " "or dict of variables to save") self._var_list = var_list self._reshape = reshape self._sharded = sharded self._max_to_keep = max_to_keep self._keep_checkpoint_every_n_hours = keep_checkpoint_every_n_hours self._name = name self._restore_sequentially = restore_sequentially self.saver_def = saver_def self._builder = builder self._is_built = False self._allow_empty = allow_empty self._is_empty = None self._write_version = saver_pb2.SaverDef.V2 self._pad_step_number = pad_step_number self._filename = filename self._last_checkpoints = [] self._checkpoints_to_be_deleted = [] self._exempt_checkpoint_paths = set(exempt_checkpoint_paths or []) self._model_dir = model_dir self._skip_save = skip_save if tf.executing_eagerly(): self._next_checkpoint_time = (time.time() + self._keep_checkpoint_every_n_hours * 3600) elif not defer_build: self.build() if self.saver_def: self._check_saver_def() self._write_version = self.saver_def.version self._save_relative_paths = save_relative_paths # For compatibility with object-based checkpoints, we may build a second # Saver to read the renamed keys. self._object_restore_saver = None self._ps_monitor = ps_monitor def build(self): if tf.executing_eagerly(): raise RuntimeError("Use save/restore instead of build in eager mode.") self._build(self._filename, build_save=True, build_restore=True) def _build_eager(self, checkpoint_path, build_save, build_restore): self._build(checkpoint_path, build_save=build_save, build_restore=build_restore) def _build(self, checkpoint_path, build_save, build_restore): """Builds saver_def.""" if not tf.executing_eagerly(): if self._is_built: return self._is_built = True if not self.saver_def or tf.executing_eagerly(): if self._builder is None: self._builder = SaverBuilder(self._write_version) if self._var_list is None: # pylint: disable=protected-access self._var_list = variables._all_saveable_objects() if not self._var_list: if self._allow_empty: self._is_empty = True return else: raise ValueError("No variables to save") self._is_empty = False self.saver_def = self._builder._build_internal( # pylint: disable=protected-access self._var_list, reshape=self._reshape, sharded=self._sharded, max_to_keep=self._max_to_keep, keep_checkpoint_every_n_hours=self._keep_checkpoint_every_n_hours, name=self._name, restore_sequentially=self._restore_sequentially, filename=checkpoint_path, build_save=build_save, build_restore=build_restore) elif self.saver_def and self._name: # Since self._name is used as a name_scope by builder(), we are # overloading the use of this field to represent the "import_scope" as # well. self.saver_def.filename_tensor_name = ops.prepend_name_scope( self.saver_def.filename_tensor_name, self._name) self.saver_def.save_tensor_name = ops.prepend_name_scope( self.saver_def.save_tensor_name, self._name) self.saver_def.restore_op_name = ops.prepend_name_scope( self.saver_def.restore_op_name, self._name) self._check_saver_def() if not tf.executing_eagerly(): # Updates next checkpoint time. # Set in __init__ when executing eagerly. self._next_checkpoint_time = ( time.time() + self.saver_def.keep_checkpoint_every_n_hours * 3600) def _check_saver_def(self): if not isinstance(self.saver_def, saver_pb2.SaverDef): raise ValueError("saver_def must be a saver_pb2.SaverDef: %s" % self.saver_def) if not tf.executing_eagerly(): if not self.saver_def.save_tensor_name: raise ValueError("saver_def must specify the save_tensor_name: %s" % str(self.saver_def)) if not self.saver_def.restore_op_name: raise ValueError("saver_def must specify the restore_op_name: %s" % str(self.saver_def)) def _CheckpointFilename(self, p): """Returns the checkpoint filename given a `(filename, time)` pair. Args: p: (filename, time) pair. Returns: Checkpoint file name. """ name, _ = p return name def _RecordLastCheckpoint(self, latest_save_path): """Manages the list of the latest checkpoints.""" if not self.saver_def.max_to_keep: return # Remove first from list if the same name was used before. for p in self._last_checkpoints: if latest_save_path == self._CheckpointFilename(p): self._last_checkpoints.remove(p) # Append new path to list self._last_checkpoints.append((latest_save_path, time.time())) # If more than max_to_keep, remove oldest but exempt checkpoint. last_checkpoint_paths = set([ os.path.basename(self._CheckpointFilename(p)) for p in self._last_checkpoints ]) exempt_checkpoint_paths = last_checkpoint_paths & self.exempt_checkpoint_paths if len(self._last_checkpoints) - len( exempt_checkpoint_paths) > self.saver_def.max_to_keep: for p in self._last_checkpoints: filename = os.path.basename(self._CheckpointFilename(p)) if filename not in self.exempt_checkpoint_paths: self._checkpoints_to_be_deleted.append(p) self._last_checkpoints.remove(p) break def _MaybeDeleteOldCheckpoints(self, meta_graph_suffix="meta"): """Deletes old checkpoints if necessary. `self._checkpoints_to_be_deleted` is going to contain checkpoints that are over `max_to_keep`. They are going to be deleted. If `keep_checkpoint_every_n_hours` was specified, keep an additional checkpoint every `N` hours. For example, if `N` is 0.5, an additional checkpoint is kept for every 0.5 hours of training; if `N` is 10, an additional checkpoint is kept for every 10 hours of training. Args: meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'. """ if self._checkpoints_to_be_deleted: p = self._checkpoints_to_be_deleted.pop(0) # Do not delete the file if we keep_checkpoint_every_n_hours is set and we # have reached N hours of training. should_keep = p[1] > self._next_checkpoint_time if should_keep: self._next_checkpoint_time += ( self.saver_def.keep_checkpoint_every_n_hours * 3600) return # Otherwise delete the files. logging.info("Deleted checkpoint: %s.", self._CheckpointFilename(p)) for pathname in tf.io.gfile.glob(self._CheckpointFilename(p) + ".*"): try: tf.io.gfile.rmtree(pathname) except tf.errors.NotFoundError: logging.warning( "Hit NotFoundError when deleting '%s', possibly because another " "process/thread is also deleting/moving the same file", pathname) def as_saver_def(self): """Generates a `SaverDef` representation of this saver. Returns: A `SaverDef` proto. """ return self.saver_def @property def exempt_checkpoint_paths(self): if self._model_dir: monolith_ckpt_state = get_monolith_checkpoint_state( self._model_dir, remove_invalid_path=True) if monolith_ckpt_state and monolith_ckpt_state.exempt_model_checkpoint_paths: exempt_checkpoint_paths = [ os.path.basename(p) for p in monolith_ckpt_state.exempt_model_checkpoint_paths ] logging.info( 'New exempt checkpoint paths: {}'.format(exempt_checkpoint_paths)) self._exempt_checkpoint_paths = set(exempt_checkpoint_paths or []) else: logging.info("Get exempt checkpoint paths null") return self._exempt_checkpoint_paths @property def last_checkpoints(self): """List of not-yet-deleted checkpoint filenames. You can pass any of the returned values to `restore()`. Returns: A list of checkpoint filenames, sorted from oldest to newest. """ return list(self._CheckpointFilename(p) for p in self._last_checkpoints) def set_last_checkpoints_with_time(self, last_checkpoints_with_time): """Sets the list of old checkpoint filenames and timestamps. Args: last_checkpoints_with_time: A list of tuples of checkpoint filenames and timestamps. Raises: AssertionError: If last_checkpoints_with_time is not a list. """ assert isinstance(last_checkpoints_with_time, list) self._last_checkpoints = last_checkpoints_with_time def recover_last_checkpoints(self, checkpoint_paths): """Recovers the internal saver state after a crash. This method is useful for recovering the "self._last_checkpoints" state. Globs for the checkpoints pointed to by `checkpoint_paths`. If the files exist, use their mtime as the checkpoint timestamp. Args: checkpoint_paths: a list of checkpoint paths. """ checkpoints_with_mtimes = [] for checkpoint_path in checkpoint_paths: try: mtime = checkpoint_management.get_checkpoint_mtimes([checkpoint_path]) except tf.errors.NotFoundError: # It's fine if some other thread/process is deleting some older # checkpoint concurrently. continue if mtime: checkpoints_with_mtimes.append((checkpoint_path, mtime[0])) self.set_last_checkpoints_with_time(checkpoints_with_mtimes) logging.info("Recover last checkpoints result: {}".format( self.last_checkpoints)) def save(self, sess, save_path, global_step=None, latest_filename=None, meta_graph_suffix="meta", write_meta_graph=True, write_state=True, strip_default_attrs=False, save_debug_info=False): # pylint: disable=line-too-long """Saves variables. This method runs the ops added by the constructor for saving variables. It requires a session in which the graph was launched. The variables to save must also have been initialized. The method returns the path prefix of the newly created checkpoint files. This string can be passed directly to a call to `restore()`. Args: sess: A Session to use to save the variables. save_path: String. Prefix of filenames created for the checkpoint. global_step: If provided the global step number is appended to `save_path` to create the checkpoint filenames. The optional argument can be a `Tensor`, a `Tensor` name or an integer. latest_filename: Optional name for the protocol buffer file that will contains the list of most recent checkpoints. That file, kept in the same directory as the checkpoint files, is automatically managed by the saver to keep track of recent checkpoints. Defaults to 'checkpoint'. meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'. write_meta_graph: `Boolean` indicating whether or not to write the meta graph file. write_state: `Boolean` indicating whether or not to write the `CheckpointStateProto`. strip_default_attrs: Boolean. If `True`, default-valued attributes will be removed from the NodeDefs. For a detailed guide, see [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). save_debug_info: If `True`, save the GraphDebugInfo to a separate file, which in the same directory of save_path and with `_debug` added before the file extension. This is only enabled when `write_meta_graph` is `True` Returns: A string: path prefix used for the checkpoint files. If the saver is sharded, this string ends with: '-?????-of-nnnnn' where 'nnnnn' is the number of shards created. If the saver is empty, returns None. Raises: TypeError: If `sess` is not a `Session`. ValueError: If `latest_filename` contains path components, or if it collides with `save_path`. RuntimeError: If save and restore ops weren't built. """ # pylint: enable=line-too-long if not self._is_built and not tf.executing_eagerly(): raise RuntimeError( "`build()` should be called before save if defer_build==True") if self._skip_save: return None if latest_filename is None: latest_filename = "checkpoint" if os.path.split(latest_filename)[0]: raise ValueError("'latest_filename' must not contain path components") save_path = tf.compat.as_str(save_path) if global_step is not None: if not isinstance(global_step, tf.compat.integral_types): global_step = tf.compat.v1.train.global_step(sess, global_step) checkpoint_file = "%s-%d" % (save_path, global_step) if self._pad_step_number: # Zero-pads the step numbers, so that they are sorted when listed. checkpoint_file = "%s-%s" % (save_path, "{:08d}".format(global_step)) else: checkpoint_file = save_path if os.path.basename(save_path) == latest_filename and not self._sharded: # Guard against collision between data file and checkpoint state file. raise ValueError( "'latest_filename' collides with 'save_path': '%s' and '%s'" % (latest_filename, save_path)) if (not tf.executing_eagerly() and not isinstance(sess, session.SessionInterface)): raise TypeError("'sess' must be a Session; %s" % sess) save_path_parent = os.path.dirname(save_path) if not self._is_empty: try: if tf.executing_eagerly(): self._build_eager(checkpoint_file, build_save=True, build_restore=False) model_checkpoint_path = self.saver_def.save_tensor_name else: model_checkpoint_path = sess.run( self.saver_def.save_tensor_name, {self.saver_def.filename_tensor_name: checkpoint_file}) model_checkpoint_path = tf.compat.as_str(model_checkpoint_path) if write_state: self._RecordLastCheckpoint(model_checkpoint_path) checkpoint_management.update_checkpoint_state_internal( save_dir=save_path_parent, model_checkpoint_path=model_checkpoint_path, all_model_checkpoint_paths=self.last_checkpoints, latest_filename=latest_filename, save_relative_paths=self._save_relative_paths) self._MaybeDeleteOldCheckpoints(meta_graph_suffix=meta_graph_suffix) except (tf.errors.FailedPreconditionError, tf.errors.NotFoundError) as exc: if not tf.io.gfile.isdir(save_path_parent): exc = ValueError( "Parent directory of {} doesn't exist, can't save.".format( save_path)) raise exc if write_meta_graph: meta_graph_filename = checkpoint_management.meta_graph_filename( checkpoint_file, meta_graph_suffix=meta_graph_suffix) if not tf.executing_eagerly(): with sess.graph.as_default(): self.export_meta_graph(meta_graph_filename, strip_default_attrs=strip_default_attrs, save_debug_info=save_debug_info) if self._is_empty: return None else: return model_checkpoint_path def export_meta_graph(self, filename=None, collection_list=None, as_text=False, export_scope=None, clear_devices=False, clear_extraneous_savers=False, strip_default_attrs=False, save_debug_info=False): # pylint: disable=line-too-long """Writes `MetaGraphDef` to save_path/filename. Args: filename: Optional meta_graph filename including the path. collection_list: List of string keys to collect. as_text: If `True`, writes the meta_graph as an ASCII proto. export_scope: Optional `string`. Name scope to remove. clear_devices: Whether or not to clear the device field for an `Operation` or `Tensor` during export. clear_extraneous_savers: Remove any Saver-related information from the graph (both Save/Restore ops and SaverDefs) that are not associated with this Saver. strip_default_attrs: Boolean. If `True`, default-valued attributes will be removed from the NodeDefs. For a detailed guide, see [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). save_debug_info: If `True`, save the GraphDebugInfo to a separate file, which in the same directory of filename and with `_debug` added before the file extension. Returns: A `MetaGraphDef` proto. """ # pylint: enable=line-too-long return tf_saver.export_meta_graph( filename=filename, graph_def=tf.compat.v1.get_default_graph().as_graph_def( add_shapes=True), saver_def=self.saver_def, collection_list=collection_list, as_text=as_text, export_scope=export_scope, clear_devices=clear_devices, clear_extraneous_savers=clear_extraneous_savers, strip_default_attrs=strip_default_attrs, save_debug_info=save_debug_info) def _origin_restore(self, sess, save_path): """Restores previously saved variables. This method runs the ops added by the constructor for restoring variables. It requires a session in which the graph was launched. The variables to restore do not have to have been initialized, as restoring is itself a way to initialize variables. The `save_path` argument is typically a value previously returned from a `save()` call, or a call to `latest_checkpoint()`. Args: sess: A `Session` to use to restore the parameters. None in eager mode. save_path: Path where parameters were previously saved. Raises: ValueError: If save_path is None or not a valid checkpoint. """ if self._is_empty: return if save_path is None: raise ValueError("Can't load save_path when it is None.") checkpoint_prefix = tf.compat.as_text(save_path) if not checkpoint_management.checkpoint_exists_internal(checkpoint_prefix): raise ValueError("The passed save_path is not a valid checkpoint: " + checkpoint_prefix) logging.info("Restoring parameters from %s", checkpoint_prefix) try: if tf.executing_eagerly(): self._build_eager(save_path, build_save=False, build_restore=True) else: # At some local case, restore_ops_per_device is empty. if not self._builder.restore_ops_per_device: sess.run(self.saver_def.restore_op_name, {self.saver_def.filename_tensor_name: save_path}) else: restore_ops = [] for device, restore_op in self._builder.restore_ops_per_device.items( ): if not self._ps_monitor or self._ps_monitor.is_ps_uninitialized( sess, device): restore_ops.append(restore_op) sess.run(restore_ops, {self.saver_def.filename_tensor_name: save_path}) except tf.errors.NotFoundError as err: # There are three common conditions that might cause this error: # 0. The file is missing. We ignore here, as this is checked above. # 1. This is an object-based checkpoint trying name-based loading. # 2. The graph has been altered and a variable or other name is missing. # 1. The checkpoint would not be loaded successfully as is. Try to parse # it as an object-based checkpoint. try: names_to_keys = tf_saver.object_graph_key_mapping(save_path) except tf.errors.NotFoundError: # 2. This is not an object-based checkpoint, which likely means there # is a graph mismatch. Re-raise the original error with # a helpful message (b/110263146) raise tf_saver._wrap_restore_error_with_msg( err, "a Variable name or other graph key that is missing") # This is an object-based checkpoint. We'll print a warning and then do # the restore. logging.warning( "Restoring an object-based checkpoint using a name-based saver. This " "may be somewhat fragile, and will re-build the Saver. Instead, " "consider loading object-based checkpoints using " "tf.train.Checkpoint().") self._object_restore_saver = tf_saver.saver_from_object_based_checkpoint( checkpoint_path=save_path, var_list=self._var_list, builder=self._builder, names_to_keys=names_to_keys, cached_saver=self._object_restore_saver) self._object_restore_saver.restore(sess=sess, save_path=save_path) except tf.errors.InvalidArgumentError as err: # There is a mismatch between the graph and the checkpoint being loaded. # We add a more reasonable error message here to help users (b/110263146) raise tf_saver._wrap_restore_error_with_msg( err, "a mismatch between the current graph and the graph") def restore(self, sess, save_path): if is_exporting() or sess is None: logging.info('is_exporting or sess is None, fall back to origin_restore') return self._origin_restore(sess, save_path) checkpoint_state = None logging.info(f"save_path is {save_path}") model_dir = os.path.dirname(save_path) try: checkpoint_state = tf.train.get_checkpoint_state(checkpoint_dir=model_dir) except Exception as e: logging.info( f'no checkpoint file in {model_dir}, fall back to origin_restore') return self._origin_restore(sess, save_path) if checkpoint_state is None: logging.info(f'checkpoint_state is None, fall back to origin_restore') return self._origin_restore(sess, save_path) graph: tf.Graph = None try: graph: tf.Graph = sess.graph except Exception as e: logging.info("the eager mode has no attribute graph") return self._origin_restore(sess, save_path) if not graph: logging.info("graph is None, pls. check! fall back to origin_restore") return self._origin_restore(sess, save_path) init_objs = graph.get_collection(CUSTOM_RESTORE_OP) if init_objs: init_ops, placeholders, alias_map = init_objs[0] if alias_map: # alias_map: new -> old ckpt: CheckpointReader = NewCheckpointReader(save_path) feed_dict = calc_feed_dict(ckpt, alias_map, placeholders) if feed_dict: sess.run(init_ops, feed_dict=feed_dict) else: self._origin_restore(sess, save_path) elif init_ops: assert alias_map is None or len(alias_map) == 0 restore_dir = os.path.dirname(save_path) model_dir = restore_dir if hasattr(init_ops[0], 'model_dir'): model_dir_tmp = getattr(init_ops[0], 'model_dir') if model_dir_tmp: model_dir = model_dir_tmp flag_file = os.path.join(model_dir, 'clear_nn') logging.info(f'the clear nn flag_file is {flag_file}') if tf.io.gfile.exists(flag_file): logging.info( 'clear nn flag_file exists, restore from ckpt, do not clear nn') self._origin_restore(sess, save_path) else: if len(init_ops) == 1: sess.run(init_ops) else: init_op, update_gs_op = init_ops sess.run(init_op) # update global_step to continue training ckpt: CheckpointReader = NewCheckpointReader(save_path) gs_tensor = ckpt.get_tensor('global_step') sess.run(update_gs_op, feed_dict={placeholders[0]: gs_tensor}) logging.info( f'update global_step to continue training, {gs_tensor}') logging.info('clear nn by calling init_op other than restore ...') with tf.io.gfile.GFile(flag_file, 'w') as ostream: ostream.write(file_content='clean nn done!') else: self._origin_restore(sess, save_path) with graph._lock: if CUSTOM_RESTORE_OP in graph._collections: del graph._collections[CUSTOM_RESTORE_OP] else: self._origin_restore(sess, save_path) def setup_ps_initialized_state(self, sess): if self._ps_monitor: self._ps_monitor.setup_ps_initialized_state(sess) ================================================ FILE: monolith/native_training/save_utils_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== from freezegun import freeze_time from unittest import mock import numpy as np import os import six import time import tensorflow as tf from tensorflow.core.protobuf import config_pb2 from tensorflow.python.data.ops import iterator_ops from tensorflow.python.eager import context from tensorflow.python.framework import test_util from tensorflow.python.lib.io import file_io from tensorflow.python.ops import resource_variable_ops from tensorflow.python.training import saver_test_utils from tensorflow.python.training import checkpoint_management from monolith.native_training import save_utils class SaveUtilsTest(tf.test.TestCase): def setUp(self): super().setUp() self._global_step = tf.compat.v1.train.get_or_create_global_step() self._saver = save_utils.PartialRecoverySaver( exempt_checkpoint_paths=['ckpt-10', 'ckpt-20'], max_to_keep=3) self._savepath = os.path.join( os.environ["TEST_TMPDIR"], type(self).__name__ + "_" + self._testMethodName, "ckpt") def create_test_ckpt(self, global_step_value: int): with self.session() as sess: self._global_step = tf.compat.v1.assign(self._global_step, global_step_value) sess.run(self._global_step) self._saver.save(sess, self._savepath, self._global_step) def test_get_ckpt_steps(self): helper = save_utils.SaveHelper(self._savepath) self.create_test_ckpt(10) self.create_test_ckpt(20) self.create_test_ckpt(300) ckpt_steps = helper.get_existing_checkpoint_steps() self.assertSetEqual(ckpt_steps, {10, 20, 300}) def test_exempt_checkpoints(self): helper = save_utils.SaveHelper(self._savepath) self.create_test_ckpt(10) self.create_test_ckpt(20) self.create_test_ckpt(30) self.create_test_ckpt(40) self.create_test_ckpt(50) ckpt_steps = helper.get_existing_checkpoint_steps() self.assertSetEqual(ckpt_steps, {10, 20, 30, 40, 50}) self.create_test_ckpt(60) ckpt_steps = helper.get_existing_checkpoint_steps() self.assertSetEqual(ckpt_steps, {10, 20, 40, 50, 60}) class SaverHookTest(tf.test.TestCase): def get_ckpt_dir(self): return os.path.join(os.environ["TEST_TMPDIR"], type(self).__name__ + "_" + self._testMethodName) def test_basic(self): ckpt_dir = self.get_ckpt_dir() global_step = tf.compat.v1.train.get_or_create_global_step() global_step = tf.compat.v1.assign_add(global_step, 1) with tf.compat.v1.train.SingularMonitoredSession([ save_utils.NoFirstSaveCheckpointSaverHook(checkpoint_dir=ckpt_dir, save_steps=100) ]) as sess: sess.run(global_step) print(tf.io.gfile.glob(os.path.join(ckpt_dir, "*"))) # Will not save at beginning. self.assertAllEqual( tf.io.gfile.glob(os.path.join(ckpt_dir, "model.ckpt-0\\.*")), []) # Will save after when session is closed. self.assertGreater( len(tf.io.gfile.glob(os.path.join(ckpt_dir, "model.ckpt-1\\.*"))), 0) def test_op_error(self): class AssertFailListener(tf.estimator.CheckpointSaverListener): def __init__(self): self._assert_op = tf.debugging.Assert(False, [False]) def before_save(self, session, global_step_value): session.run(self._assert_op) class FinallyAfterSaveListener(tf.estimator.CheckpointSaverListener): def __init__(self): self._called = False @property def called(self): return self._called def after_save(self, session, global_step_value): self._called = True ckpt_dir = self.get_ckpt_dir() global_step = tf.compat.v1.train.get_or_create_global_step() l = FinallyAfterSaveListener() with tf.compat.v1.train.SingularMonitoredSession([ save_utils.NoFirstSaveCheckpointSaverHook( checkpoint_dir=ckpt_dir, save_steps=100, listeners=[AssertFailListener()], guard_saver_listeners=[l], ignore_save_errors=True, no_first_save=False) ]) as sess: pass self.assertTrue(l.called) def test_trigger_save(self): ckpt_dir = self.get_ckpt_dir() global_step = tf.compat.v1.train.get_or_create_global_step() global_step = tf.compat.v1.assign_add(global_step, 1) h = save_utils.NoFirstSaveCheckpointSaverHook(checkpoint_dir=ckpt_dir, save_steps=100) with tf.compat.v1.train.SingularMonitoredSession([h]) as sess: sess.run(global_step) h.trigger_save(sess.raw_session()) self.assertGreater( len(tf.io.gfile.glob(os.path.join(ckpt_dir, "model.ckpt-1\\.*"))), 0) class SaverTest(tf.test.TestCase): def basicSaveRestore(self, variable_op): save_path = os.path.join(self.get_temp_dir(), "basic_save_restore") with self.session(graph=tf.Graph()) as sess: # Build a graph with 2 parameter nodes, and Save and # Restore nodes for them. v0 = variable_op(10.0, name="v0") v1 = variable_op(20.0, name="v1") v2 = saver_test_utils.CheckpointedOp(name="v2") v2_init = v2.insert("k1", 30.0) # Initialize all variables if not tf.executing_eagerly(): self.evaluate([tf.compat.v1.global_variables_initializer(), v2_init]) # Check that the parameter nodes have been initialized. self.assertEqual(10.0, self.evaluate(v0)) self.assertEqual(20.0, self.evaluate(v1)) self.assertEqual(b"k1", self.evaluate(v2.keys())) self.assertEqual(30.0, self.evaluate(v2.values())) # Save the initialized values in the file at "save_path" save = save_utils.PartialRecoverySaver( { "v0": v0, "v1": v1, "v2": v2.saveable }, restore_sequentially=True) val = save.save(sess, save_path) self.assertTrue(isinstance(val, six.string_types)) self.assertEqual(save_path, val) # Start a second session. In that session the parameter nodes # have not been initialized either. with self.session(graph=tf.Graph()) as sess: v0 = variable_op(-1.0, name="v0") v1 = variable_op(-1.0, name="v1") v2 = saver_test_utils.CheckpointedOp(name="v2") # Assert that the variables are not initialized. if not tf.executing_eagerly(): self.assertEqual( len(tf.compat.v1.report_uninitialized_variables().eval()), 2) self.assertEqual(0, len(self.evaluate(v2.keys()))) self.assertEqual(0, len(self.evaluate(v2.values()))) # Restore the saved values in the parameter nodes. save = save_utils.PartialRecoverySaver({ "v0": v0, "v1": v1, "v2": v2.saveable }) save.restore(sess, save_path) # Check that the parameter nodes have been restored. self.assertEqual(10.0, self.evaluate(v0)) self.assertEqual(20.0, self.evaluate(v1)) self.assertEqual(b"k1", self.evaluate(v2.keys())) self.assertEqual(30.0, self.evaluate(v2.values())) # Build another graph with 2 nodes, initialized # differently, and a Restore node for them. with self.session(graph=tf.Graph()) as sess: v0_2 = variable_op(1000.0, name="v0") v1_2 = variable_op(2000.0, name="v1") v2_2 = saver_test_utils.CheckpointedOp(name="v2") v2_init = v2_2.insert("k1000", 3000.0) # Check that the parameter nodes have been initialized. if not tf.executing_eagerly(): init_all_op = [tf.compat.v1.global_variables_initializer(), v2_init] self.evaluate(init_all_op) # TODO(xpan): Why _mutable_hash_table_v2 doesn't create empty # table as it claims in eager mode? self.assertEqual(b"k1000", self.evaluate(v2_2.keys())) self.assertEqual(3000.0, self.evaluate(v2_2.values())) self.assertEqual(1000.0, self.evaluate(v0_2)) self.assertEqual(2000.0, self.evaluate(v1_2)) # Restore the values saved earlier in the parameter nodes. save2 = save_utils.PartialRecoverySaver({ "v0": v0_2, "v1": v1_2, "v2": v2_2.saveable }) save2.restore(sess, save_path) # Check that the parameter nodes have been restored. self.assertEqual(10.0, self.evaluate(v0_2)) self.assertEqual(20.0, self.evaluate(v1_2)) self.assertEqual(b"k1", self.evaluate(v2_2.keys())) self.assertEqual(30.0, self.evaluate(v2_2.values())) def testBasic(self): self.basicSaveRestore(tf.Variable) def testSaveMaxToKeep(self): save_path = os.path.join(self.get_temp_dir(), "test_save_max_to_keep", "model.ckpt") with self.session(graph=tf.Graph()) as sess: # Build a graph with 2 parameter nodes, and Save and # Restore nodes for them. v0 = tf.Variable(10.0, name="v0") v1 = tf.Variable(20.0, name="v1") v2 = saver_test_utils.CheckpointedOp(name="v2") v2_init = v2.insert("k1", 30.0) # Initialize all variables if not tf.executing_eagerly(): self.evaluate([tf.compat.v1.global_variables_initializer(), v2_init]) # Check that the parameter nodes have been initialized. self.assertEqual(10.0, self.evaluate(v0)) self.assertEqual(20.0, self.evaluate(v1)) self.assertEqual(b"k1", self.evaluate(v2.keys())) self.assertEqual(30.0, self.evaluate(v2.values())) # Save the initialized values in the file at "save_path" save = save_utils.PartialRecoverySaver( { "v0": v0, "v1": v1, "v2": v2.saveable }, restore_sequentially=True, max_to_keep=2, exempt_checkpoint_paths=['model.ckpt-2']) val = save.save(sess, save_path, global_step=1) self.assertEqual(save_path + '-1', val) self.assertGreater(len(tf.io.gfile.glob(save_path + "-1\\.*")), 0) val = save.save(sess, save_path, global_step=2) self.assertEqual(save_path + '-2', val) self.assertGreater(len(tf.io.gfile.glob(save_path + "-1\\.*")), 0) self.assertGreater(len(tf.io.gfile.glob(save_path + "-2\\.*")), 0) val = save.save(sess, save_path, global_step=3) self.assertEqual(save_path + '-3', val) self.assertGreater(len(tf.io.gfile.glob(save_path + "-1\\.*")), 0) self.assertGreater(len(tf.io.gfile.glob(save_path + "-2\\.*")), 0) self.assertGreater(len(tf.io.gfile.glob(save_path + "-3\\.*")), 0) val = save.save(sess, save_path, global_step=4) self.assertEqual(save_path + '-4', val) self.assertEqual(len(tf.io.gfile.glob(save_path + "-1\\.*")), 0) self.assertGreater(len(tf.io.gfile.glob(save_path + "-2\\.*")), 0) self.assertGreater(len(tf.io.gfile.glob(save_path + "-3\\.*")), 0) self.assertGreater(len(tf.io.gfile.glob(save_path + "-4\\.*")), 0) save.restore(sess, val) # Check that the parameter nodes have been restored. self.assertEqual(10.0, self.evaluate(v0)) self.assertEqual(20.0, self.evaluate(v1)) self.assertEqual(b"k1", self.evaluate(v2.keys())) self.assertEqual(30.0, self.evaluate(v2.values())) @test_util.run_in_graph_and_eager_modes def testResourceBasic(self): self.basicSaveRestore(resource_variable_ops.ResourceVariable) def testResourceColocation(self): # train.Saver is V1 only API. with tf.Graph().as_default(): partitioner = tf.compat.v1.fixed_size_partitioner(num_shards=2) with tf.device("/job:ps/device:GPU:0"): v = tf.compat.v1.get_variable("v0", shape=[10, 2], partitioner=partitioner, use_resource=True) save_utils.PartialRecoverySaver({"v0": v}).build() save_op = None for op in tf.compat.v1.get_default_graph().get_operations(): if op.type == "SaveV2": save_op = op break assert save_op is not None for save_inp in save_op.inputs[3:]: # Input to SaveV2 op is placed on CPU of the same device as # the Variable. self.assertEqual("/job:ps/device:CPU:0", save_inp.device) def testResourceVariableReadOpsAddedDeterministically(self): graph_defs = [] num_graphs = 10 for _ in range(num_graphs): with tf.Graph().as_default() as g: for i in range(20): resource_variable_ops.ResourceVariable(i, name="var%s" % i) save_utils.PartialRecoverySaver() graph_defs.append(g.as_graph_def()) for i in range(num_graphs - 1): self.assertEqual(graph_defs[i], graph_defs[i + 1]) def testEagerBasic(self): with context.eager_mode(): ckpt_prefix = os.path.join(self.get_temp_dir(), "ckpt") v1 = resource_variable_ops.ResourceVariable(3.14, name="v1") v2 = resource_variable_ops.ResourceVariable([1, 2], name="v2") save = save_utils.PartialRecoverySaver([v1, v2]) save.save(None, ckpt_prefix) v1.assign(0.0) v2.assign([0, 0]) self.assertNear(0.0, self.evaluate(v1), 1e-5) self.assertAllEqual([0, 0], self.evaluate(v2)) save.restore(None, ckpt_prefix) self.assertNear(3.14, self.evaluate(v1), 1e-5) self.assertAllEqual([1, 2], self.evaluate(v2)) def testEagerGraphCompatibility(self): # Save from graph mode and restore from eager mode. graph_ckpt_prefix = os.path.join(self.get_temp_dir(), "graph_ckpt") with context.graph_mode(): with self.session(graph=tf.Graph()) as sess: # Create a graph model and save the checkpoint. w1 = resource_variable_ops.ResourceVariable(1.0, name="w1") w2 = resource_variable_ops.ResourceVariable(2.0, name="w2") graph_saver = save_utils.PartialRecoverySaver([w1, w2]) self.evaluate(tf.compat.v1.global_variables_initializer()) graph_saver.save(sess, graph_ckpt_prefix) with context.eager_mode(): tf.compat.v1.reset_default_graph() w1 = resource_variable_ops.ResourceVariable(0.0, name="w1") w2 = resource_variable_ops.ResourceVariable(0.0, name="w2") graph_saver = save_utils.PartialRecoverySaver([w1, w2]) graph_saver.restore(None, graph_ckpt_prefix) self.assertAllEqual(self.evaluate(w1), 1.0) self.assertAllEqual(self.evaluate(w2), 2.0) # Save from eager mode and restore from graph mode. eager_ckpt_prefix = os.path.join(self.get_temp_dir(), "eager_ckpt") with context.eager_mode(): tf.compat.v1.reset_default_graph() w3 = resource_variable_ops.ResourceVariable(3.0, name="w3") w4 = resource_variable_ops.ResourceVariable(4.0, name="w4") graph_saver = save_utils.PartialRecoverySaver([w3, w4]) graph_saver.save(None, eager_ckpt_prefix) with context.graph_mode(): with self.session(graph=tf.Graph()) as sess: w3 = resource_variable_ops.ResourceVariable(0.0, name="w3") w4 = resource_variable_ops.ResourceVariable(0.0, name="w4") graph_saver = save_utils.PartialRecoverySaver([w3, w4]) self.evaluate(tf.compat.v1.global_variables_initializer()) graph_saver.restore(sess, eager_ckpt_prefix) self.assertAllEqual(w3, 3.0) self.assertAllEqual(w4, 4.0) @test_util.run_in_graph_and_eager_modes def testResourceSaveRestoreCachingDevice(self): save_path = os.path.join(self.get_temp_dir(), "resource_cache") with self.session(graph=tf.Graph()) as sess: v = resource_variable_ops.ResourceVariable([1], caching_device="/cpu:0", name="v") if tf.executing_eagerly(): sess = None else: self.evaluate(tf.compat.v1.global_variables_initializer()) save = save_utils.PartialRecoverySaver([v]) save.save(sess, save_path) save2 = save_utils.PartialRecoverySaver([v]) save2.restore(sess, save_path) self.assertEqual(self.evaluate(v), [1]) def testNoAdditionalOpsAddedBySaverForResourceVariablesOutsideSaveScope(self): with tf.Graph().as_default() as g: v = resource_variable_ops.ResourceVariable(1.0, name="v") with tf.name_scope("saver1"): save_utils.PartialRecoverySaver() with tf.name_scope("saver2"): save_utils.PartialRecoverySaver({"name": v}) ops_in_saver1_scope_but_not_save_scope = [ op for op in g.get_operations() if (op.name.startswith("saver1/") and not op.name.startswith("saver1/save/")) ] self.assertEqual(ops_in_saver1_scope_but_not_save_scope, []) ops_in_saver2_scope_but_not_save_scope = [ op for op in g.get_operations() if (op.name.startswith("saver2/") and not op.name.startswith("saver2/save/")) ] self.assertEqual(ops_in_saver2_scope_but_not_save_scope, []) def testSaveCopyRestoreWithSaveRelativePaths(self): """Save, copy checkpoint dir and restore from copied dir. This only works for save_relative_paths=True. """ save_dir1 = os.path.join(self.get_temp_dir(), "save_dir1") os.mkdir(save_dir1) save_path1 = os.path.join(save_dir1, "save_copy_restore") # train.Saver is V1 only API. with tf.Graph().as_default(): # Build a graph with 2 parameter nodes, and Save and # Restore nodes for them. v0 = tf.compat.v1.Variable(10.0, name="v0") v1 = tf.compat.v1.Variable(20.0, name="v1") v2 = saver_test_utils.CheckpointedOp(name="v2") v2_init = v2.insert("k1", 30.0) save = save_utils.PartialRecoverySaver(var_list={ "v0": v0, "v1": v1, "v2": v2.saveable }, restore_sequentially=True, save_relative_paths=True) init_all_op = [tf.compat.v1.global_variables_initializer(), v2_init] with self.cached_session() as sess: # Initialize all variables self.evaluate(init_all_op) # Check that the parameter nodes have been initialized. self.assertEqual(10.0, self.evaluate(v0)) self.assertEqual(20.0, self.evaluate(v1)) self.assertEqual(b"k1", self.evaluate(v2.keys())) self.assertEqual(30.0, self.evaluate(v2.values())) # Save the initialized values in the file at "save_path" val = save.save(sess, save_path1) self.assertTrue(isinstance(val, six.string_types)) self.assertEqual(save_path1, val) self.assertEqual(tf.train.latest_checkpoint(save_dir1), save_path1) save_dir2 = os.path.join(self.get_temp_dir(), "save_dir2") os.renames(save_dir1, save_dir2) save_path2 = os.path.join(save_dir2, "save_copy_restore") self.assertEqual(tf.train.latest_checkpoint(save_dir2), save_path2) # Start a second session. In that session the parameter nodes # have not been initialized either. with self.cached_session() as sess: v0 = tf.compat.v1.Variable(-1.0, name="v0") v1 = tf.compat.v1.Variable(-1.0, name="v1") v2 = saver_test_utils.CheckpointedOp(name="v2") save = save_utils.PartialRecoverySaver({ "v0": v0, "v1": v1, "v2": v2.saveable }) # Assert that the variables are not initialized. self.assertEqual( len(tf.compat.v1.report_uninitialized_variables().eval()), 2) self.assertEqual(0, len(self.evaluate(v2.keys()))) self.assertEqual(0, len(self.evaluate(v2.values()))) # Restore the saved values in the parameter nodes. save.restore(sess, save_path2) # Check that the parameter nodes have been restored. self.assertEqual(10.0, self.evaluate(v0)) self.assertEqual(20.0, self.evaluate(v1)) self.assertEqual(b"k1", self.evaluate(v2.keys())) self.assertEqual(30.0, self.evaluate(v2.values())) def testFilenameTensor(self): # train.Saver is V1 only API. with tf.Graph().as_default(): v0 = tf.compat.v1.Variable(0, name="v0") filename = b"somerandomfilename" save = save_utils.PartialRecoverySaver({"v0": v0}, filename=filename) with self.cached_session() as sess: tensor = sess.graph.get_tensor_by_name( save.saver_def.filename_tensor_name) self.assertEqual(self.evaluate(tensor), filename) def testInvalidPath(self): v0 = tf.compat.v1.Variable(0, name="v0") with self.cached_session() as sess: save = save_utils.PartialRecoverySaver({"v0": v0}) with self.assertRaisesRegex( ValueError, "The passed save_path is not a valid checkpoint:"): save.restore(sess, "invalid path") @test_util.run_v1_only("train.Saver is V1 only API.") def testInt64(self): save_path = os.path.join(self.get_temp_dir(), "int64") with self.cached_session() as sess: # Build a graph with 1 node, and save and restore for them. v = tf.compat.v1.Variable(np.int64(15), name="v") save = save_utils.PartialRecoverySaver({"v": v}, restore_sequentially=True) self.evaluate(tf.compat.v1.global_variables_initializer()) # Save the initialized values in the file at "save_path" val = save.save(sess, save_path) self.assertTrue(isinstance(val, six.string_types)) self.assertEqual(save_path, val) with self.cached_session() as sess: v = tf.compat.v1.Variable(np.int64(-1), name="v") save = save_utils.PartialRecoverySaver({"v": v}) with self.assertRaisesWithPredicateMatch( tf.errors.OpError, lambda e: "uninitialized value v" in e.message): self.evaluate(v) # Restore the saved values in the parameter nodes. save.restore(sess, save_path) # Check that the parameter nodes have been restored. self.assertEqual(np.int64(15), self.evaluate(v)) def testSomeErrors(self): with tf.Graph().as_default(): v0 = tf.compat.v1.Variable([10.0], name="v0") v1 = tf.compat.v1.Variable([20.0], name="v1") v2 = tf.compat.v1.Variable([20.0], name="v2") v2._set_save_slice_info(tf.Variable.SaveSliceInfo("v1", [1], [0], [1])) # By default the name used for "v2" will be "v1" and raise an error. with self.assertRaisesRegex(ValueError, "same name: v1"): save_utils.PartialRecoverySaver([v0, v1, v2]) # The names are different and will work. save_utils.PartialRecoverySaver({"vee1": v1, "other": [v2]}) # Partitioned variables also cause name conflicts. p_v1 = tf.compat.v1.get_variable( "p_v1", shape=[4, 5], partitioner=tf.compat.v1.fixed_size_partitioner(num_shards=2)) p_v2 = tf.compat.v1.get_variable( "p_v2", shape=[4, 5], partitioner=tf.compat.v1.fixed_size_partitioner(num_shards=2)) p_v2._name = "p_v1" with self.assertRaisesRegex(ValueError, "same name: p_v1"): save_utils.PartialRecoverySaver([p_v1, p_v2]) def testSameName(self): with tf.Graph().as_default(): v0 = tf.compat.v1.Variable([10.0], name="v0") v2 = saver_test_utils.CheckpointedOp(name="v2") # Saving one variable under two names raises an error. with self.assertRaisesRegex( ValueError, "The same saveable will be restored with two names: v0"): save_utils.PartialRecoverySaver({"v0": v0, "v0too": v0}) # Ditto for custom saveables. with self.assertRaisesRegex( ValueError, "The same saveable will be restored with two names: v2"): save_utils.PartialRecoverySaver({ "v2": v2.saveable, "v2too": v2.saveable }) # Verify non-duplicate names work. save_utils.PartialRecoverySaver({"v0": v0, "v2": v2.saveable}) @test_util.run_v1_only("train.Saver and VariableV1 are V1 only APIs.") def testBasicsWithListOfVariables(self): save_path = os.path.join(self.get_temp_dir(), "basics_with_list") with self.session(graph=tf.Graph()) as sess: # Build a graph with 2 parameter nodes, and Save and # Restore nodes for them. v0 = tf.compat.v1.Variable(10.0, name="v0") v1 = tf.compat.v1.Variable(20.0, name="v1") v2 = saver_test_utils.CheckpointedOp(name="v2") v2_init = v2.insert("k1", 30.0) save = save_utils.PartialRecoverySaver([v0, v1, v2.saveable]) self.evaluate(tf.compat.v1.global_variables_initializer()) v2_init.run() # Check that the parameter nodes have been initialized. self.assertEqual(10.0, self.evaluate(v0)) self.assertEqual(20.0, self.evaluate(v1)) self.assertEqual(b"k1", self.evaluate(v2.keys())) self.assertEqual(30.0, self.evaluate(v2.values())) # Save the initialized values in the file at "save_path" val = save.save(sess, save_path) self.assertTrue(isinstance(val, six.string_types)) self.assertEqual(save_path, val) # Start a second session. In that session the variables # have not been initialized either. with self.session(graph=tf.Graph()) as sess: v0 = tf.compat.v1.Variable(-1.0, name="v0") v1 = tf.compat.v1.Variable(-1.0, name="v1") v2 = saver_test_utils.CheckpointedOp(name="v2") save = save_utils.PartialRecoverySaver([v0, v1, v2.saveable]) with self.assertRaisesWithPredicateMatch( tf.errors.OpError, lambda e: "uninitialized value v0" in e.message): self.evaluate(v0) with self.assertRaisesWithPredicateMatch( tf.errors.OpError, lambda e: "uninitialized value v1" in e.message): self.evaluate(v1) self.assertEqual(0, len(self.evaluate(v2.keys()))) self.assertEqual(0, len(self.evaluate(v2.values()))) # Restore the saved values in the parameter nodes. save.restore(sess, save_path) # Check that the parameter nodes have been restored. self.assertEqual(10.0, self.evaluate(v0)) self.assertEqual(20.0, self.evaluate(v1)) self.assertEqual(b"k1", self.evaluate(v2.keys())) self.assertEqual(30.0, self.evaluate(v2.values())) # Build another graph with 2 nodes, initialized # differently, and a Restore node for them. with self.session(graph=tf.Graph()) as sess: v0_2 = tf.compat.v1.Variable(1000.0, name="v0") v1_2 = tf.compat.v1.Variable(2000.0, name="v1") v2_2 = saver_test_utils.CheckpointedOp(name="v2") save2 = save_utils.PartialRecoverySaver([v0_2, v1_2, v2_2.saveable]) v2_2.insert("k1000", 3000.0).run() self.evaluate(tf.compat.v1.global_variables_initializer()) # Check that the parameter nodes have been initialized. self.assertEqual(1000.0, self.evaluate(v0_2)) self.assertEqual(2000.0, self.evaluate(v1_2)) self.assertEqual(b"k1000", self.evaluate(v2_2.keys())) self.assertEqual(3000.0, self.evaluate(v2_2.values())) # Restore the values saved earlier in the parameter nodes. save2.restore(sess, save_path) # Check that the parameter nodes have been restored. self.assertEqual(10.0, self.evaluate(v0_2)) self.assertEqual(20.0, self.evaluate(v1_2)) self.assertEqual(b"k1", self.evaluate(v2_2.keys())) self.assertEqual(30.0, self.evaluate(v2_2.values())) def _SaveAndLoad(self, var_name, var_value, other_value, save_path): with self.session(graph=tf.Graph()) as sess: var = resource_variable_ops.ResourceVariable(var_value, name=var_name) save = save_utils.PartialRecoverySaver({var_name: var}) if not tf.executing_eagerly(): self.evaluate(var.initializer) val = save.save(sess, save_path) self.assertEqual(save_path, val) with self.session(graph=tf.Graph()) as sess: var = resource_variable_ops.ResourceVariable(other_value, name=var_name) save = save_utils.PartialRecoverySaver({var_name: var}) save.restore(sess, save_path) self.assertAllClose(var_value, self.evaluate(var)) def testCacheRereadsFile(self): save_path = os.path.join(self.get_temp_dir(), "cache_rereads") # Save and reload one Variable named "var0". self._SaveAndLoad("var0", 0.0, 1.0, save_path) # Save and reload one Variable named "var1" in the same file. # The cached readers should know to re-read the file. self._SaveAndLoad("var1", 1.1, 2.2, save_path) def testAllowEmpty(self): save_path = os.path.join(self.get_temp_dir(), "allow_empty") # train.Saver is V1 only API. with tf.Graph().as_default(), self.cached_session() as sess: _ = tf.constant(1) save = save_utils.PartialRecoverySaver(allow_empty=True) val = save.save(sess, save_path) self.assertIsNone(val) with tf.Graph().as_default(), self.cached_session() as sess: save = save_utils.PartialRecoverySaver(allow_empty=True) save.restore(sess, save_path) def testGPU(self): if not tf.test.is_gpu_available(): return save_path = os.path.join(self.get_temp_dir(), "gpu") with tf.compat.v1.Session("", graph=tf.Graph()) as sess: with sess.graph.device(tf.test.gpu_device_name()): v0_1 = tf.compat.v1.Variable(123.45) save = save_utils.PartialRecoverySaver({"v0": v0_1}) self.evaluate(tf.compat.v1.global_variables_initializer()) save.save(sess, save_path) with tf.compat.v1.Session("", graph=tf.Graph()) as sess: with sess.graph.device(tf.test.gpu_device_name()): v0_2 = tf.compat.v1.Variable(543.21) save = save_utils.PartialRecoverySaver({"v0": v0_2}) self.evaluate(tf.compat.v1.global_variables_initializer()) def testSharedServerOnGPU(self): if not tf.test.is_gpu_available(): return save_path = os.path.join(self.get_temp_dir(), "gpu") with tf.compat.v1.Session("", graph=tf.Graph()) as sess: with sess.graph.device(tf.test.gpu_device_name()): v0_1 = tf.compat.v1.Variable(123.45) save = save_utils.PartialRecoverySaver({"v0": v0_1}, sharded=True, allow_empty=True) self.evaluate(tf.compat.v1.global_variables_initializer()) save.save(sess, save_path) with tf.compat.v1.Session("", graph=tf.Graph()) as sess: with sess.graph.device(tf.test.gpu_device_name()): v0_2 = tf.compat.v1.Variable(543.21) save = save_utils.PartialRecoverySaver({"v0": v0_2}, sharded=True, allow_empty=True) self.evaluate(tf.compat.v1.global_variables_initializer()) def testVariables(self): save_path = os.path.join(self.get_temp_dir(), "variables") with tf.compat.v1.Session("", graph=tf.Graph()) as sess: one = tf.compat.v1.Variable(1.0) twos = tf.compat.v1.Variable([2.0, 2.0, 2.0]) v2 = saver_test_utils.CheckpointedOp(name="v2") init = tf.compat.v1.global_variables_initializer() save = save_utils.PartialRecoverySaver() init.run() v2.insert("k1", 3.0).run() save.save(sess, save_path) with tf.compat.v1.Session("", graph=tf.Graph()) as sess: one = tf.compat.v1.Variable(0.0) twos = tf.compat.v1.Variable([0.0, 0.0, 0.0]) v2 = saver_test_utils.CheckpointedOp(name="v2") # Saver with no arg, defaults to 'all variables'. save = save_utils.PartialRecoverySaver() save.restore(sess, save_path) self.assertAllClose(1.0, self.evaluate(one)) self.assertAllClose([2.0, 2.0, 2.0], self.evaluate(twos)) self.assertEqual(b"k1", self.evaluate(v2.keys())) self.assertEqual(3.0, self.evaluate(v2.values())) def testVarListShouldBeEmptyInDeferredBuild(self): with tf.Graph().as_default(): v = tf.compat.v1.Variable(1.0) with self.assertRaisesRegex(ValueError, "defer_build"): save_utils.PartialRecoverySaver([v], defer_build=True) def testBuildShouldBeCalledBeforeSaveInCaseOfDeferBuild(self): save_path = os.path.join(self.get_temp_dir(), "error_deferred_build") with tf.Graph().as_default(), tf.compat.v1.Session() as sess: tf.compat.v1.Variable(1.0) saver = save_utils.PartialRecoverySaver(defer_build=True) with self.assertRaisesRegex(RuntimeError, "build"): saver.save(sess, save_path) def testDeferredBuild(self): save_path = os.path.join(self.get_temp_dir(), "deferred_build") with tf.compat.v1.Session("", graph=tf.Graph()) as sess: one = tf.compat.v1.Variable(1.0) save = save_utils.PartialRecoverySaver(defer_build=True) # if build is not deferred, saver cannot save the `twos`. twos = tf.compat.v1.Variable([2.0, 2.0, 2.0]) init = tf.compat.v1.global_variables_initializer() save.build() init.run() save.save(sess, save_path) with tf.compat.v1.Session("", graph=tf.Graph()) as sess: one = tf.compat.v1.Variable(0.0) twos = tf.compat.v1.Variable([0.0, 0.0, 0.0]) # Saver with no arg, defaults to 'all variables'. save = save_utils.PartialRecoverySaver() save.restore(sess, save_path) self.assertAllClose(1.0, self.evaluate(one)) self.assertAllClose([2.0, 2.0, 2.0], self.evaluate(twos)) @test_util.run_v1_only("train.Saver is V1 only API.") def testReshape(self): save_path = os.path.join(self.get_temp_dir(), "variables_reshape") with tf.compat.v1.Session("", graph=tf.Graph()) as sess: var = tf.compat.v1.Variable([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) init = tf.compat.v1.global_variables_initializer() save = save_utils.PartialRecoverySaver() init.run() save.save(sess, save_path) # Error when restoring with default reshape=False with tf.compat.v1.Session("", graph=tf.Graph()) as sess: var = tf.compat.v1.Variable([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]]) save = save_utils.PartialRecoverySaver() with self.assertRaisesRegex( tf.errors.InvalidArgumentError, "Assign requires shapes of both tensors to match."): save.restore(sess, save_path) # Restored to new shape with reshape=True with tf.compat.v1.Session("", graph=tf.Graph()) as sess: var = tf.compat.v1.Variable([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]]) save = save_utils.PartialRecoverySaver(reshape=True) save.restore(sess, save_path) self.assertAllClose([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], self.evaluate(var)) @test_util.run_in_graph_and_eager_modes def testSaveWithGlobalStep(self, pad_step_number=False): save_path = os.path.join(self.get_temp_dir(), "ckpt_with_global_step") global_step_int = 5 # Save and reload one Variable named "var0". self._SaveAndLoad("var0", 0.0, 1.0, save_path) for use_tensor in [True, False]: with self.session(graph=tf.Graph()): var = resource_variable_ops.ResourceVariable(1.0, name="var0") save = save_utils.PartialRecoverySaver({var._shared_name: var}, pad_step_number=pad_step_number) if tf.executing_eagerly(): sess = None else: self.evaluate(var.initializer) sess = tf.compat.v1.get_default_session() if use_tensor: global_step = tf.constant(global_step_int) val = save.save(sess, save_path, global_step=global_step) else: val = save.save(sess, save_path, global_step=global_step_int) if pad_step_number: expected_save_path = "%s-%s" % (save_path, "{:08d}".format(global_step_int)) else: expected_save_path = "%s-%d" % (save_path, global_step_int) self.assertEqual(expected_save_path, val) def testSaveWithGlobalStepWithPadding(self): self.testSaveWithGlobalStep(pad_step_number=True) def testSaveToNonexistingPath(self): file_io.write_string_to_file( os.path.join(self.get_temp_dir(), "actually_a_file"), "") paths = [ os.path.join(self.get_temp_dir(), "nonexisting_dir/path"), os.path.join(self.get_temp_dir(), "other_nonexisting_dir/path1/path2"), os.path.join(self.get_temp_dir(), "actually_a_file/path"), ] for save_path in paths: # Build a graph with 2 parameter nodes, and Save and # Restore nodes for them. v0 = tf.compat.v1.Variable(10.0, name="v0") v1 = tf.compat.v1.Variable(20.0, name="v1") save = save_utils.PartialRecoverySaver({ "v0": v0, "v1": v1 }, restore_sequentially=True) init_all_op = tf.compat.v1.global_variables_initializer() # In the case where the parent directory doesn't exist, whether or not the # save succeeds or fails is implementation dependent. Therefore we allow # both cases. try: with self.cached_session() as sess: # Initialize all variables self.evaluate(init_all_op) # Check that the parameter nodes have been initialized. self.assertEqual(10.0, self.evaluate(v0)) self.assertEqual(20.0, self.evaluate(v1)) # Save the graph. save.save(sess, save_path) with self.cached_session() as sess: # Restore the saved values in the parameter nodes. save.restore(sess, save_path) # Check that the parameter nodes have been restored. self.assertEqual(10.0, self.evaluate(v0)) self.assertEqual(20.0, self.evaluate(v1)) except ValueError as exc: error_msg_template = "Parent directory of {} doesn't exist, can't save." self.assertEqual(error_msg_template.format(save_path), str(exc)) def testSaveToURI(self): # ParseURI functions don't work on Windows yet. # TODO(jhseu): Remove this check when it works. if os.name == "nt": self.skipTest("Local URI support doesn't work on Windows") save_path = "file://" + os.path.join(self.get_temp_dir(), "uri") # Build a graph with 2 parameter nodes, and Save and # Restore nodes for them. v0 = tf.compat.v1.Variable(10.0, name="v0") v1 = tf.compat.v1.Variable(20.0, name="v1") save = save_utils.PartialRecoverySaver({ "v0": v0, "v1": v1 }, restore_sequentially=True) init_all_op = tf.compat.v1.global_variables_initializer() with self.cached_session() as sess: # Initialize all variables self.evaluate(init_all_op) # Check that the parameter nodes have been initialized. self.assertEqual(10.0, self.evaluate(v0)) self.assertEqual(20.0, self.evaluate(v1)) save.save(sess, save_path) def testSaveRestoreAndValidateVariableDtype(self): for variable_op in [tf.Variable, resource_variable_ops.ResourceVariable]: save_path = os.path.join(self.get_temp_dir(), "basic_save_restore") # Build the first session. with self.session(graph=tf.Graph()) as sess: v0 = variable_op(10.0, name="v0", dtype=tf.dtypes.float32) if not tf.executing_eagerly(): self.evaluate([tf.compat.v1.global_variables_initializer()]) save = save_utils.PartialRecoverySaver({"v0": v0}) save.save(sess, save_path) # Start a second session. with self.session(graph=tf.Graph()) as sess: v0_wrong_dtype = variable_op(1, name="v0", dtype=tf.dtypes.int32) # Restore the saved value with different dtype # in the parameter nodes. save = save_utils.PartialRecoverySaver({"v0": v0_wrong_dtype}) with self.assertRaisesRegex(tf.errors.InvalidArgumentError, "original dtype"): save.restore(sess, save_path) # Test restoring large tensors (triggers a thread pool) def testRestoreLargeTensors(self): save_dir = self.get_temp_dir() def _model(): small_v = [ tf.compat.v1.get_variable("small%d" % i, shape=[10, 2], use_resource=True) for i in range(5) ] large_v = [ tf.compat.v1.get_variable("large%d" % i, shape=[32000, 1000], use_resource=True) for i in range(3) ] return small_v + large_v save_graph = tf.Graph() with save_graph.as_default(), self.session(graph=save_graph) as sess: orig_vars = _model() self.evaluate(tf.compat.v1.global_variables_initializer()) save = save_utils.PartialRecoverySaver(max_to_keep=1) self.evaluate(tf.compat.v1.global_variables_initializer()) save.save(sess, save_dir) orig_vals = self.evaluate(orig_vars) restore_graph = tf.Graph() with restore_graph.as_default(), self.session(graph=restore_graph) as sess: restored_vars = _model() save = save_utils.PartialRecoverySaver(max_to_keep=1) save.restore(sess, save_dir) restored_vals = self.evaluate(restored_vars) for orig, restored in zip(orig_vals, restored_vals): self.assertAllEqual(orig, restored) class SaveRestoreShardedTest(tf.test.TestCase): def testIterators(self): save_path = os.path.join(self.get_temp_dir(), "sharded_iterators") # Build a graph with 2 parameter nodes on different devices and save. with tf.compat.v1.Session( target="", config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess: with sess.graph.device("/cpu:0"): ds0 = tf.data.Dataset.range(10) it0 = tf.compat.v1.data.make_initializable_iterator(ds0) get_next0 = it0.get_next() saveable0 = iterator_ops._IteratorSaveable(it0._iterator_resource, name="saveable_it0") with sess.graph.device("/cpu:1"): ds1 = tf.data.Dataset.range(20) it1 = tf.compat.v1.data.make_initializable_iterator(ds1) get_next1 = it1.get_next() saveable1 = iterator_ops._IteratorSaveable(it1._iterator_resource, name="saveable_it1") saver = save_utils.PartialRecoverySaver( { "it0": saveable0, "it1": saveable1 }, sharded=True) self.evaluate(it0.initializer) self.evaluate(it1.initializer) self.assertEqual(0, self.evaluate(get_next0)) self.assertEqual(1, self.evaluate(get_next0)) self.assertEqual(0, self.evaluate(get_next1)) val = saver.save(sess, save_path) self.assertEqual(save_path, val) data_files = tf.io.gfile.glob(save_path + ".data*") self.assertEqual(2, len(data_files)) # Restore with tf.compat.v1.Session( target="", config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess: with sess.graph.device("/cpu:0"): ds0 = tf.data.Dataset.range(10) it0 = tf.compat.v1.data.make_initializable_iterator(ds0) get_next0 = it0.get_next() saveable0 = iterator_ops._IteratorSaveable(it0._iterator_resource, name="saveable_it0") with sess.graph.device("/cpu:1"): ds1 = tf.data.Dataset.range(20) it1 = tf.compat.v1.data.make_initializable_iterator(ds1) get_next1 = it1.get_next() saveable1 = iterator_ops._IteratorSaveable(it1._iterator_resource, name="saveable_it1") saver = save_utils.PartialRecoverySaver( { "it0": saveable0, "it1": saveable1 }, sharded=True) self.evaluate(it0.initializer) self.evaluate(it1.initializer) saver.restore(sess, save_path) self.assertEqual(2, self.evaluate(get_next0)) self.assertEqual(1, self.evaluate(get_next1)) def testIteratorsUnshardedRestore(self): save_path = os.path.join(self.get_temp_dir(), "restore_unsharded_iterators") # Build a graph with 2 parameter nodes on different devices and save. with tf.compat.v1.Session( target="", config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess: with sess.graph.device("/cpu:0"): ds0 = tf.data.Dataset.range(10) it0 = tf.compat.v1.data.make_initializable_iterator(ds0) get_next0 = it0.get_next() saveable0 = iterator_ops._IteratorSaveable(it0._iterator_resource, name="saveable_it0") with sess.graph.device("/cpu:1"): ds1 = tf.data.Dataset.range(20) it1 = tf.compat.v1.data.make_initializable_iterator(ds1) get_next1 = it1.get_next() saveable1 = iterator_ops._IteratorSaveable(it1._iterator_resource, name="saveable_it1") saver = save_utils.PartialRecoverySaver( { "it0": saveable0, "it1": saveable1 }, sharded=True) self.evaluate(it0.initializer) self.evaluate(it1.initializer) self.assertEqual(0, self.evaluate(get_next0)) self.assertEqual(1, self.evaluate(get_next0)) self.assertEqual(0, self.evaluate(get_next1)) val = saver.save(sess, save_path) self.assertEqual(save_path, val) data_files = tf.io.gfile.glob(save_path + ".data*") self.assertEqual(2, len(data_files)) # Restore with tf.compat.v1.Session( target="", config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess: with sess.graph.device("/cpu:0"): ds0 = tf.data.Dataset.range(10) it0 = tf.compat.v1.data.make_initializable_iterator(ds0) get_next0 = it0.get_next() saveable0 = iterator_ops._IteratorSaveable(it0._iterator_resource, name="saveable_it0") with sess.graph.device("/cpu:1"): ds1 = tf.data.Dataset.range(20) it1 = tf.compat.v1.data.make_initializable_iterator(ds1) get_next1 = it1.get_next() saveable1 = iterator_ops._IteratorSaveable(it1._iterator_resource, name="saveable_it1") saver = save_utils.PartialRecoverySaver( { "it0": saveable0, "it1": saveable1 }, sharded=False) self.evaluate(it0.initializer) self.evaluate(it1.initializer) saver.restore(sess, save_path) self.assertEqual(2, self.evaluate(get_next0)) self.assertEqual(1, self.evaluate(get_next1)) class MaxToKeepTest(tf.test.TestCase): def _get_test_dir(self, dirname): test_dir = os.path.join(self.get_temp_dir(), dirname) tf.io.gfile.makedirs(test_dir) return test_dir def assertCheckpointState(self, model_checkpoint_path, all_model_checkpoint_paths, save_dir): checkpoint_state = tf.train.get_checkpoint_state(save_dir) self.assertEqual(checkpoint_state.model_checkpoint_path, model_checkpoint_path) self.assertEqual(checkpoint_state.all_model_checkpoint_paths, all_model_checkpoint_paths) def testMaxToKeepEager(self): with context.eager_mode(): save_dir = self._get_test_dir("max_to_keep_eager") v = tf.compat.v1.Variable(10.0, name="v") save = save_utils.PartialRecoverySaver({"v": v}, max_to_keep=2) self.evaluate(tf.compat.v1.global_variables_initializer()) if not tf.executing_eagerly(): self.assertEqual([], save.last_checkpoints) s1 = save.save(None, os.path.join(save_dir, "s1")) self.assertEqual([s1], save.last_checkpoints) self.assertTrue(tf.compat.v1.train.checkpoint_exists(s1)) self.assertCheckpointState(model_checkpoint_path=s1, all_model_checkpoint_paths=[s1], save_dir=save_dir) s2 = save.save(None, os.path.join(save_dir, "s2")) self.assertEqual([s1, s2], save.last_checkpoints) self.assertTrue(tf.compat.v1.train.checkpoint_exists(s1)) self.assertTrue(tf.compat.v1.train.checkpoint_exists(s2)) self.assertCheckpointState(model_checkpoint_path=s2, all_model_checkpoint_paths=[s1, s2], save_dir=save_dir) s3 = save.save(None, os.path.join(save_dir, "s3")) self.assertEqual([s2, s3], save.last_checkpoints) self.assertFalse(tf.compat.v1.train.checkpoint_exists(s1)) self.assertTrue(tf.compat.v1.train.checkpoint_exists(s2)) self.assertTrue(tf.compat.v1.train.checkpoint_exists(s3)) self.assertCheckpointState(model_checkpoint_path=s3, all_model_checkpoint_paths=[s2, s3], save_dir=save_dir) # Create a second helper, identical to the first. save2 = save_utils.PartialRecoverySaver({"v": v}, max_to_keep=2) save2.set_last_checkpoints_with_time([ (s, np.inf) for s in save.last_checkpoints ]) # Exercise the first helper. # Adding s2 again (old s2 is removed first, then new s2 appended) s2 = save.save(None, os.path.join(save_dir, "s2")) self.assertEqual([s3, s2], save.last_checkpoints) self.assertFalse(tf.compat.v1.train.checkpoint_exists(s1)) self.assertTrue(tf.compat.v1.train.checkpoint_exists(s3)) self.assertTrue(tf.compat.v1.train.checkpoint_exists(s2)) self.assertCheckpointState(model_checkpoint_path=s2, all_model_checkpoint_paths=[s3, s2], save_dir=save_dir) # Adding s1 (s3 should now be deleted as oldest in list) s1 = save.save(None, os.path.join(save_dir, "s1")) self.assertEqual([s2, s1], save.last_checkpoints) self.assertFalse(tf.compat.v1.train.checkpoint_exists(s3)) self.assertTrue(tf.compat.v1.train.checkpoint_exists(s2)) self.assertCheckpointState(model_checkpoint_path=s1, all_model_checkpoint_paths=[s2, s1], save_dir=save_dir) s2 = save2.save(None, os.path.join(save_dir, "s2")) self.assertEqual([s3, s2], save2.last_checkpoints) # Created by the first helper. self.assertTrue(tf.compat.v1.train.checkpoint_exists(s1)) # Deleted by the first helper. self.assertFalse(tf.compat.v1.train.checkpoint_exists(s3)) def testNonSharded(self): save_dir = self._get_test_dir("max_to_keep_non_sharded") # train.Saver is V1 only API. with tf.Graph().as_default(), self.cached_session() as sess: v = tf.compat.v1.Variable(10.0, name="v") save = save_utils.PartialRecoverySaver({"v": v}, max_to_keep=2) self.evaluate(tf.compat.v1.global_variables_initializer()) self.assertEqual([], save.last_checkpoints) s1 = save.save(sess, os.path.join(save_dir, "s1")) self.assertEqual([s1], save.last_checkpoints) self.assertTrue(tf.compat.v1.train.checkpoint_exists(s1)) self.assertCheckpointState(model_checkpoint_path=s1, all_model_checkpoint_paths=[s1], save_dir=save_dir) s2 = save.save(sess, os.path.join(save_dir, "s2")) self.assertEqual([s1, s2], save.last_checkpoints) self.assertTrue(tf.compat.v1.train.checkpoint_exists(s1)) self.assertTrue(tf.compat.v1.train.checkpoint_exists(s2)) self.assertCheckpointState(model_checkpoint_path=s2, all_model_checkpoint_paths=[s1, s2], save_dir=save_dir) s3 = save.save(sess, os.path.join(save_dir, "s3")) self.assertEqual([s2, s3], save.last_checkpoints) self.assertFalse(tf.compat.v1.train.checkpoint_exists(s1)) self.assertTrue(tf.compat.v1.train.checkpoint_exists(s2)) self.assertTrue(tf.compat.v1.train.checkpoint_exists(s3)) self.assertCheckpointState(model_checkpoint_path=s3, all_model_checkpoint_paths=[s2, s3], save_dir=save_dir) # Create a second helper, identical to the first. save2 = save_utils.PartialRecoverySaver(saver_def=save.as_saver_def()) save2.set_last_checkpoints_with_time([ (s, np.inf) for s in save.last_checkpoints ]) # Create a third helper, with the same configuration but no knowledge of # previous checkpoints. save3 = save_utils.PartialRecoverySaver(saver_def=save.as_saver_def()) # Exercise the first helper. # Adding s2 again (old s2 is removed first, then new s2 appended) s2 = save.save(sess, os.path.join(save_dir, "s2")) self.assertEqual([s3, s2], save.last_checkpoints) self.assertFalse(tf.compat.v1.train.checkpoint_exists(s1)) self.assertFalse( tf.compat.v1.train.checkpoint_exists( checkpoint_management.meta_graph_filename(s1))) self.assertTrue(tf.compat.v1.train.checkpoint_exists(s3)) self.assertTrue( tf.compat.v1.train.checkpoint_exists( checkpoint_management.meta_graph_filename(s3))) self.assertTrue(tf.compat.v1.train.checkpoint_exists(s2)) self.assertTrue( tf.compat.v1.train.checkpoint_exists( checkpoint_management.meta_graph_filename(s2))) self.assertCheckpointState(model_checkpoint_path=s2, all_model_checkpoint_paths=[s3, s2], save_dir=save_dir) # Adding s1 (s3 should now be deleted as oldest in list) s1 = save.save(sess, os.path.join(save_dir, "s1")) self.assertEqual([s2, s1], save.last_checkpoints) self.assertFalse(tf.compat.v1.train.checkpoint_exists(s3)) self.assertFalse( tf.compat.v1.train.checkpoint_exists( checkpoint_management.meta_graph_filename(s3))) self.assertTrue(tf.compat.v1.train.checkpoint_exists(s2)) self.assertTrue( tf.compat.v1.train.checkpoint_exists( checkpoint_management.meta_graph_filename(s2))) self.assertTrue(tf.compat.v1.train.checkpoint_exists(s1)) self.assertTrue( tf.compat.v1.train.checkpoint_exists( checkpoint_management.meta_graph_filename(s1))) self.assertCheckpointState(model_checkpoint_path=s1, all_model_checkpoint_paths=[s2, s1], save_dir=save_dir) # Exercise the second helper. # Adding s2 again (old s2 is removed first, then new s2 appended) s2 = save2.save(sess, os.path.join(save_dir, "s2")) self.assertEqual([s3, s2], save2.last_checkpoints) # Created by the first helper. self.assertTrue(tf.compat.v1.train.checkpoint_exists(s1)) self.assertTrue( tf.compat.v1.train.checkpoint_exists( checkpoint_management.meta_graph_filename(s1))) # Deleted by the first helper. self.assertFalse(tf.compat.v1.train.checkpoint_exists(s3)) self.assertFalse( tf.compat.v1.train.checkpoint_exists( checkpoint_management.meta_graph_filename(s3))) self.assertTrue(tf.compat.v1.train.checkpoint_exists(s2)) self.assertTrue( tf.compat.v1.train.checkpoint_exists( checkpoint_management.meta_graph_filename(s2))) self.assertCheckpointState(model_checkpoint_path=s2, all_model_checkpoint_paths=[s3, s2], save_dir=save_dir) # Adding s1 (s3 should now be deleted as oldest in list) s1 = save2.save(sess, os.path.join(save_dir, "s1")) self.assertEqual([s2, s1], save2.last_checkpoints) self.assertFalse(tf.compat.v1.train.checkpoint_exists(s3)) self.assertFalse( tf.compat.v1.train.checkpoint_exists( checkpoint_management.meta_graph_filename(s3))) self.assertTrue(tf.compat.v1.train.checkpoint_exists(s2)) self.assertTrue( tf.compat.v1.train.checkpoint_exists( checkpoint_management.meta_graph_filename(s2))) self.assertTrue(tf.compat.v1.train.checkpoint_exists(s1)) self.assertTrue( tf.compat.v1.train.checkpoint_exists( checkpoint_management.meta_graph_filename(s1))) self.assertCheckpointState(model_checkpoint_path=s1, all_model_checkpoint_paths=[s2, s1], save_dir=save_dir) # Exercise the third helper. # Adding s2 again (but helper is unaware of previous s2) s2 = save3.save(sess, os.path.join(save_dir, "s2")) self.assertEqual([s2], save3.last_checkpoints) # Created by the first helper. self.assertTrue(tf.compat.v1.train.checkpoint_exists(s1)) self.assertTrue( tf.compat.v1.train.checkpoint_exists( checkpoint_management.meta_graph_filename(s1))) # Deleted by the first helper. self.assertFalse(tf.compat.v1.train.checkpoint_exists(s3)) self.assertFalse( tf.compat.v1.train.checkpoint_exists( checkpoint_management.meta_graph_filename(s3))) self.assertTrue(tf.compat.v1.train.checkpoint_exists(s2)) self.assertTrue( tf.compat.v1.train.checkpoint_exists( checkpoint_management.meta_graph_filename(s2))) # Even though the file for s1 exists, this saver isn't aware of it, which # is why it doesn't end up in the checkpoint state. self.assertCheckpointState(model_checkpoint_path=s2, all_model_checkpoint_paths=[s2], save_dir=save_dir) # Adding s1 (s3 should not be deleted because helper is unaware of it) s1 = save3.save(sess, os.path.join(save_dir, "s1")) self.assertEqual([s2, s1], save3.last_checkpoints) self.assertFalse(tf.compat.v1.train.checkpoint_exists(s3)) self.assertFalse( tf.compat.v1.train.checkpoint_exists( checkpoint_management.meta_graph_filename(s3))) self.assertTrue(tf.compat.v1.train.checkpoint_exists(s2)) self.assertTrue( tf.compat.v1.train.checkpoint_exists( checkpoint_management.meta_graph_filename(s2))) self.assertTrue(tf.compat.v1.train.checkpoint_exists(s1)) self.assertTrue( tf.compat.v1.train.checkpoint_exists( checkpoint_management.meta_graph_filename(s1))) self.assertCheckpointState(model_checkpoint_path=s1, all_model_checkpoint_paths=[s2, s1], save_dir=save_dir) def testSharded(self): save_dir = self._get_test_dir("max_to_keep_sharded") with tf.compat.v1.Session( target="", config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess: with sess.graph.device("/cpu:0"): v0 = tf.compat.v1.Variable(111, name="v0") with sess.graph.device("/cpu:1"): v1 = tf.compat.v1.Variable(222, name="v1") save = save_utils.PartialRecoverySaver({ "v0": v0, "v1": v1 }, sharded=True, max_to_keep=2) self.evaluate(tf.compat.v1.global_variables_initializer()) self.assertEqual([], save.last_checkpoints) s1 = save.save(sess, os.path.join(save_dir, "s1")) self.assertEqual([s1], save.last_checkpoints) self.assertEqual(4, len(tf.io.gfile.glob(s1 + "*"))) self.assertTrue( tf.io.gfile.exists(checkpoint_management.meta_graph_filename(s1))) s2 = save.save(sess, os.path.join(save_dir, "s2")) self.assertEqual([s1, s2], save.last_checkpoints) self.assertEqual(4, len(tf.io.gfile.glob(s1 + "*"))) self.assertTrue( tf.io.gfile.exists(checkpoint_management.meta_graph_filename(s1))) self.assertEqual(4, len(tf.io.gfile.glob(s2 + "*"))) self.assertTrue( tf.io.gfile.exists(checkpoint_management.meta_graph_filename(s2))) s3 = save.save(sess, os.path.join(save_dir, "s3")) self.assertEqual([s2, s3], save.last_checkpoints) self.assertEqual(0, len(tf.io.gfile.glob(s1 + "*"))) self.assertFalse( tf.io.gfile.exists(checkpoint_management.meta_graph_filename(s1))) self.assertEqual(4, len(tf.io.gfile.glob(s2 + "*"))) self.assertTrue( tf.io.gfile.exists(checkpoint_management.meta_graph_filename(s2))) self.assertEqual(4, len(tf.io.gfile.glob(s3 + "*"))) self.assertTrue( tf.io.gfile.exists(checkpoint_management.meta_graph_filename(s3))) def testNoMaxToKeep(self): save_dir = self._get_test_dir("no_max_to_keep") save_dir2 = self._get_test_dir("max_to_keep_0") with self.cached_session() as sess: v = tf.compat.v1.Variable(10.0, name="v") self.evaluate(tf.compat.v1.global_variables_initializer()) # Test max_to_keep being None. save = save_utils.PartialRecoverySaver({"v": v}, max_to_keep=None) self.assertEqual([], save.last_checkpoints) s1 = save.save(sess, os.path.join(save_dir, "s1")) self.assertEqual([], save.last_checkpoints) self.assertTrue(tf.compat.v1.train.checkpoint_exists(s1)) s2 = save.save(sess, os.path.join(save_dir, "s2")) self.assertEqual([], save.last_checkpoints) self.assertTrue(tf.compat.v1.train.checkpoint_exists(s2)) # Test max_to_keep being 0. save2 = save_utils.PartialRecoverySaver({"v": v}, max_to_keep=0) self.assertEqual([], save2.last_checkpoints) s1 = save2.save(sess, os.path.join(save_dir2, "s1")) self.assertEqual([], save2.last_checkpoints) self.assertTrue(tf.compat.v1.train.checkpoint_exists(s1)) s2 = save2.save(sess, os.path.join(save_dir2, "s2")) self.assertEqual([], save2.last_checkpoints) self.assertTrue(tf.compat.v1.train.checkpoint_exists(s2)) def testNoMetaGraph(self): save_dir = self._get_test_dir("no_meta_graph") with self.cached_session() as sess: v = tf.compat.v1.Variable(10.0, name="v") save = save_utils.PartialRecoverySaver({"v": v}) self.evaluate(tf.compat.v1.global_variables_initializer()) s1 = save.save(sess, os.path.join(save_dir, "s1"), write_meta_graph=False) self.assertTrue(tf.compat.v1.train.checkpoint_exists(s1)) self.assertFalse( tf.io.gfile.exists(checkpoint_management.meta_graph_filename(s1))) class RecoverLastCheckpointsTest(tf.test.TestCase): def _get_test_dir(self, dirname): test_dir = os.path.join(self.get_temp_dir(), dirname) tf.io.gfile.makedirs(test_dir) return test_dir def assertCheckpointState(self, model_checkpoint_path, all_model_checkpoint_paths, save_dir): checkpoint_state = tf.train.get_checkpoint_state(save_dir) self.assertEqual(checkpoint_state.model_checkpoint_path, model_checkpoint_path) self.assertEqual(checkpoint_state.all_model_checkpoint_paths, all_model_checkpoint_paths) def test_recover_last_checkpoints(self): with context.eager_mode(): save_dir = self._get_test_dir("recover_last_checkpoints") v = tf.compat.v1.Variable(10.0, name="v") save = save_utils.PartialRecoverySaver({"v": v}, max_to_keep=10) self.evaluate(tf.compat.v1.global_variables_initializer()) self.assertEqual([], save.last_checkpoints) s1 = save.save(None, os.path.join(save_dir, "ckpt-1")) s2 = save.save(None, os.path.join(save_dir, "ckpt-2")) s3 = save.save(None, os.path.join(save_dir, "ckpt-3")) self.assertEqual([s1, s2, s3], save.last_checkpoints) self.assertTrue(tf.compat.v1.train.checkpoint_exists(s1)) self.assertTrue(tf.compat.v1.train.checkpoint_exists(s2)) self.assertTrue(tf.compat.v1.train.checkpoint_exists(s3)) self.assertCheckpointState(model_checkpoint_path=s3, all_model_checkpoint_paths=[s1, s2, s3], save_dir=save_dir) # Create another saver and recover last checkpoints. save2 = save_utils.PartialRecoverySaver({"v": v}, max_to_keep=10) self.assertEqual([], save2.last_checkpoints) save2.recover_last_checkpoints([s1, s2, s3]) self.assertEqual([s1, s2, s3], save2.last_checkpoints) # Remove a checkpoint and check that last checkpoints are # restored correctly. for fname in tf.io.gfile.glob("{}*".format(s1)): tf.io.gfile.remove(fname) self.assertFalse(tf.compat.v1.train.checkpoint_exists(s1)) # Create another saver and recover last checkpoints. The removed # checkpoint would be correctly omitted. save3 = save_utils.PartialRecoverySaver({"v": v}, max_to_keep=10) self.assertEqual([], save3.last_checkpoints) save3.recover_last_checkpoints([s1, s2, s3]) self.assertEqual([s2, s3], save3.last_checkpoints) s4 = save3.save(None, os.path.join(save_dir, "ckpt-4")) self.assertCheckpointState(model_checkpoint_path=s4, all_model_checkpoint_paths=[s2, s3, s4], save_dir=save_dir) class KeepCheckpointEveryNHoursTest(tf.test.TestCase): def _get_test_dir(self, dirname): test_dir = os.path.join(self.get_temp_dir(), dirname) tf.io.gfile.makedirs(test_dir) return test_dir @test_util.run_in_graph_and_eager_modes @mock.patch.object(save_utils, "time") def testNonSharded(self, mock_time): save_dir = self._get_test_dir("keep_checkpoint_every_n_hours") with self.cached_session() as sess: v = tf.compat.v1.Variable([10.0], name="v") # Run the initializer NOW to avoid the 0.5s overhead of the first Run() # call, which throws the test timing off in fastbuild mode. self.evaluate(tf.compat.v1.global_variables_initializer()) # Create a saver that will keep the last 2 checkpoints plus one every 0.7 # seconds. start_time = time.time() mock_time.time.return_value = start_time save = save_utils.PartialRecoverySaver({"v": v}, max_to_keep=2, keep_checkpoint_every_n_hours=0.7 / 3600) self.assertEqual([], save.last_checkpoints) # Wait till 1 seconds have elapsed so s1 will be old enough to keep. # sleep may return early, don't trust it. mock_time.time.return_value = start_time + 1.0 s1 = save.save(sess, os.path.join(save_dir, "s1")) self.assertEqual([s1], save.last_checkpoints) s2 = save.save(sess, os.path.join(save_dir, "s2")) self.assertEqual([s1, s2], save.last_checkpoints) # We now have 2 'last_checkpoints': [s1, s2]. The next call to Save(), # would normally delete s1, because max_to_keep is 2. However, s1 is # older than 0.7s so we must keep it. s3 = save.save(sess, os.path.join(save_dir, "s3")) self.assertEqual([s2, s3], save.last_checkpoints) # s1 should still be here, we are Not checking now to reduce time # variance in the test. # We now have 2 'last_checkpoints': [s2, s3], and s1 on disk. The next # call to Save(), will delete s2, because max_to_keep is 2, and because # we already kept the old s1. s2 is very close in time to s1 so it gets # deleted. s4 = save.save(sess, os.path.join(save_dir, "s4")) self.assertEqual([s3, s4], save.last_checkpoints) # Check that s1 is still here, but s2 is gone. self.assertTrue(tf.compat.v1.train.checkpoint_exists(s1)) self.assertFalse(tf.compat.v1.train.checkpoint_exists(s2)) self.assertTrue(tf.compat.v1.train.checkpoint_exists(s3)) self.assertTrue(tf.compat.v1.train.checkpoint_exists(s4)) class SaveRestoreWithVariableNameMap(tf.test.TestCase): def _testNonReshape(self, variable_op): save_path = os.path.join(self.get_temp_dir(), "non_reshape") with self.session(graph=tf.Graph()) as sess: # Build a graph with 2 parameter nodes, and Save and # Restore nodes for them. v0 = variable_op(10.0, name="v0") v1 = variable_op(20.0, name="v1") save = save_utils.PartialRecoverySaver({ "save_prefix/v0": v0, "save_prefix/v1": v1 }) self.evaluate(tf.compat.v1.global_variables_initializer()) # Check that the parameter nodes have been initialized. self.assertEqual(10.0, self.evaluate(v0)) self.assertEqual(20.0, self.evaluate(v1)) # Save the initialized values in the file at "save_path" # Use a variable name map to set the saved tensor names val = save.save(sess, save_path) self.assertTrue(isinstance(val, six.string_types)) self.assertEqual(save_path, val) # Verify that the original names are not in the Saved file save = save_utils.PartialRecoverySaver({"v0": v0, "v1": v1}) with self.assertRaisesOpError("not found in checkpoint"): save.restore(sess, save_path) # Verify that the mapped names are present in the Saved file and can be # Restored using remapped names. with self.session(graph=tf.Graph()) as sess: v0 = variable_op(-1.0, name="v0") v1 = variable_op(-1.0, name="v1") if not tf.executing_eagerly(): with self.assertRaisesOpError("uninitialized"): self.evaluate(v0) with self.assertRaisesOpError("uninitialized"): self.evaluate(v1) save = save_utils.PartialRecoverySaver({ "save_prefix/v0": v0, "save_prefix/v1": v1 }) save.restore(sess, save_path) # Check that the parameter nodes have been restored. if not tf.executing_eagerly(): self.assertEqual(10.0, self.evaluate(v0)) self.assertEqual(20.0, self.evaluate(v1)) # Add a prefix to the node names in the current graph and Restore using # remapped names. with self.session(graph=tf.Graph()) as sess: v0 = variable_op(-1.0, name="restore_prefix/v0") v1 = variable_op(-1.0, name="restore_prefix/v1") if not tf.executing_eagerly(): with self.assertRaisesOpError("uninitialized"): self.evaluate(v0) with self.assertRaisesOpError("uninitialized"): self.evaluate(v1) # Restore the saved values in the parameter nodes. save = save_utils.PartialRecoverySaver({ "save_prefix/v0": v0, "save_prefix/v1": v1 }) save.restore(sess, save_path) # Check that the parameter nodes have been restored. self.assertEqual(10.0, self.evaluate(v0)) self.assertEqual(20.0, self.evaluate(v1)) @test_util.run_in_graph_and_eager_modes def testNonReshapeResourceVariable(self): self._testNonReshape(resource_variable_ops.ResourceVariable) def testNonReshapeVariable(self): self._testNonReshape(tf.Variable) class SecondOrStepTimerWithTideSettingTest(tf.test.TestCase): def testNoTideSetting(self): timer = save_utils.SecondOrStepTimerWithTideSetting(every_secs=10) with freeze_time("2012-01-14 02:00:00") as freezer: timer.update_last_triggered_step(5) freezer.tick(5.0) self.assertEqual(False, timer.should_trigger_for_step(10)) freezer.tick(10.0) self.assertEqual(True, timer.should_trigger_for_step(10)) def testTideAvailable(self): timer = save_utils.SecondOrStepTimerWithTideSetting(every_secs=10, tide_start_hour=1, tide_start_minute=0, tide_end_hour=3, tide_end_minute=0, tide_every_secs=5) with freeze_time("2012-01-14 02:00:00") as freezer: timer.update_last_triggered_step(5) freezer.tick(5.0) self.assertEqual(False, timer.should_trigger_for_step(10)) freezer.tick(10.0) self.assertEqual(True, timer.should_trigger_for_step(10)) def testTideNotAvailable(self): timer = save_utils.SecondOrStepTimerWithTideSetting(every_secs=10, tide_start_hour=1, tide_start_minute=0, tide_end_hour=3, tide_end_minute=0, tide_every_secs=5) with freeze_time("2012-01-14 04:00:00") as freezer: timer.update_last_triggered_step(5) freezer.tick(2.0) self.assertEqual(False, timer.should_trigger_for_step(10)) freezer.tick(7.0) self.assertEqual(True, timer.should_trigger_for_step(10)) if __name__ == "__main__": tf.compat.v1.disable_eager_execution() tf.test.main() ================================================ FILE: monolith/native_training/service_discovery.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from absl import logging import abc from collections import namedtuple, defaultdict from enum import Enum import socket import os import threading import time from typing import Dict, NamedTuple import numpy as np from kazoo.retry import KazooRetry from monolith.native_training.zk_utils import MonolithKazooClient from kazoo.client import KazooState from kazoo.exceptions import NoNodeError, NodeExistsError from monolith.native_training import consul from monolith.native_training.zk_utils import default_zk_servers from monolith.native_training.mlp_utils import MLPEnv, check_port class ServiceDiscoveryType(Enum): PRIMUS = 1 CONSUL = 2 ZK = 3 MLP = 4 class ServiceDiscovery(abc.ABC): @abc.abstractmethod def register(self, name: str, index: int, addr: str): """Register the port to the index""" @abc.abstractmethod def deregister(self, name: str, index: int, addr: str): """Deregister the port from index.""" def query(self, name) -> Dict[int, str]: """Returns a dict that maps index to str""" def close(self): pass _HostAndPort = namedtuple("_HostAndPort", ["host", "port"]) _RETRY_MAX_BACKOFF_SECS = 5 def retry_with_socket_error(http_call): tries = 5 for i in range(tries): try: return http_call() except socket.error: if i < tries - 1: time.sleep(np.random.rand() * _RETRY_MAX_BACKOFF_SECS) continue raise class ConsulServiceDiscovery(ServiceDiscovery): def __init__(self, consul_id: str, retry_time_secs: float = 3.0): self._consul_id = consul_id self._client = consul.Client() self._retry_time_secs = retry_time_secs def register(self, name: str, index: int, addr: str): # This is best effort, deregister the address with the same name & index. while True: index_to_addr = self.query(name) if index in index_to_addr: self.deregister(name, index, index_to_addr[index]) else: break time.sleep(self._retry_time_secs) host, port = self._get_host_and_port(addr) retry_with_socket_error(lambda: self._client.register( self._consul_id, port, tags={ "index": index, "name": name, "ip": host, })) # We need to make sure we can be registered # We wait upto 180 secs before we think the machine is blacklisted. max_retries = max(2, 180 / max(1, _RETRY_MAX_BACKOFF_SECS)) retries = 0 while True: index_to_addr = self.query(name) if index in index_to_addr: break retries += 1 if retries > max_retries: raise OSError("This machine is blacklisted by consul.") time.sleep(_RETRY_MAX_BACKOFF_SECS) def deregister(self, name: str, index: int, addr: str): del name del index host, port = self._get_host_and_port(addr) retry_with_socket_error( lambda: self._client.deregister(self._consul_id, port)) def query_all(self) -> Dict[str, Dict[int, str]]: elements = retry_with_socket_error( lambda: self._client.lookup(self._consul_id, timeout=15)) addrs = defaultdict(dict) for element in elements: name = element["Tags"]["name"] addr = "{}:{}".format(element["Tags"]["ip"], element["Port"]) index = int(element["Tags"]["index"]) addrs[name][index] = addr return addrs def query(self, name: str): named_addrs = self.query_all() return named_addrs[name] def _get_host_and_port(self, addr: str) -> _HostAndPort: components = addr.split(":") if len(components) != 2: raise ValueError("Invalid addr: {}".format(addr)) return _HostAndPort(host=components[0], port=int(components[1])) class TfConfigServiceDiscovery(ServiceDiscovery): def __init__(self, tf_config): self._tf_config = tf_config def register(self, name: str, index: int, addr: str): pass def deregister(self, name: str, index: int, addr: str): pass def query(self, name: str): if name == 'ps': addr_list = self._tf_config['cluster'][name] elif name == 'worker': if 'chief' in self._tf_config['cluster']: addr_list = self._tf_config['cluster']['chief'] + \ self._tf_config['cluster'][name] else: addr_list = self._tf_config['cluster'][name] else: raise ValueError('name must be ps/worker') return {i: addr for i, addr in enumerate(addr_list)} @property def server_type(self): task = self._tf_config['task'] return 'worker' if task['type'] == 'chief' else task['type'] @property def addr(self): task = self._tf_config['task'] return self._tf_config['cluster'][task['type']][task['index']] @property def index(self): task = self._tf_config['task'] if 'chief' in self._tf_config['cluster']: return task['index'] + 1 if task['type'] == 'worker' else task['index'] else: return task['index'] class ZKListener(object): def __init__(self, zkds: 'ZKServiceDiscovery'): self._zksd = zkds self._has_lost = False def __call__(self, state: KazooState) -> None: if state == KazooState.LOST: # The connection has been confirmed dead logging.warning( "Any ephemeral nodes will need to be recreated upon re-establishing a connection." ) self._has_lost = True elif state == KazooState.SUSPENDED: # Handle being disconnected from Zookeeper return else: # Handle being connected/reconnected to Zookeeper if self._has_lost: logging.info( "connected/reconnected after lost, restart updater and watcher") self._zksd.do_all_registrations() self._has_lost = False _ZK_REGISTRATION_PERIOD = 30 * 60 class ZKServiceDiscovery(ServiceDiscovery): class ThreadSet(NamedTuple): stop: threading.Event wakeup: threading.Event thread: threading.Thread def stop_and_join(self): self.stop.set() self.wakeup.set() self.thread.join() def __init__(self, job_name: str, zk_server: str = None, max_tries: int = 3, delay: int = 5): self._max_tries = max_tries self._delay = delay # /monolith/{job_name}/server_type.index -> host:port self._path_prefix = '/monolith/{}'.format(job_name) self._lock = threading.Lock() self._cluster = {} # Maps (name, index) to thread set. self._threads: Dict[Tuple[str, int], ZKServiceDiscovery.ThreadSet] = {} try: zk_server = zk_server or default_zk_servers() self._client = MonolithKazooClient(zk_server) self._client.add_listener(ZKListener(self)) self._client.start() self._client.ensure_path(self._path_prefix) except Exception as e: logging.error("cannot create zk client, {}".format(e)) raise e self._watch_data() def do_all_registrations(self): for ts in self._threads.values(): ts.wakeup.set() def _get_node_name(self, server_type: str, index: int): return '{}.{}'.format(server_type, index) def _get_path(self, server_type: str, index: int): return "{}/{}".format(self._path_prefix, self._get_node_name(server_type, index)) def _try_create(self, path: str, value: str): value_bytes = bytes(value, 'utf-8') try: self._client.create(path, value=value_bytes, makepath=True, ephemeral=True) except NodeExistsError: self._client.set(path, value_bytes) def _try_delete(self, path): try: self._client.delete(path=path, recursive=True) except NoNodeError: logging.info("{path} is not exist, no need to delete".format(path=path)) def _children_watch(self, children): with self._lock: old_children = set( self._get_node_name(serve_type, index) for serve_type in self._cluster for index in self._cluster[serve_type]) new_children = set(child for child in children if child) added = new_children - old_children for node in added: path = '{}/{}'.format(self._path_prefix, node) self._client.DataWatch(path, func=self._get_data_watch(path)) def _get_data_watch(self, path): def data_watch(data, stat): server_type, index = os.path.basename(path).split('.') index = int(index) with self._lock: if data is not None and len(data) > 0: addr = data.decode("utf-8") if server_type in self._cluster: self._cluster[server_type][index] = addr else: self._cluster[server_type] = {index: addr} else: if server_type in self._cluster: if index in self._cluster[server_type]: del self._cluster[server_type][index] return data_watch def _watch_data(self): self._client.ChildrenWatch(self._path_prefix, self._children_watch) def register(self, name: str, index: int, addr: str): self._internal_register(name, index, addr, register_periodically=True) def _periodically_register(self, name: str, index: int, addr: str, stop: threading.Event, wakeup: threading.Event): while True: if wakeup.wait(_ZK_REGISTRATION_PERIOD): wakeup.clear() if stop.is_set(): break try: self._internal_register(name, index, addr, register_periodically=False) except Exception: # This is the best effort. pass def _internal_register(self, name: str, index: int, addr: str, register_periodically: bool): path = self._get_path(name, index) retry = KazooRetry(max_tries=self._max_tries, delay=self._delay) try: retry(self._try_create, path, addr) except Exception as e: logging.error("server_type: {} , index:{} register fail".format( name, index)) raise e if register_periodically: stop = threading.Event() wakeup = threading.Event() thread = threading.Thread(target=self._periodically_register, args=(name, index, addr, stop, wakeup), daemon=True) thread.start() self._threads[(name, index)] = ZKServiceDiscovery.ThreadSet(stop=stop, wakeup=wakeup, thread=thread) def deregister(self, name: str, index: int, addr: str): path = self._get_path(name, index) retry = KazooRetry(max_tries=self._max_tries, delay=self._delay) try: retry(self._try_delete, path) except Exception as e: logging.error("server_type: {} , index:{} deregister fail".format( name, index)) raise e key = (name, index) ts = self._threads[key] ts.stop_and_join() del self._threads[key] def query(self, name) -> Dict[int, str]: with self._lock: if name in self._cluster: return self._cluster[name] else: return {} def close(self): if self._client is not None: self._client.stop() self._client.close() self._client = None for ts in self._threads.values(): ts.stop_and_join() def __del__(self): self.close() class MLPServiceDiscovery(ServiceDiscovery): def __init__(self): self._mlp_env = MLPEnv() self._filters = set() self.addr = f"{self._mlp_env.host}:{self._mlp_env.port}" def _check(self, name: str, index: int, addr: str): if self._mlp_env is None: return assert name.upper() in self._mlp_env.all_roles assert index < self._mlp_env.num_replicas(name) exp_addr = self._mlp_env.get_addr(name, index=index) exp_host, exp_port = exp_addr.split(':') real_host, real_port = addr.split(':') assert real_host in {'local', 'localhost', '127.0.0.1', '0.0.0.0', exp_host, self._mlp_env.host, self._mlp_env.get_host(is_primary=False)} assert exp_port == real_port def register(self, name: str, index: int, addr: str): self._check(name, index, addr) key = f'{name.lower()}:{index}' if key in self._filters: self._filters.remove(key) def deregister(self, name: str, index: int, addr: str): self._check(name, index, addr) self._filters.add(f'{name.lower()}:{index}') def query(self, name, skip_port_check: bool = False) -> Dict[int, str]: assert name is not None and len(name) > 0 if self._mlp_env is None: return {} result = {} for idx in range(self._mlp_env.num_replicas(name)): addr = self._mlp_env.get_addr(name, index=idx) assert addr is not None key = f'{name.lower()}:{idx}' if key not in self._filters: result[idx] = addr if name.lower() == 'ps' and not skip_port_check: host, port = addr.split(':') assert check_port(host, int(port), timeout=3600), f'{addr} connect error!' return result def deregister_all(self) -> Dict[str, Dict[int, str]]: if self._mlp_env is None: return {} for name, num in self._mlp_env.all_roles.items(): for idx in range(num): key = f'{name.lower()}:{idx}' self._filters.add(key) def query_all(self): if self._mlp_env is None: return {} result = {} for name, num in self._mlp_env.all_roles.items(): if name.lower() in {'ps', 'worker', 'chief'}: name = name.lower() result[name] = self.query(name, True) return result @property def server_type(self): if self._mlp_env is None: return None return self._mlp_env.role.lower() @property def index(self): if self._mlp_env is None: return None return self._mlp_env.index def close(self): self._mlp_env = None self._filters.clear() def __del__(self): self.close() def deregister_all(consul_id: str): """Deregisters all records in the given consul_id.""" discovery = ConsulServiceDiscovery(consul_id) named_addrs = discovery.query_all() for name, addrs in named_addrs.items(): for idx, addr in addrs.items(): discovery.deregister(name, idx, addr) ================================================ FILE: monolith/native_training/service_discovery_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import json from collections import defaultdict from functools import partial import os import socket import threading import time import unittest from unittest import mock from kazoo.client import KazooState from kazoo.exceptions import NoNodeError, NodeExistsError, NotEmptyError from typing import List, Callable from monolith.native_training import service_discovery class FakeConsul: def __init__(self, blacklist=[]): """The host in blacklist will not be registered to consul""" self._name_to_dict = defaultdict(dict) self._blacklist = blacklist def lookup(self, name: str, **kwargs): return list(self._name_to_dict[name].values()) def register(self, name: str, port: int, tags={}, host: str = None, check_script: str = None): del check_script if tags["ip"] in self._blacklist: return key = str(host) + ":" + str(port) d = self._name_to_dict[name] d[key] = {"Host": host, "Port": port, "Tags": tags} def deregister(self, name: str, port: int, host: str = None): key = str(host) + ":" + str(port) d = self._name_to_dict[name] del d[key] class FakeKazooClient: def __init__(self, zk_server: str): self._lock = threading.RLock() self._zk_server = zk_server self._data = None self._children_watches = [] self._data_watches = [] self.DataWatch = partial(DataWatch, self) self.ChildrenWatch = partial(ChildrenWatch, self) self._listeners = [] def ensure_path(self, path: str): with self._lock: if path not in self._data: self._data[path] = None def start(self): with self._lock: self._data = {} def create(self, path: str, value: bytes = b'', makepath: bool = False, ephemeral: bool = False): with self._lock: if path in self._data: raise NodeExistsError('node {} exists'.format(path)) else: prefix = os.path.dirname(path) if prefix in self._data or makepath: self._data[path] = value for dw in self._data_watches: if dw.path == path: dw(value, None) for cw in self._children_watches: dirname = os.path.dirname(path) if dirname == cw.path: children = [ os.path.basename(key) for key in self._data if os.path.dirname(key) == cw.path ] cw(children) else: raise NoNodeError('No Node {}'.format(prefix)) def delete(self, path: str, recursive: bool = True): with self._lock: if path in self._data: del self._data[path] for dw in self._data_watches: if dw.path == path: dw(None, None) for cw in self._children_watches: dirname = os.path.dirname(path) if dirname == cw.path: children = [ os.path.basename(key) for key in self._data if os.path.dirname(key) == cw.path ] cw(children) else: collected = [] for key in self._data: if key.startswith(path): collected.append(key) if collected: if recursive: for key in collected: del self._data[key] for dw in self._data_watches: if dw.path == key: dw(None, None) for cw in self._children_watches: dirname = os.path.dirname(key) if dirname == cw.path: children = [ os.path.basename(key) for key in self._data if os.path.dirname(key) == cw.path ] cw(children) else: raise NotEmptyError('node {} has children'.format(path)) else: raise NoNodeError('node {} not found'.format(path)) def set(self, path: str, value: bytes): with self._lock: if path in self._data: self._data[path] = value for dw in self._data_watches: if dw.path == path: dw(value, None) else: raise NoNodeError('node {} is not exist'.format(path)) def get(self, path: str): with self._lock: if path in self._data: return self._data[path], None else: raise NoNodeError('node {} is not exist'.format(path)) def get_children(self, path: str): with self._lock: if path in self._data: return [] else: collected = [] for key in self._data: if key.startswith(path): child = key[len(path) + 1:].split('/')[0] collected.append(child) if collected: return collected else: raise NoNodeError('node {} is not exist'.format(path)) def stop(self): with self._lock: self._data = None def close(self): with self._lock: if self._data is not None: self.stop() def add_listener(self, listener): self._listeners.append(listener) @property def listeners(self): return self._listeners class ChildrenWatch: def __init__(self, client: FakeKazooClient, path: str, func: Callable[[List[str]], None]): self._client = client client._children_watches.append(self) self.path = path self._func = func children = [] for key in self._client._data: dirname = os.path.dirname(key) if dirname == path: children.append(os.path.basename(key)) self._func(children) def __call__(self, children: List[str]): self._func(children) class DataWatch: def __init__(self, client: FakeKazooClient, path: str, func: str): self._client = client client._data_watches.append(self) self.path = path self._func = func for key, value in self._client._data.items(): if key == path: self._func(value, None) def __call__(self, data: str, stat=None): self._func(data, stat) _CONSUL_CLIENT = "monolith.native_training.service_discovery.consul.Client" _ZK_CLIENT = "monolith.native_training.service_discovery.MonolithKazooClient" class ConsultServiceDiscovery(unittest.TestCase): def test_basic(self): with mock.patch(_CONSUL_CLIENT) as MockClient: MockClient.return_value = FakeConsul() discovery = service_discovery.ConsulServiceDiscovery("unique_id") discovery.register("server", 0, "192.168.0.1:1001") discovery.register("server", 1, "192.168.0.2:1002") self.assertDictEqual(discovery.query("server"), { 0: "192.168.0.1:1001", 1: "192.168.0.2:1002" }) discovery.deregister("server", 0, "192.168.0.1:1001") discovery.deregister("server", 1, "192.168.0.2:1002") self.assertDictEqual(discovery.query("server"), {}) def test_duplicate_registration(self): with mock.patch(_CONSUL_CLIENT) as MockClient: MockClient.return_value = FakeConsul() discovery = service_discovery.ConsulServiceDiscovery("unique_id", retry_time_secs=0.0) discovery.register("server", 0, "192.168.0.1:1001") discovery.register("server", 0, "192.168.0.1:1002") self.assertDictEqual(discovery.query("server"), {0: "192.168.0.1:1002"}) def test_multi_names(self): with mock.patch(_CONSUL_CLIENT) as MockClient: MockClient.return_value = FakeConsul() discovery = service_discovery.ConsulServiceDiscovery("unique_id") discovery.register("ps", 0, "192.168.0.1:1001") discovery.register("worker", 0, "192.168.0.1:1002") self.assertDictEqual(discovery.query("worker"), {0: "192.168.0.1:1002"}) def test_retry(self): with mock.patch( "monolith.native_training.service_discovery._RETRY_MAX_BACKOFF_SECS", 0): with mock.patch(_CONSUL_CLIENT) as MockClient: mock_client = mock.MagicMock() def raise_timeout(*args, **kwargs): raise socket.timeout() mock_client.register.side_effect = raise_timeout MockClient.return_value = mock_client discovery = service_discovery.ConsulServiceDiscovery("unique_id") self.assertRaises(socket.timeout, discovery.register, "ps", 0, "192.168.0.1:1001") def test_registeration_failed(self): with mock.patch( "monolith.native_training.service_discovery._RETRY_MAX_BACKOFF_SECS", 0), mock.patch(_CONSUL_CLIENT) as MockClient: MockClient.return_value = FakeConsul(blacklist=["192.168.0.1"]) discovery = service_discovery.ConsulServiceDiscovery("unique_id") with self.assertRaises(OSError): discovery.register("ps", 0, "192.168.0.1:1001") class TfConfigServiceDiscoveryTest(unittest.TestCase): def test_tf_conf_sd(self): cluster = { 'chief': ['host0:2222'], 'ps': ['host1:2222', 'host2:2222'], 'worker': ['host3:2222', 'host4:2222', 'host5:2222'] } task = {'type': 'worker', 'index': 1} tf_conf = {'cluster': cluster, 'task': task} ps_list = cluster['ps'] discovery = service_discovery.TfConfigServiceDiscovery(tf_conf) self.assertEqual(discovery.query('ps'), {i: addr for i, addr in enumerate(ps_list)}, "['host1:2222', 'host2:2222']") worker_list = cluster['chief'] + cluster['worker'] self.assertEqual(discovery.query('worker'), {i: addr for i, addr in enumerate(worker_list)}, "[host0:2222, 'host1:2222', 'host2:2222']") self.assertEqual(discovery.addr, 'host4:2222', 'host4:2222') self.assertEqual(discovery.server_type, 'worker', 'worker') self.assertEqual(discovery.index, 2, 2) class ZKServiceDiscoveryTest(unittest.TestCase): def test_basic(self): with mock.patch(_ZK_CLIENT) as MockClient: MockClient.return_value = FakeKazooClient("test_model") discovery = service_discovery.ZKServiceDiscovery("test_model", "fask") discovery.register("server", 0, "192.168.0.1:1001") discovery.register("server", 1, "192.168.0.2:1002") self.assertDictEqual(discovery.query("server"), { 0: "192.168.0.1:1001", 1: "192.168.0.2:1002" }) discovery.deregister("server", 0, "192.168.0.1:1001") discovery.deregister("server", 1, "192.168.0.2:1002") self.assertDictEqual(discovery.query("server"), {}) discovery.close() def test_duplicate_registration(self): with mock.patch(_ZK_CLIENT) as MockClient: MockClient.return_value = FakeKazooClient("test_model") discovery = service_discovery.ZKServiceDiscovery("test_model", "fask") discovery.register("server", 0, "192.168.0.1:1001") discovery.register("server", 0, "192.168.0.1:1002") self.assertDictEqual(discovery.query("server"), {0: "192.168.0.1:1002"}) discovery.close() def test_multi_names(self): with mock.patch(_ZK_CLIENT) as MockClient: MockClient.return_value = FakeKazooClient("test_model") discovery = service_discovery.ZKServiceDiscovery("test_model", "fask") discovery.register("ps", 0, "192.168.0.1:1001") discovery.register("worker", 0, "192.168.0.1:1002") self.assertDictEqual(discovery.query("worker"), {0: "192.168.0.1:1002"}) del discovery @mock.patch( "monolith.native_training.service_discovery._ZK_REGISTRATION_PERIOD", 0.01) def test_periodic_registration(self): with mock.patch(_ZK_CLIENT) as MockClient: c = FakeKazooClient("test_model") MockClient.return_value = c discovery = service_discovery.ZKServiceDiscovery("test_model") discovery.register("ps", 0, "192.168.0.1:1001") c.set("/monolith/test_model/ps.0", "wrongdata".encode()) time.sleep(1) # Periodic registration should register it again self.assertDictEqual(discovery.query("ps"), {0: "192.168.0.1:1001"}) def test_listener(self): with mock.patch(_ZK_CLIENT) as MockClient: c = FakeKazooClient("test_model") MockClient.return_value = c discovery = service_discovery.ZKServiceDiscovery("test_model") discovery.register("ps", 0, "192.168.0.1:1001") listener = c.listeners[0] listener(KazooState.LOST) listener(KazooState.CONNECTED) self.assertDictEqual(discovery.query("ps"), {0: "192.168.0.1:1001"}) class UtilsTest(unittest.TestCase): def test_deregister_all(self): with mock.patch(_CONSUL_CLIENT) as MockClient: MockClient.return_value = FakeConsul() discovery = service_discovery.ConsulServiceDiscovery("unique_id") discovery.register("server", 0, "192.168.0.1:1001") service_discovery.deregister_all("unique_id") self.assertDictEqual(discovery.query("server"), {}) if __name__ == "__main__": unittest.main() ================================================ FILE: monolith/native_training/serving_ps_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import List from random import randint import tensorflow as tf from monolith.native_training.distribution_ops import * from idl.matrix.proto.example_pb2 import ExampleBatch, FeatureListType, Feature, \ FeatureConfigs, FeatureConfig, OutConfig, SliceConfig, PoolingType, OutType batch_size = 10 @dataclass class FeatMeta: name: str = None slot: int = None max_sequence_length: int = 1 fid_version: int = 0 slice_dims: List[int] = None table: str = None pool_type: int = 1 # 1 sum, 2 mean, 3 fristn fl_type: int = 0 # 0 INDIVIDUAL, 1 SHARED @dataclass class TableMeta: name: str = None slice_dims: List[int] = None table1 = TableMeta(name='table1', slice_dims=[1, 4, 4, 8]) table2 = TableMeta(name='table2', slice_dims=[1, 4, 4]) table3 = TableMeta(name='table3', slice_dims=[8]) features = { 'f_user_id': FeatMeta(name='f_user_id', slot=1, max_sequence_length=1, slice_dims=table1.slice_dims, table=table1.name, pool_type=1, fl_type=1), 'f_user_ctx_network': FeatMeta(name='f_user_ctx_network', slot=61, max_sequence_length=1, slice_dims=table1.slice_dims, table=table1.name, pool_type=1, fl_type=1), 'f_user_test10_array': FeatMeta(name='f_user_test10_array', slot=549, max_sequence_length=10, slice_dims=table1.slice_dims, table=table1.name, pool_type=2, fl_type=1), 'f_user_id-f_page': FeatMeta(name='f_user_id-f_page', slot=504, max_sequence_length=10, fid_version=1, slice_dims=table3.slice_dims, table=table3.name, pool_type=3, fl_type=1), 'f_goods_id': FeatMeta(name='f_user_id', slot=200, max_sequence_length=1, slice_dims=table2.slice_dims, table=table2.name, pool_type=1, fl_type=0), 'f_page': FeatMeta(name='f_page', slot=305, max_sequence_length=1, slice_dims=table2.slice_dims, table=table2.name, pool_type=1, fl_type=0), } class ServingPSTest(tf.test.TestCase): def test_example_gen(self): example_batch = ExampleBatch(batch_size=batch_size) for name, meta in features.items(): named_feature_list = example_batch.named_feature_list.add() named_feature_list.id = meta.slot named_feature_list.name = name if meta.fl_type == 0: named_feature_list.type = FeatureListType.INDIVIDUAL else: named_feature_list.type = FeatureListType.SHARED for i in range(batch_size): feature = named_feature_list.feature.add() if meta.fid_version == 0: mask = (1 << 54) - 1 feature.fid_v1_list.value.extend([ (meta.slot << 54) | (randint(1, mask) & mask) for _ in range(meta.max_sequence_length) ]) else: mask = (1 << 48) - 1 feature.fid_v2_list.value.extend([ (meta.slot << 48) | (randint(1, mask) & mask) for _ in range(meta.max_sequence_length) ]) if named_feature_list.type == FeatureListType.SHARED: break print(example_batch, flush=True) def test_conf_gen(self): feature_configs = FeatureConfigs() for name, meta in features.items(): feat_conf = FeatureConfig(table=meta.table) if meta.max_sequence_length > 1 and meta.pool_type == 3: max_sequence_length = meta.max_sequence_length feat_conf.slice_dims.extend(meta.slice_dims) if meta.pool_type == 1: feat_conf.pooling_type = PoolingType.SUM elif meta.pool_type == 2: feat_conf.pooling_type = PoolingType.MEAN else: feat_conf.pooling_type = PoolingType.FIRSTN feature_configs.feature_configs[name].CopyFrom(feat_conf) bias = OutConfig() bias.out_type = OutType.CONCAT bias_shape = (batch_size, len(features) - 1) sub_shape = bias.shape.add() sub_shape.dims.extend(bias_shape) for name, meta in features.items(): if meta.pool_type != 3: slice_config = bias.slice_configs.add() slice_config.feature_name = name slice_config.start = 0 slice_config.end = 1 feature_configs.out_configs['bias'].CopyFrom(bias) vec = OutConfig() vec.out_type = OutType.CONCAT vec_shape = (batch_size, (len(features) - 1) * 4) sub_shape = vec.shape.add() sub_shape.dims.extend(vec_shape) for name, meta in features.items(): if meta.pool_type != 3: slice_config = vec.slice_configs.add() slice_config.feature_name = name slice_config.start = 1 slice_config.end = 5 feature_configs.out_configs['vec'].CopyFrom(vec) uffm = OutConfig() uffm.out_type = OutType.NONE uffm_shape = (batch_size, 4) for name, meta in features.items(): if meta.pool_type != 3 and 'user' in name: sub_shape = uffm.shape.add() sub_shape.dims.extend(uffm_shape) slice_config = uffm.slice_configs.add() slice_config.feature_name = name slice_config.start = 5 slice_config.end = 8 feature_configs.out_configs['uffm'].CopyFrom(uffm) iffm = OutConfig() iffm.out_type = OutType.NONE iffm_shape = (batch_size, 4) for name, meta in features.items(): if meta.pool_type != 3 and 'user' not in name: sub_shape = iffm.shape.add() sub_shape.dims.extend(iffm_shape) slice_config = iffm.slice_configs.add() slice_config.feature_name = name slice_config.start = 5 slice_config.end = 8 feature_configs.out_configs['iffm'].CopyFrom(iffm) seq = OutConfig() seq.out_type = OutType.NONE meta = features['f_user_id-f_page'] seq_shape = (batch_size, meta.slice_dims[0], meta.max_sequence_length) sub_shape = seq.shape.add() sub_shape.dims.extend(seq_shape) slice_config = seq.slice_configs.add() slice_config.feature_name = meta.name slice_config.start = 0 slice_config.end = 8 feature_configs.out_configs['seq'].CopyFrom(seq) user_only = OutConfig() user_only.out_type = OutType.STACK sub_shape = user_only.shape.add() user_only_shape = (batch_size, 8, 3) sub_shape.dims.extend(user_only_shape) for name, meta in features.items(): if meta.pool_type != 3 and 'user' in name and '-' not in name: slice_config = user_only.slice_configs.add() slice_config.feature_name = name slice_config.start = 8 slice_config.end = 16 feature_configs.out_configs['user_only'].CopyFrom(user_only) print(feature_configs, flush=True) if __name__ == '__main__': tf.compat.v1.disable_eager_execution() tf.test.main() ================================================ FILE: monolith/native_training/session_run_hooks.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import tensorflow as tf import random import time from absl import logging from datetime import datetime from tensorflow.python.training import training_util def before(hour1, minute1, hour2, minute2): if hour1 < hour2 or (hour1 == hour2 and minute1 < minute2): return True else: return False def tide_available_now(tide_start_hour, tide_start_minute, tide_end_hour, tide_end_minute): if before(tide_start_hour, tide_start_minute, tide_end_hour, tide_end_minute): if not before(datetime.utcnow().hour, datetime.utcnow().minute, tide_start_hour, tide_start_minute) and before( datetime.utcnow().hour, datetime.utcnow().minute, tide_end_hour, tide_end_minute): return True else: return False else: if before(datetime.utcnow().hour, datetime.utcnow().minute, tide_start_hour, tide_start_minute) or not before( datetime.utcnow().hour, datetime.utcnow().minute, tide_end_hour, tide_end_minute): return True else: return False class CustomGlobalStepWaiterHook(tf.estimator.SessionRunHook): """Delays execution until global step reaches `wait_until_step`. This hook delays execution until global step reaches to `wait_until_step`. It is used to gradually start workers in distributed settings. One example usage would be setting `wait_until_step=int(K*log(task_id+1))` assuming that task_id=0 is the chief. """ def __init__(self, wait_until_step, tide_start_hour=None, tide_start_minute=None, tide_end_hour=None, tide_end_minute=None, max_non_tide_wait_minute=10): """Initializes a `GlobalStepWaiterHook`. Args: wait_until_step: an `int` shows until which global step should we wait. tide_start_hour: the first hour in utc timezone when tide resources are available. tide_end_hour: the last hour in utc timezone when tide resources are available. """ self._wait_until_step = wait_until_step self._tide_start_hour = tide_start_hour self._tide_start_minute = tide_start_minute self._tide_end_hour = tide_end_hour self._tide_end_minute = tide_end_minute self._hook_start_time = None random_extra_seconds = random.randint(0, 600) self._non_tide_wait_second = max_non_tide_wait_minute * 60 + random_extra_seconds def begin(self): self._worker_is_started = False self._global_step_tensor = training_util._get_or_create_global_step_read() # pylint: disable=protected-access if self._global_step_tensor is None: raise RuntimeError( "Global step should be created to use _GlobalStepWaiterHook.") def before_run(self, run_context): if self._worker_is_started: return None if self._wait_until_step <= 0: self._worker_is_started = True return None while True: if self._tide_start_hour is not None and self._tide_end_hour is not None: if not tide_available_now(self._tide_start_hour, self._tide_start_minute, self._tide_end_hour, self._tide_end_minute): logging.info("Current UTC time: {} : {}".format( datetime.utcnow().hour, datetime.utcnow().minute)) logging.info("Last hour in tide queue. Saving ckpt...") run_context.request_stop() return current_step = run_context.session.run(self._global_step_tensor) if self._hook_start_time is None and current_step > 1: # Wait for the chief node to start training for at least one step # before starting the timer. self._hook_start_time = time.time() if current_step >= self._wait_until_step: self._worker_is_started = True logging.info( "Start training after waiting for {} global steps. Current step is {}." .format(self._wait_until_step, current_step)) return None has_been_waiting_seconds = None if self._hook_start_time is not None: has_been_waiting_seconds = time.time() - self._hook_start_time if has_been_waiting_seconds > self._non_tide_wait_second: self._worker_is_started = True logging.info( "Start training after waiting for {} seconds. Current step is {}." .format(self._non_tide_wait_second, current_step)) return None logging.log_every_n_seconds( logging.INFO, "Waiting for global_step >= {} or waiting time > {} seconds before starting training. " "Current step is {}, has been waiting for {} seconds already.".format( self._wait_until_step, self._non_tide_wait_second, current_step, has_been_waiting_seconds), 60) time.sleep(0.5) class TideStoppingHook(tf.estimator.SessionRunHook): def __init__(self, tide_start_hour=None, tide_start_minute=None, tide_end_hour=None, tide_end_minute=None): """Initializes a `GlobalStepWaiterHook`. Args: wait_until_step: an `int` shows until which global step should we wait. tide_start_hour: the first hour in utc timezone when tide resources are available. tide_end_hour: the last hour in utc timezone when tide resources are available. """ self._tide_start_hour = tide_start_hour self._tide_start_minute = tide_start_minute self._tide_end_hour = tide_end_hour self._tide_end_minute = tide_end_minute def before_run(self, run_context): if self._tide_start_hour is not None and self._tide_end_hour is not None: if not tide_available_now(self._tide_start_hour, self._tide_start_minute, self._tide_end_hour, self._tide_end_minute): logging.info("Current UTC time: {} : {}".format( datetime.utcnow().hour, datetime.utcnow().minute)) logging.info("Last hour in tide queue. Saving ckpt...") run_context.request_stop() ================================================ FILE: monolith/native_training/session_run_hooks_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import threading import time import tensorflow as tf from absl import logging from freezegun import freeze_time from tensorflow.python.platform import test from monolith.native_training import session_run_hooks class MockDateTime: def __init__(self, hour, minute): self.hour = hour self.minute = minute class GlobalStepWaiterHookTest(tf.test.TestCase): def test_not_wait_for_step_zero(self): with tf.compat.v1.Graph().as_default(): tf.compat.v1.train.get_or_create_global_step() hook = session_run_hooks.CustomGlobalStepWaiterHook(wait_until_step=0) hook.begin() with tf.compat.v1.Session() as sess: # Before run should return without waiting gstep increment. hook.before_run( tf.estimator.SessionRunContext(original_args=None, session=sess)) @freeze_time("2012-01-14 10:00:00") def test_not_wait_if_tide_not_available(self): with tf.compat.v1.Graph().as_default(): tf.compat.v1.train.get_or_create_global_step() hook = session_run_hooks.CustomGlobalStepWaiterHook(wait_until_step=0, tide_start_hour=1, tide_start_minute=0, tide_end_hour=5, tide_end_minute=0) hook.begin() with tf.compat.v1.Session() as sess: # Before run should return without waiting gstep increment. hook.before_run( tf.estimator.SessionRunContext(original_args=None, session=sess)) @test.mock.patch.object(time, 'sleep') def test_wait_for_step(self, mock_sleep): with tf.compat.v1.Graph().as_default(): gstep = tf.compat.v1.train.get_or_create_global_step() hook = session_run_hooks.CustomGlobalStepWaiterHook(wait_until_step=1000) hook.begin() with tf.compat.v1.Session() as sess: # Mock out calls to time.sleep() to update the global step. class Context(object): counter = 0 def mock_sleep_side_effect(seconds): del seconds # argument is ignored Context.counter += 1 if Context.counter == 1: # The first time sleep() is called, we update the global_step from # 0 to 500. sess.run(tf.compat.v1.assign(gstep, 500)) elif Context.counter == 2: # The second time sleep() is called, we update the global_step from # 500 to 1100. sess.run(tf.compat.v1.assign(gstep, 1100)) else: raise AssertionError( 'Expected before_run() to terminate after the second call to ' 'time.sleep()') mock_sleep.side_effect = mock_sleep_side_effect # Run the mocked-out interaction with the hook. self.evaluate(tf.compat.v1.global_variables_initializer()) run_context = tf.estimator.SessionRunContext(original_args=None, session=sess) hook.before_run(run_context) self.assertEqual(Context.counter, 2) class MockSessionRunContext: def __init__(self): self.requested_stop = False def request_stop(self): logging.info("stop requested") self.requested_stop = True logging.info(self.requested_stop) class TideStoppingHookTest(tf.test.TestCase): @freeze_time("2012-01-14 10:00:00") def test_stop_if_tide_not_available(self): with tf.compat.v1.Graph().as_default(): hook = session_run_hooks.TideStoppingHook(tide_start_hour=1, tide_start_minute=0, tide_end_hour=5, tide_end_minute=0) hook.begin() with tf.compat.v1.Session() as _: # Before run should return without waiting gstep increment. context = MockSessionRunContext() hook.before_run(context) self.assertEqual(context.requested_stop, True) @freeze_time("2012-01-14 10:00:00") def test_do_not_stop_if_tide_available(self): with tf.compat.v1.Graph().as_default(): hook = session_run_hooks.TideStoppingHook(tide_start_hour=1, tide_start_minute=0, tide_end_hour=12, tide_end_minute=0) hook.begin() with tf.compat.v1.Session() as _: # Before run should return without waiting gstep increment. context = MockSessionRunContext() hook.before_run(context) self.assertEqual(context.requested_stop, False) if __name__ == '__main__': tf.compat.v1.disable_eager_execution() tf.test.main() ================================================ FILE: monolith/native_training/signal_utils.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import traceback import signal import time def print_stack_trace(sig, frame): for line in traceback.format_stack(frame): print(line.strip()) def add_siguser1_handler(): ret = signal.getsignal(signal.SIGUSR1) def handler(sig, frame): if callable(ret): ret(sig, frame) print_stack_trace(sig, frame) signal.signal(signal.SIGUSR1, handler) # Adds default handler add_siguser1_handler() ================================================ FILE: monolith/native_training/signal_utils_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import signal import unittest from monolith.native_training import signal_utils class SignalUtilsTest(unittest.TestCase): def testBasic(self): # Add twice to test two handlers case. signal_utils.add_siguser1_handler() signal.raise_signal(signal.SIGUSR1) if __name__ == "__main__": unittest.main() ================================================ FILE: monolith/native_training/static_reshape_op.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from absl import logging from typing import List, Tuple import tensorflow as tf from monolith.native_training.runtime.ops import gen_monolith_ops reshape_ops = gen_monolith_ops def static_reshape( inputs: List[tf.Tensor], shapes: List[Tuple[int]], enable_parallelism: bool = True) -> Tuple[List[tf.Tensor], tf.Tensor]: """ Arguments: inputs: List of input tensors. shapes: List of target shapes for input tensors. Returns: outputs: List of reshaped tensors. sizes: A Tensor containing the size of output tensors. """ return reshape_ops.monolith_static_reshape_n( inputs=inputs, shapes=shapes, enable_parallelism=enable_parallelism) class StaticReshapeNBuilder: def __init__(self, enable_parallelism: bool = True): self.inputs = [] self.shapes = [] self.enable_parallelism = enable_parallelism def add(self, input: tf.Tensor, shape: Tuple[int]) -> int: """Returns index of input added.""" self.inputs.append(input) self.shapes.append(shape) return len(self.inputs) - 1 def build(self): return static_reshape(inputs=self.inputs, shapes=self.shapes, enable_parallelism=self.enable_parallelism) ================================================ FILE: monolith/native_training/static_reshape_op_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import numpy as np import tensorflow as tf from tensorflow.python.framework import ops from monolith.native_training import static_reshape_op class StaticReshapeOpTest(tf.test.TestCase): def test_static_reshape_n(self): inputs = [ tf.ones(shape=(5,), dtype=tf.int32), tf.ones(shape=(4, 10), dtype=tf.float32), tf.ones(shape=(2, 2, 3), dtype=tf.int64), ] shapes = [ (1, 5), (5, 8), (None, 2), ] with tf.compat.v1.Session() as sess: res = static_reshape_op.static_reshape(inputs, shapes) outputs, sizes = sess.run(res) self.assertAllEqual(sizes, [5, 40, 12]) for out, shape in zip(outputs, shapes): self.assertEqual(len(out.shape), len(shape)) for d in range(len(shape)): if shape[d] is not None: self.assertEqual(out.shape[d], shape[d]) def test_nested_reshape_n(self): builder = static_reshape_op.StaticReshapeNBuilder() structure = [{ "key_0": tf.constant([[1, 2, 3], [4, 5, 6]], dtype=tf.float32), "key_1": tf.constant([[3, 3], [5, 5]], dtype=tf.float32) }, { "key_1": tf.constant([[0], [1], [2]], dtype=tf.float32) }] target = [{ "key_0": np.array([1, 2, 3, 4, 5, 6], dtype=np.float32), "key_1": np.array([3, 3, 5, 5], dtype=np.float32) }, { "key_1": np.array([0, 1, 2], dtype=np.float32) }] def flatten(tensor): return builder.add(tensor, (None,)) list_keyed_id = tf.nest.map_structure(flatten, structure) res = builder.build() with tf.compat.v1.Session() as sess: outputs, _ = sess.run(res) for di, dt in zip(list_keyed_id, target): for k in sorted(dt.keys()): self.assertAllEqual(outputs[di[k]], dt[k]) if __name__ == "__main__": tf.compat.v1.disable_eager_execution() tf.test.main() ================================================ FILE: monolith/native_training/summary/BUILD ================================================ load("@rules_python//python:defs.bzl", "py_binary", "py_library", "py_test") load("@org_tensorflow//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test", "tf_kernel_library") load("@pip_deps//:requirements.bzl", "requirement") package( default_visibility = ["//visibility:public"], ) py_library( name = "utils", srcs = ["utils.py"], srcs_version = "PY3", deps = [requirement('tensorboard')], ) py_test( name = "utils_test", srcs = ["utils_test.py"], python_version = "PY3", srcs_version = "PY3", deps = [ ":utils", ], ) py_library( name = "summary_ops", srcs = ["summary_ops.py"], srcs_version = "PY3", deps = [ ":utils", "//monolith/native_training/layers:layer_ops", ] ) py_test( name = "summary_ops_test", srcs = ["summary_ops_test.py"], python_version = "PY3", srcs_version = "PY3", deps = [ ":summary_ops", ], ) ================================================ FILE: monolith/native_training/summary/summary_ops.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import json from typing import List, Dict, Union, Optional import tensorflow as tf from tensorflow.python.framework import ops from tensorflow.python.util.tf_export import tf_export from monolith.native_training.summary import utils from monolith.native_training.summary.utils import SummaryType from monolith.native_training.layers.layer_ops import feature_insight @tf_export(v1=["summary.nas_data"]) def nas_data(weight, segment_names=None, segment_sizes=None, group_info=None, raw_tag=None, collections=None, description=None, name=None): meta_content, summaty_type = utils.prepare_head(segment_names, segment_sizes, group_info, raw_tag, out_type='bytes') name = f'{name}_{summaty_type}' if name else summaty_type description = description or summaty_type with tf.name_scope(name): return tf.compat.v1.summary.tensor_summary( name=utils.MONOLITH_NAS_DATA, tensor=weight, collections=collections or [ops.GraphKeys.SUMMARIES], summary_metadata=utils.create_summary_metadata(description, meta_content), ) @tf_export(v1=["summary.feature_insight_data"]) def feature_insight_data(input_tensor: tf.Tensor, segment_names: List[str], segment_sizes: List[int], weight: tf.Tensor = None, group_info: Union[List[int], List[List[int]]] = None, label: tf.Tensor = None, collections: List[str] = None, description: str = None, name: str = None): assert segment_sizes is not None and len(segment_names) == len(segment_sizes) aggregate = True if label is None else False raw_tag = SummaryType.FEATURE_INSIGHT_DIRECT if aggregate else SummaryType.FEATURE_INSIGHT_TRAIN if weight is None: summary_data = input_tensor else: summary_data = feature_insight( input_embedding=input_tensor, weight=weight, segment_sizes=segment_sizes, aggregate=aggregate) segment_sizes = [1 if aggregate else weight.shape.as_list()[-1]] * len(segment_sizes) meta_content, summaty_type = utils.prepare_head(segment_names, segment_sizes, group_info, raw_tag=raw_tag, out_type='json') name = f'{name}_{summaty_type}' if name else summaty_type description = description or summaty_type if label is not None: if label.dtype != tf.float32: label = tf.cast(label, dtype=tf.float32) if label.shape.rank == 1: label = tf.reshape(label, shape=(-1, 1)) meta_content['label_size'] = 1 else: meta_content['label_size'] = label.shape.as_list()[-1] summary_data = tf.concat(values=[summary_data, label], axis=1) else: meta_content['label_size'] = 0 with tf.name_scope(name): return tf.compat.v1.summary.tensor_summary( name=utils.MONOLITH_FI_DATA, tensor=summary_data, collections=collections or [ops.GraphKeys.SUMMARIES], summary_metadata=utils.create_summary_metadata(description, json.dumps(meta_content)), ) ================================================ FILE: monolith/native_training/summary/summary_ops_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import json import unittest import tensorflow as tf from tensorboard import plugin_util from tensorboard.backend.event_processing.plugin_event_multiplexer import EventMultiplexer from tensorboard.backend.event_processing.data_provider import MultiplexerDataProvider from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState from tensorboard.data.provider import DataProvider, RunTagFilter from monolith.native_training.summary.utils import PLUGIN_NAME, prepare_head from monolith.native_training.summary import summary_ops tf.compat.v1.disable_eager_execution() class SummaryTest(unittest.TestCase): @classmethod def setUpClass(cls) -> None: cls.log_dir = "demo_logs_v1" cls.sess = tf.compat.v1.Session() segment_names = ['f1', 'f2', 'f3'] segment_sizes = [3, 5, 9] group_info = [['f1', 'f2'], ['f3', 'f4'], ['f5', 'f6']] cls.weight = [0.5, 0.8] weight=tf.constant(value=cls.weight, dtype=tf.float32, name='weight') summary_ops.nas_data(weight, segment_names, segment_sizes, group_info) input_tensor = tf.random.uniform(shape=(3, 17), dtype=tf.float32) label = tf.constant(value=[1, 0, 1], shape=(3,), dtype=tf.float32) weight_tensor = tf.random.uniform(shape=(17, 2), dtype=tf.float32) summary_ops.feature_insight_data(input_tensor, segment_names, segment_sizes, label=label, weight=weight_tensor) with cls.sess.as_default(): with tf.compat.v1.summary.FileWriter(cls.log_dir) as writer: summaries = tf.compat.v1.summary.merge_all() for global_step in range(10): summaries_out = cls.sess.run(summaries) writer.add_summary(summaries_out, global_step) multiplexer = EventMultiplexer() multiplexer.AddRunsFromDirectory(path=cls.log_dir) multiplexer.Reload() cls.data_provider: DataProvider = MultiplexerDataProvider( multiplexer, logdir=cls.log_dir) @classmethod def tearDownClass(cls) -> None: if tf.io.gfile.exists(cls.log_dir): tf.io.gfile.rmtree(cls.log_dir) def test_nas_data(self): ctx = plugin_util.context({}) run = '.' tag = 'gating/monolith_nas_weight' tag_info = self.data_provider.list_tensors(ctx, experiment_id='0', plugin_name=PLUGIN_NAME, run_tag_filter=RunTagFilter(runs=[run], tags=[tag])) tensors = self.data_provider.read_tensors(ctx, experiment_id='0', plugin_name=PLUGIN_NAME, downsample=100, run_tag_filter=RunTagFilter(runs=[run], tags=[tag])) tensor_tts = tag_info.get(run, {}).get(tag, None) tensor_datum = tensors.get(run, {}).get(tag, None) self.assertTrue(tensor_datum is not None) if isinstance(tensor_datum, (list, tuple)): tensor_datum = tensor_datum[-1] plugin_content = str(tensor_tts.plugin_content, encoding='utf-8') plugin_content_exp = '{"tag_type": "gating", "segment_names": ["f1", "f2", "f3"], "segment_sizes": [3, 5, 9], "group_index": [0, 0, 1]}' self.assertEqual(plugin_content, plugin_content_exp) for x, y in zip(tensor_datum.numpy, self.weight): self.assertAlmostEqual(x, y) def test_feature_insight_data(self): ctx = plugin_util.context({}) run = '.' tag = 'fi_train/monolith_feature_insight' tag_info = self.data_provider.list_tensors(ctx, experiment_id='0', plugin_name=PLUGIN_NAME, run_tag_filter=RunTagFilter(runs=[run], tags=[tag])) tensors = self.data_provider.read_tensors(ctx, experiment_id='0', plugin_name=PLUGIN_NAME, downsample=100, run_tag_filter=RunTagFilter(runs=[run], tags=[tag])) tensor_tts = tag_info.get(run, {}).get(tag, None) tensor_datum = tensors.get(run, {}).get(tag, None) self.assertTrue(tensor_datum is not None) if isinstance(tensor_datum, (list, tuple)): tensor_datum = tensor_datum[-1] plugin_content = str(tensor_tts.plugin_content, encoding='utf-8') label_size = json.loads(plugin_content)['label_size'] dim = 2 if label_size > 0 else 1 plugin_content_exp = '{"tag_type": "fi_train", "segment_names": ["f1", "f2", "f3"], "segment_sizes": [2, 2, 2], "group_index": [0, 1, 2], "label_size": 1}' self.assertEqual(plugin_content, plugin_content_exp) shape_exp = (3, 7 if label_size > 0 else 3) self.assertTupleEqual(tensor_datum.numpy.shape, shape_exp) if __name__ == "__main__": unittest.main() ================================================ FILE: monolith/native_training/summary/utils.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import json from functools import lru_cache from typing import Any, Dict, List, Tuple, Union import tensorflow as tf from tensorboard.compat.proto import summary_pb2 PLUGIN_NAME = 'monolith' MONOLITH_NAS_DATA = f'{PLUGIN_NAME}_nas_weight' MONOLITH_FI_DATA = f'{PLUGIN_NAME}_feature_insight' KTYPE, KMETA, KDATA = 'tag_type', 'meta', 'data' class SummaryType(object): GATING = 'gating' SELECTING = 'selecting' MIXED = 'mixed' SIMPLE = 'simple' FEATURE_INSIGHT_DIRECT = 'fi_direct' FEATURE_INSIGHT_TRAIN = 'fi_train' # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/summary.proto def create_summary_metadata(description: str = None, meta_content=b''): return summary_pb2.SummaryMetadata( summary_description=description, plugin_data=summary_pb2.SummaryMetadata.PluginData( plugin_name=PLUGIN_NAME, content=meta_content.encode('utf-8') if isinstance(meta_content, str) else meta_content, ), data_class=summary_pb2.DATA_CLASS_TENSOR, ) def _name_to_group_id(segment_names: List[str], group_info: List[List[str]]): if group_info: name_to_group: Dict[str, int] = {} for i, group in enumerate(group_info): for name in group: name_to_group[name] = i group_id_to_names: Dict[int, List[str]] = {} for name in segment_names: assert name in name_to_group group_id = name_to_group[name] if group_id in group_id_to_names: group_id_to_names[group_id].append(name) else: group_id_to_names[group_id] = [name] name_to_reorder_id = {} for idx, group_id in enumerate(sorted(group_id_to_names)): for name in group_id_to_names[group_id]: name_to_reorder_id[name] = idx name_to_reorder_id = name_to_reorder_id else: name_to_reorder_id = {name: idx for idx, name in enumerate(segment_names)} return name_to_reorder_id def prepare_head(segment_names: List[str], segment_sizes: Union[List[int], List[List[int]]], group_info: List[List[str]] = None, raw_tag: str = None, out_type: str = 'tensor' ) -> Tuple[Any, str]: assert out_type in {'bytes', 'tensor', 'json'} if not (segment_names or segment_sizes): if out_type == 'tensor': return tf.constant(value=[b''], dtype=tf.string, shape=tuple()), raw_tag else: return b'', raw_tag raw_tag = raw_tag or ( SummaryType.GATING if all(isinstance(s, int) for s in segment_sizes) else SummaryType.SELECTING) data = { KTYPE: raw_tag, 'segment_names': segment_names, 'segment_sizes': segment_sizes, } if raw_tag in {SummaryType.GATING, SummaryType.FEATURE_INSIGHT_TRAIN}: name_to_reorder_id = _name_to_group_id(segment_names, group_info) data['group_index'] = [name_to_reorder_id[name] for name in segment_names] if out_type == 'tensor': return tf.constant(value=[json.dumps(data)], dtype=tf.string, shape=tuple()), raw_tag elif out_type == 'json': return data, raw_tag else: return json.dumps(data), raw_tag @lru_cache def get_nas_weight_json(ckpt_dir_or_file, prefix=None) -> List[str]: prefix = prefix or ARCH_TENSOR_PREFIX ckpt = tf.train.load_checkpoint(ckpt_dir_or_file=ckpt_dir_or_file) if ckpt: for name in ckpt.get_variable_to_dtype_map(): if prefix in name: return [str(v) for v in ckpt.get_tensor(name).flat] raise Exception('not arch_weights in ckpt') ================================================ FILE: monolith/native_training/summary/utils_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import unittest from monolith.native_training.summary.utils import \ prepare_head, SummaryType, get_nas_weight_json class UtilsTest(unittest.TestCase): def test_read_head_gating(self): segment_names = ['f1', 'f2', 'f3'] segment_sizes = [3, 5, 9] group_info = [['f1', 'f2'], ['f3', 'f4'], ['f5', 'f6']] data, nas_type = prepare_head(segment_names, segment_sizes, group_info) data_exp = b'{"tag_type": "gating", "segment_names": ["f1", "f2", "f3"], "segment_sizes": [3, 5, 9], "group_index": [0, 0, 1]}' self.assertEqual(nas_type, SummaryType.GATING) self.assertEqual(data.numpy(), data_exp) def test_read_head_selecting(self): segment_names = ['f1', 'f2', 'f3'] segment_sizes = [[3, 6], [5, 10], [4, 8, 16]] data, nas_type = prepare_head(segment_names, segment_sizes) data_exp = b'{"tag_type": "selecting", "segment_names": ["f1", "f2", "f3"], "segment_sizes": [[3, 6], [5, 10], [4, 8, 16]]}' self.assertEqual(nas_type, SummaryType.SELECTING) self.assertEqual(data.numpy(), data_exp) if __name__ == "__main__": unittest.main() ================================================ FILE: monolith/native_training/sync_hooks.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import time from absl import logging import tensorflow as tf class SyncHelper: #TODO(leqi.zou): maybe in the future, we want to support the dynamic number of workers. def __init__(self, num_workers: int, is_chief, var_device="/job:chief/task:0"): self._num_workers = num_workers with tf.name_scope("monolith_sync_helper"): # In distributed training, the var is local variable for chief, but global variable for worker. collections = [tf.compat.v1.GraphKeys.LOCAL_VARIABLES ] if is_chief else [tf.compat.v1.GraphKeys.VARIABLES] with tf.device(var_device): # For idx 0, represents the current restore status # For idx >0, represents if the current worker is alive self._var = tf.compat.v1.get_variable( "monolith_sync_helper/control_var", initializer=[False] * num_workers, dtype=tf.bool, trainable=False, collections=collections) self._idx_ph = tf.compat.v1.placeholder(tf.int32, shape=[], name="idx_ph") self._val_ph = tf.compat.v1.placeholder(tf.bool, shape=[], name="value_ph") self._read_value = self._var[self._idx_ph] self._assign_value = self._var[self._idx_ph].assign(self._val_ph) self._workers_status = self._var[1:] self._alive_workers = tf.where(self._workers_status) + 1 self._num_alive_workers = tf.math.reduce_sum( tf.cast(self._workers_status, tf.int32)) @property def num_workers(self): return self._num_workers def mark_restore_done(self, sess): sess.run(self._assign_value, feed_dict={ self._idx_ph: 0, self._val_ph: True }) def get_restore_status(self, sess): return sess.run(self._read_value, feed_dict={self._idx_ph: 0}) def start_worker(self, sess, idx: int): assert idx > 0 and idx < self._num_workers, f"Index {idx} is out range " sess.run(self._assign_value, feed_dict={ self._idx_ph: idx, self._val_ph: True }) def finish_worker(self, sess, idx: int): assert idx > 0 and idx < self._num_workers, f"Index {idx} is out range " sess.run(self._assign_value, feed_dict={ self._idx_ph: idx, self._val_ph: False }) def get_alive_workers(self, sess): return sess.run(self._alive_workers).flatten() def get_num_alive_workers(self, sess): return sess.run(self._num_alive_workers) _CHIEF_TIMEOUT_SECONDS = 1800 class ChiefSyncHook(tf.estimator.SessionRunHook): """ A hook that used for chief and worker sync at the beginning and at the end. """ def __init__(self, sync_helper: SyncHelper, timeout_seconds=_CHIEF_TIMEOUT_SECONDS): self._timeout_seconds = timeout_seconds self._helper = sync_helper def after_create_session(self, session, coord): self._helper.mark_restore_done(session) def end(self, session): start_time = time.time() while True: num_alive_workers = self._helper.get_num_alive_workers(session) if time.time( ) - start_time > self._timeout_seconds or num_alive_workers == 0: break logging.log_every_n_seconds( logging.INFO, "Total worker count: {}, remaining count: {}.\nRemaining workers: %s". format(self._helper.num_workers, num_alive_workers), 60, self._helper.get_alive_workers(session)) time.sleep(1) if num_alive_workers > 0: logging.info("Reach timeout seconds! Remaining worker count: {}.".format( num_alive_workers)) else: logging.info("All other workers had been finished.") class WorkerSyncHook(tf.estimator.SessionRunHook): """ A hook that used for chief and worker sync at the beginning and at the end. """ def __init__(self, worker_index, sync_helper: SyncHelper): self._worker_index = worker_index self._helper = sync_helper def after_create_session(self, session, coord): if self._worker_index > 0: self._helper.start_worker(session, self._worker_index) while not self._helper.get_restore_status(session): logging.log_every_n_seconds( logging.INFO, "The worker {} waits for start signal of chief.".format( self._worker_index), 60) time.sleep(1) def end(self, session): if self._worker_index > 0: self._helper.finish_worker(session, self._worker_index) class TrainingHooksHelper: def __init__(self, enable_sync: bool, num_workers: int, worker_idx: int, chief_timeout_seconds: int = _CHIEF_TIMEOUT_SECONDS): self._enable_sync = enable_sync self._training_chief_hooks = [] self._training_hooks = [] if self._enable_sync: sync_helper = SyncHelper(num_workers, worker_idx == 0) self._training_chief_hooks.append( ChiefSyncHook(sync_helper, timeout_seconds=chief_timeout_seconds)) self._training_hooks.append(WorkerSyncHook(worker_idx, sync_helper)) @property def training_chief_hooks(self): return tuple(self._training_chief_hooks) @property def training_hooks(self): return tuple(self._training_hooks) ================================================ FILE: monolith/native_training/sync_hooks_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import threading import time import tensorflow as tf from monolith.native_training import sync_hooks class CountHook(tf.estimator.SessionRunHook): def __init__(self): self.after_create_session_count = 0 self.end_count = 0 def after_create_session(self, session, coord): self.after_create_session_count += 1 def end(self, session): self.end_count += 1 def get_counts(self): return { 'after_create_session': self.after_create_session_count, 'end': self.end_count } def get_local_helper(num_workers): return sync_hooks.SyncHelper(num_workers, is_chief=True, var_device=None) class SyncHooksTest(tf.test.TestCase): def _after_create_session(self, sess, hooks): for hook in hooks: hook.after_create_session(sess, None) def _end(self, sess, hooks): for hook in hooks: hook.end(sess) def test_sync_process(self): with tf.compat.v1.Graph().as_default(): helper = get_local_helper(2) chief_hook = sync_hooks.ChiefSyncHook(helper) worker_hook = sync_hooks.WorkerSyncHook(1, helper) worker_count_hook = CountHook() chief_count_hook = CountHook() with tf.compat.v1.Session() as sess: sess.run(tf.compat.v1.local_variables_initializer()) worker = threading.Thread(target=self._after_create_session, args=(sess, [worker_hook, worker_count_hook])) worker.daemon = True worker.start() time.sleep(1) # Worker hook is pending at 'after_create_session'. self.assertEqual({ 'after_create_session': 0, 'end': 0, }, worker_count_hook.get_counts()) chief_hook.after_create_session(sess, None) worker.join() self.assertEqual({ 'after_create_session': 1, 'end': 0, }, worker_count_hook.get_counts()) worker_hook.after_create_session(sess, None) chief = threading.Thread(target=self._end, args=(sess, [chief_hook, chief_count_hook])) chief.daemon = True chief.start() # Chief hook is pending at 'end'. self.assertEqual({ 'after_create_session': 0, 'end': 0, }, chief_count_hook.get_counts()) # Make sure logging is covered time.sleep(1) worker_hook.end(sess) chief.join() self.assertEqual({ 'after_create_session': 0, 'end': 1, }, chief_count_hook.get_counts()) def test_hook_helper(self): h = sync_hooks.TrainingHooksHelper(False, 0, 0) self.assertEqual(h.training_hooks, ()) self.assertEqual(h.training_chief_hooks, ()) h = sync_hooks.TrainingHooksHelper(True, 1, 0) # This only for grammar check h.training_hooks h.training_chief_hooks if __name__ == '__main__': tf.compat.v1.disable_eager_execution() tf.test.main() ================================================ FILE: monolith/native_training/sync_training_hooks.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. #coding:utf-8 from cProfile import run import os import uuid import tempfile import time from datetime import datetime from absl import logging import grpc import numpy as np import tensorflow as tf from tensorflow.python.framework import ops from tensorflow.python.training import training_util from tensorflow.python.training import session_run_hook from tensorflow_serving.apis import predict_pb2, get_model_metadata_pb2 from tensorflow_serving.apis import prediction_service_pb2_grpc from monolith.agent_service import backends try: from monolith.coordinator.utils import token_utils except ImportError: pass from monolith.native_training.data import datasets from monolith.native_training import distributed_serving_ops from monolith.native_training import hvd_lib from monolith.native_training import native_task from monolith.native_training import hash_table_ops from monolith.native_training.distributed_serving_ops import ParameterSyncClient, refresh_sync_config from monolith.utils import find_main from monolith.native_training.model_export import export_context class SyncTrainingBarrierSaverListener(tf.estimator.CheckpointSaverListener): def begin(self): self._barrier_op = None self._barrier_var = tf.compat.v1.placeholder(dtype=tf.int64, shape=[], name="hvd_export_barrier_ph") self._barrier_op = hvd_lib.broadcast(tf.identity(self._barrier_var), 0) def after_save(self, session, global_step_value): logging.info(f"exporter barrier begin {hvd_lib.rank()}") try: barrier_val = session.run( self._barrier_op, feed_dict={self._barrier_var: global_step_value}) logging.info( f"exporter barrier end {hvd_lib.rank()} value: {barrier_val}") except Exception as ex: logging.error(f"barrier error: {ex}") class ParameterSyncHook(session_run_hook.SessionRunHook): """ sync parameter sync to online ps """ def __init__(self, sync_backend, ps_index, refresh_interval=100): self._sync_backend = sync_backend self._ps_index = ps_index self._refresh_interval = refresh_interval self._last_sync_time = 0 self._last_refresh_time = 0 self._sync_config = None logging.info( f"sync hook for ps_{self._ps_index} with refresh_interval={self._refresh_interval}" ) def begin(self): self._config_ph = tf.compat.v1.placeholder(tf.string, shape=(), name="sync_config_str") sync_client = ParameterSyncClient( distributed_serving_ops.parameter_sync_client_from_config( name_suffix=str(self._ps_index))) self._sync_run_step = sync_client.create_sync_op(self._config_ph) def before_run(self, run_context): cur_time = time.time() if cur_time - self._last_refresh_time >= self._refresh_interval: self._sync_config = refresh_sync_config(self._sync_backend, self._ps_index) self._last_refresh_time = cur_time return session_run_hook.SessionRunArgs( fetches=self._sync_run_step, feed_dict={self._config_ph: self._sync_config}) class SyncTrainingForceDumpHook(tf.estimator.SessionRunHook): def __init__(self, model_dir, target_timer, step_interval=100): self._model_dir = model_dir self._target_timer = target_timer self._step_interval = step_interval def begin(self): self._global_step_tensor = training_util._get_or_create_global_step_read() # pylint: disable=protected-access self._ctrl_ph = tf.compat.v1.placeholder(tf.int16, shape=(3,), name='hvd_dump_ctrl') self._broadcast_op = hvd_lib.broadcast(self._ctrl_ph, 0) def after_run(self, run_context, run_values): global_step = run_context.session.run(self._global_step_tensor) if global_step % self._step_interval == 0: utc_hour = datetime.utcnow().hour should_dump, should_stop, timer_enabled = 0, 0, 0 if hvd_lib.rank() == 0: timer_enabled = int(utc_hour >= 18 and utc_hour <= 20) logging.info(f"utc_hour: {utc_hour} time_enabled: {timer_enabled}") dump_path = os.path.join(self._model_dir, f"dump_{global_step}") stop_path = os.path.join(self._model_dir, f"stop_{global_step}") should_stop = int(tf.io.gfile.exists(stop_path)) logging.info(f"checked stop {stop_path} {should_stop}") should_dump = int(tf.io.gfile.exists(dump_path)) logging.info(f"checked dump {dump_path} {should_dump}") try: should_stop, should_dump, timer_enabled = run_context.session.run( self._broadcast_op, feed_dict={ self._ctrl_ph: [should_stop, should_dump, timer_enabled] }, options=tf.compat.v1.RunOptions(timeout_in_ms=1000 * 10)) except (RuntimeError, TypeError, ValueError, tf.errors.OpError) as ex: logging.error('Error occurred in syncing control flags: %s', str(ex)) if timer_enabled: logging.info(f"enable timer with utc_hour: {utc_hour}") self._target_timer.enable() else: logging.info(f"disable timer with utc_hour: {utc_hour}") self._target_timer.disable() if should_dump or should_stop: logging.info(f"reset and enable timer for dump at step {global_step}") self._target_timer.enable() self._target_timer.reset() if should_stop: logging.info(f"request stop at step {global_step}") run_context.request_stop() class SyncTrainingSaverControlHook(tf.estimator.SessionRunHook): def __init__(self, model_dir, target_timer, step_interval=100): self._model_dir = model_dir self._target_timer = target_timer self._step_interval = step_interval def begin(self): self._global_step_tensor = training_util._get_or_create_global_step_read() # pylint: disable=protected-access def after_run(self, run_context, run_values): global_step = run_context.session.run(self._global_step_tensor) if global_step % self._step_interval == 0: check_path = os.path.join(self._model_dir, "ONLINE") if tf.io.gfile.exists(check_path): logging.info(f"{check_path} exists, enable timer") self._target_timer.enable() else: logging.info(f"{check_path} not exists, disable timer") self._target_timer.disable() class SyncTrainingInfoHook(tf.estimator.SessionRunHook): def begin(self): self._last_timestamp = 0 self._fetches = {} for table in ops.get_collection(hash_table_ops._HASH_TABLE_GRAPH_KEY): tensor_prefix = hash_table_ops._table_tensor_prefix(table) self._fetches[tensor_prefix] = table.size() def before_run(self, run_context): cur_time = int(time.time()) if cur_time > self._last_timestamp + 600: self._last_timestamp = cur_time return tf.estimator.SessionRunArgs(self._fetches) else: return None def after_run(self, run_context, run_values): if run_values.results: logging.info("*** info: {}".format(run_values.results)) class ReqTimeControlDumpHook(tf.estimator.SessionRunHook): def __init__(self, model_dir, target_timer, step_interval=1000): self._model_dir = model_dir self._target_timer = target_timer self._step_interval = step_interval def begin(self): if hvd_lib.rank() == 0: req_time_col = tf.compat.v1.get_collection("req_time") assert len(req_time_col) == 1 self._req_time = tf.math.reduce_max(req_time_col[0]) else: self._req_time = None self._global_step_tensor = training_util._get_or_create_global_step_read() # pylint: disable=protected-access self._req_time_ph = tf.compat.v1.placeholder(tf.int64, shape=[2], name="hvd_req_time") self._req_time_bcast_op = hvd_lib.broadcast(self._req_time_ph, 0) def before_run(self, run_context): if hvd_lib.rank() == 0: return session_run_hook.SessionRunArgs( fetches={'req_time': self._req_time}) else: return None def after_run(self, run_context, run_values): global_step = run_context.session.run(self._global_step_tensor) if global_step % self._step_interval == 0: if hvd_lib.rank() == 0: req_time = run_values.results['req_time'] file_name = os.path.join(self._model_dir, "limit_req_time") if tf.io.gfile.exists(file_name): with tf.io.gfile.GFile(file_name) as f: limit_req_time = int(f.read()) else: limit_req_time = -1 else: req_time = 0 limit_req_time = -1 req_time0, limit_req_time0 = run_context.session.run( self._req_time_bcast_op, feed_dict={self._req_time_ph: [req_time, limit_req_time]}) if req_time0 >= limit_req_time0 and limit_req_time0 > 0: self._target_timer.enable() self._target_timer.reset() run_context.request_stop() return super().after_run(run_context, run_values) INPUT_FN_WRAPPER_KEY = "wrapped" class EofAwareTask: """A NativeTask like object that helps stop training before the eof was raised.""" EOF_KEY = "__EofAwareTask_eof" def __init__(self, task: native_task.NativeTask, use_dataservice: bool = False): self._ori_task = task self.use_dataservice = use_dataservice logging.info(f'init EofAwareTask') def create_input_fn(self, mode): input_fn = self._ori_task.create_input_fn(mode) def new_input_fn_factory(input_fn): def new_input_fn(): ds = input_fn() if export_context.is_dry_run_or_exporting(): return ds ds = datasets.CacheOneDataset(ds) # There are 2 reasons why we need a map here: # 1. tuple will be treated as features, label in the estimator which are wrong # 2. In sync training, reorder_fids_in_data_pipeline should be able to get # the original data after we wrap the input_fn output. def map_fn(features, eof): if isinstance(features, dict): logging.info(f"in map_fn: {EofAwareTask.EOF_KEY}") return {**features, EofAwareTask.EOF_KEY: eof} logging.info('map_fn keys: 1, 2') return {"1": features, "2": eof} return ds.map(map_fn) return new_input_fn if self.use_dataservice: return new_input_fn_factory(input_fn) else: return input_fn def create_model_fn(self): model_fn = self._ori_task.create_model_fn() def new_model_fn_factory(model_fn): if export_context.is_dry_run_or_exporting(): return model_fn def new_model_fn(features, mode, config): spec: tf.estimator.EstimatorSpec = model_fn(features, mode, config) dequeued_eof = tf.compat.v1.get_collection(EofAwareTask.EOF_KEY) if dequeued_eof and not export_context.is_dry_run_or_exporting(): if isinstance(dequeued_eof, (list, tuple)): dequeued_eof = dequeued_eof[0] assert(isinstance(dequeued_eof, tf.Tensor)) training_hooks = spec.training_hooks or () training_hooks = [self.EofHook(dequeued_eof)] + list(training_hooks) spec = spec._replace(training_hooks=training_hooks) return spec else: return spec return new_model_fn if self.use_dataservice: return new_model_fn_factory(model_fn) else: return model_fn def __getattr__(self, name): return getattr(self._ori_task, name) class EofHook(tf.estimator.SessionRunHook): def __init__(self, eof_tensor): eof_tensor_for_gather = tf.reshape(tf.cast(eof_tensor, dtype=tf.int32), [1], name="eof_tensor_for_all_gather") eof_tensors = hvd_lib.allgather(eof_tensor_for_gather) self._agg_eof = tf.math.reduce_sum(eof_tensors) def before_run(self, run_context): return tf.estimator.SessionRunArgs(fetches=self._agg_eof) def after_run(self, run_context, run_values): if run_values.results: logging.info(f'rank {hvd_lib.rank()} request_stop, results is {run_values.results}, before') run_context.request_stop() logging.info(f'rank {hvd_lib.rank()} request_stop, results is {run_values.results}, after') ================================================ FILE: monolith/native_training/sync_training_hooks_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import tensorflow as tf from monolith.native_training import native_task from monolith.native_training import hvd_lib from monolith.native_training import sync_training_hooks class EofAwareTaskTest(tf.test.TestCase): def test_basic(self): class TestTask(native_task.NativeTask): def create_input_fn(self, mode): def input_fn(): return tf.data.Dataset.from_tensor_slices( tf.constant([1, 2, 3], dtype=tf.int64)) return input_fn def create_model_fn(self): def model_fn(features, mode, config): gs = tf.compat.v1.train.get_or_create_global_step() train_op = tf.compat.v1.assign_add(gs, features) return tf.estimator.EstimatorSpec(mode, train_op=train_op, loss=tf.constant(0.0)) return model_fn hvd_lib.init() p = TestTask.params() p.name = "test" t = TestTask(p) t = sync_training_hooks.EofAwareTask(t) est = tf.estimator.Estimator(t.create_model_fn()) est.train(t.create_input_fn(tf.estimator.ModeKeys.TRAIN)) self.assertEqual(est.get_variable_value("global_step"), 6) def test_dict(self): class TestTask(native_task.NativeTask): def create_input_fn(self, mode): def input_fn(): ds = tf.data.Dataset.from_tensor_slices( tf.constant([1, 2, 3], dtype=tf.int64)) ds = ds.map(lambda x: {"1": x}) return ds return input_fn def create_model_fn(self): def model_fn(features, mode, config): gs = tf.compat.v1.train.get_or_create_global_step() train_op = tf.compat.v1.assign_add(gs, features["1"]) return tf.estimator.EstimatorSpec(mode, train_op=train_op, loss=tf.constant(0.0)) return model_fn hvd_lib.init() p = TestTask.params() p.name = "test" t = TestTask(p) t = sync_training_hooks.EofAwareTask(t) est = tf.estimator.Estimator(t.create_model_fn()) est.train(t.create_input_fn(tf.estimator.ModeKeys.TRAIN)) self.assertEqual(est.get_variable_value("global_step"), 6) if __name__ == "__main__": tf.test.main() ================================================ FILE: monolith/native_training/tensor_utils.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import tensorflow as tf from typing import Callable, DefaultDict, Dict, Iterable, List, Tuple, Optional from monolith.native_training.static_reshape_op import static_reshape, StaticReshapeNBuilder def maybe_squeeze_3d_tensor(x: tf.RaggedTensor): """Expected to return a raggedtensor which shape is [None/batch_size, None (RaggedRank)] Supports tensor type: [None/batch_size, None], [None/batch_size, 1, None] """ if not isinstance(x, tf.RaggedTensor): raise ValueError("input must be RaggedTensor") if len(x.shape) == 2: return x elif len(x.shape) == 3: return tf.squeeze(x, axis=1) else: raise ValueError("Unknown shape of RaggedTensor. ", x) def pack_tensors( keyed_tensors: Dict[str, tf.Tensor]) -> Tuple[tf.Tensor, tf.Tensor]: """Compact multiple tensors into 1 tensor.""" builder = StaticReshapeNBuilder() for key in sorted(keyed_tensors): builder.add(keyed_tensors[key], (None,)) outputs, sizes = builder.build() return tf.concat(outputs, 0), sizes def get_keyed_shape( keyed_tensors: Dict[str, tf.Tensor]) -> Dict[str, List[int]]: return {key: val.shape.as_list() for key, val in keyed_tensors.items()} def unpack_tensors(keyed_shape: Dict[str, List[int]], packed: Tuple[tf.Tensor, tf.Tensor]) -> Dict[str, tf.Tensor]: """The reverse method of _pack_tensors.""" m = {} tensor, length = packed[0], packed[1] flat_tensors = tf.split(tensor, length, num=len(keyed_shape)) builder = StaticReshapeNBuilder() for i, key in enumerate(sorted(keyed_shape)): builder.add(flat_tensors[i], keyed_shape[key]) outputs, _ = builder.build() for i, key in enumerate(sorted(keyed_shape)): m[key] = outputs[i] return m def _get_flat_tensor_and_size(input_tensor): reshaped = tf.reshape(input_tensor, [-1]) return reshaped, tf.size(reshaped) def split_tensors_with_type( keyed_tensors: Dict[str, tf.Tensor]) -> List[Dict[str, tf.Tensor]]: type_dict_dict = {} type_set = set() for key in sorted(keyed_tensors): tensor = keyed_tensors[key] if (str(tensor.dtype) not in type_set): type_set.add(str(tensor.dtype)) type_dict_dict[str(tensor.dtype)] = {} type_dict_dict[str(tensor.dtype)][key] = tensor convert_list = [] for key in sorted(type_dict_dict): convert_list.append(type_dict_dict[key]) return convert_list def merge_dicts( tensor_dict_list: List[Dict[str, tf.Tensor]]) -> Dict[str, tf.Tensor]: res_d = {} for d in tensor_dict_list: for key in d.keys(): res_d[key] = d[key] return res_d def pack_typed_keyed_tensors( list_keyed_tensors: List[Dict[str, tf.Tensor]]) -> List[tf.Tensor]: builder = StaticReshapeNBuilder() def flatten(tensor): return builder.add(tensor, (None,)) list_keyed_id = tf.nest.map_structure(flatten, list_keyed_tensors) outputs, sizes = builder.build() packed_tensors = [] packed_size_size_list = [] for d in list_keyed_id: tensors = [outputs[d[key]] for key in sorted(d.keys())] packed_tensors.append(tf.concat(tensors, 0)) packed_size_size_list.append(len(d.keys())) packed_size_size = tf.constant(packed_size_size_list, dtype=tf.int64) concat_offset_size = tf.concat([packed_size_size, sizes], 0) packed_tensors.append(concat_offset_size) return packed_tensors def get_typed_keyed_shape( list_keyed_tensors: List[Dict[str, tf.Tensor]]) -> List[Dict[str, List[int]]]: list_keyed_shape = [] for d in list_keyed_tensors: list_keyed_shape.append(get_keyed_shape(d)) return list_keyed_shape def unpack_packed_tensors( list_keyed_shape: List[Dict[str, List[int]]], packed_list: List[tf.Tensor]) -> List[Dict[str, tf.Tensor]]: length = len(packed_list) if (length < 2): raise ValueError("Wrong packed_list length") concat_offset_size = packed_list[-1] packed_size_size = tf.slice(concat_offset_size, [0], [length - 1]) packed_size = tf.slice(concat_offset_size, [length - 1], [-1]) packed_size = tf.split(packed_size, packed_size_size, num=len(list_keyed_shape)) builder = StaticReshapeNBuilder() unpack_list = [] for i, d in enumerate(list_keyed_shape): d_size = packed_size[i] d_tensor = tf.split(packed_list[i], d_size, num=len(list_keyed_shape[i])) for j, key in enumerate(sorted(d.keys())): builder.add(d_tensor[j], list_keyed_shape[i][key]) outputs, _ = builder.build() idx = 0 for d in list_keyed_shape: unpack_d = {} for key in sorted(d.keys()): unpack_d[key] = outputs[idx] idx += 1 unpack_list.append(unpack_d) return unpack_list ================================================ FILE: monolith/native_training/tensor_utils_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import tensorflow as tf from monolith.native_training import tensor_utils class TensorUtilsTest(tf.test.TestCase): def test_maybe_squeeze_3d_tensor(self): x = tf.ragged.constant([[0, 1], [2]]) y = tf.RaggedTensor.from_uniform_row_length(x, 1) sx = tensor_utils.maybe_squeeze_3d_tensor(x) sy = tensor_utils.maybe_squeeze_3d_tensor(y) with self.session() as sess: sx_value, sy_value = sess.run([sx, sy]) for squeezed in (sx_value, sy_value): self.assertAllEqual(squeezed, [[0, 1], [2]]) def test_pack_tensors(self): x = tf.constant([1, 2], dtype=tf.int64) y = tf.constant([[4, 5], [6, 7]], dtype=tf.int64) d = {"x": x, "y": y} packed_d = tensor_utils.pack_tensors(d) unpacked_d = tensor_utils.unpack_tensors(tensor_utils.get_keyed_shape(d), packed_d) with self.session() as sess: packed_d_value = sess.run(packed_d) self.assertAllEqual(packed_d_value[0], [1, 2, 4, 5, 6, 7]) self.assertAllEqual(packed_d_value[1], [2, 4]) unpacked_d_value = sess.run(unpacked_d) original_d = sess.run(d) for key in sorted(unpacked_d_value): self.assertAllEqual(unpacked_d_value[key], original_d[key]) def test_pack_typed_keyed_tensors(self): t1 = tf.constant([[0, 0, 1, 0], [0, 2, 0, 3]], dtype=tf.int64) t2 = tf.constant([0, 3, 1, 2], dtype=tf.int64) t3 = tf.constant([9.1, 2.2], dtype=tf.float32) t4 = tf.constant([[1.1, 2.2], [3.3, 4.4]], dtype=tf.float32) t5 = tf.constant([3, 4, 5, 6, 7, 8, 9], dtype=tf.float64) d1 = {"t1": t1, "t2": t2} d2 = {"t4": t4, "t3": t3} d3 = {"t5": t5} l = [d1, d2, d3] packed_l = tensor_utils.pack_typed_keyed_tensors(l) unpacked_l = tensor_utils.unpack_packed_tensors( tensor_utils.get_typed_keyed_shape(l), packed_l) packed_d1 = tensor_utils.pack_tensors(d1) packed_d2 = tensor_utils.pack_tensors(d2) packed_d3 = tensor_utils.pack_tensors(d3) with self.session() as sess: packed_l_value = sess.run(packed_l) packed_d1_value = sess.run(packed_d1) packed_d2_value = sess.run(packed_d2) packed_d3_value = sess.run(packed_d3) self.assertAllEqual(packed_l_value[0], packed_d1_value[0]) self.assertAllEqual(packed_l_value[1], packed_d2_value[0]) self.assertAllEqual(packed_l_value[2], packed_d3_value[0]) self.assertAllEqual(packed_l_value[3], [2, 2, 1, 8, 4, 2, 4, 7]) unpacked_l_value = sess.run(unpacked_l) l_value = sess.run(l) for i, d in enumerate(unpacked_l_value): for key in sorted(d): self.assertAllEqual(d[key], l_value[i][key]) def test_pack_typed_keyed_tensors_with_placeholder(self): t1 = tf.compat.v1.placeholder(tf.int32, shape=( None, 3, 4, ), name="t1_placeholder") t2 = tf.compat.v1.placeholder(tf.int32, shape=(None, 2), name="t2_place_holder") t3 = tf.compat.v1.placeholder(tf.float32, shape=(None,), name="t3_placeholder") t4 = tf.compat.v1.placeholder(tf.float32, shape=(None, 2, 2), name="t4_placeholder") d1 = {"t1": t1, "t2": t2} d2 = {"t3": t3, "t4": t4} l = [d1, d2] t5 = tf.constant([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24 ], dtype=tf.int32) t5 = tf.reshape(t5, [-1, 3, 4]) t6 = tf.constant([9, 8, 7, 6]) t6 = tf.reshape(t6, [-1, 2]) t7 = tf.constant([9, 12, 15], dtype=tf.float32) t8 = tf.constant([1, 2, 3, 4, 5, 6, 7, 8], dtype=tf.float32) t8 = tf.reshape(t8, [-1, 2, 2]) packed_l = tensor_utils.pack_typed_keyed_tensors(l) unpacked_l = tensor_utils.unpack_packed_tensors( tensor_utils.get_typed_keyed_shape(l), packed_l) with self.session() as sess: t5_value = sess.run(t5) t6_value = sess.run(t6) t7_value = sess.run(t7) t8_value = sess.run(t8) unpacked_l_value = sess.run(unpacked_l, feed_dict={ t1: t5_value, t2: t6_value, t3: t7_value, t4: t8_value }) original_l_value = sess.run(l, feed_dict={ t1: t5_value, t2: t6_value, t3: t7_value, t4: t8_value }) for d1, d2 in zip(unpacked_l_value, original_l_value): for key in sorted(d1): self.assertAllEqual(d1[key], d2[key]) def test_split_tensors_with_type_and_merge_dicts(self): t1 = tf.constant([[0, 0, 1, 0], [0, 2, 0, 3]], dtype=tf.int64) t2 = tf.constant([0, 3, 1, 2], dtype=tf.int64) t3 = tf.constant([9.1, 2.2], dtype=tf.float32) t4 = tf.constant([[1.1, 2.2], [3.3, 4.4]], dtype=tf.float32) t5 = tf.constant([3, 4, 5, 6, 7, 8, 9], dtype=tf.float64) d1 = {"t1": t1, "t2": t2} d2 = {"t4": t4, "t3": t3} d3 = {"t5": t5} l = [d2, d3, d1] total_d = {"t1": t1, "t2": t2, "t4": t4, "t3": t3, "t5": t5} split_total_d_l = tensor_utils.split_tensors_with_type(total_d) total_d_assemble = tensor_utils.merge_dicts(split_total_d_l) with self.session() as sess: split_total_d_l_value = sess.run(split_total_d_l) l_value = sess.run(l) for i, d in enumerate(split_total_d_l_value): for key in sorted(d): self.assertAllEqual(d[key], l_value[i][key]) total_d_assemble_value = sess.run(total_d_assemble) total_d_value = sess.run(total_d) for key in sorted(total_d_assemble_value): self.assertAllEqual(total_d_value[key], total_d_assemble_value[key]) if __name__ == "__main__": tf.compat.v1.disable_eager_execution() tf.test.main() ================================================ FILE: monolith/native_training/test_utils.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import tensorflow as tf from monolith.native_training import entry from monolith.native_training import utils from monolith.native_training.runtime.hash_table import \ embedding_hash_table_pb2 def generate_test_hash_table_config(dim: int = 2, use_float16: float = False, learning_rate: float = 1.0): """Creates a valid hash table config.""" table_config = embedding_hash_table_pb2.EmbeddingHashTableConfig() table_config.cuckoo.SetInParent() segment = table_config.entry_config.segments.add() segment.dim_size = dim segment.opt_config.sgd.SetInParent() segment.opt_config.stochastic_rounding_float16 = use_float16 segment.init_config.zeros.SetInParent() segment.comp_config.fp32.SetInParent() return entry.HashTableConfigInstance(table_config, [learning_rate]) def create_test_ps_cluster(num_ps: int): """Generates a config based on servers""" servers = [] for i in range(num_ps): servers.append(tf.distribute.Server.create_local_server()) cluster_def = tf.train.ClusterDef() job = cluster_def.job.add() job.name = utils.PS_JOB_NAME for i, server in enumerate(servers): job.tasks[i] = server.target[len('grpc://'):] return servers, tf.compat.v1.ConfigProto(cluster_def=cluster_def) def profile_it(fn): """Decorator for testcase to profile locally.""" def wrapped_fn(*args, **kwargs): options = tf.profiler.experimental.ProfilerOptions(host_tracer_level=2, python_tracer_level=1, device_tracer_level=1) tf.profiler.experimental.start("/tmp/tests_profile", options) res = fn(*args, **kwargs) tf.profiler.experimental.stop() time.sleep( 1) # ensure distinct profile dir names defined by timestamp on sec return res return wrapped_fn ================================================ FILE: monolith/native_training/touched_key_set_ops.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import tensorflow as tf from monolith.native_training.runtime.ops import gen_monolith_ops # 64 MB TOUCHED_KEY_SET_CAPACITY = 64 * 1024 * 1024 // (8 * 4) TOUCHED_KEY_SET_CONCURRENCY_LEVEL = 1024 touched_key_set_ops = gen_monolith_ops def create_touched_key_set(capacity: int, concurrency_level: int, name_suffix: str = "") -> tf.Tensor: """Creates a touched key set""" return touched_key_set_ops.MonolithTouchedKeySet( capacity=capacity, concurrency_level=concurrency_level, shared_name="MonolithTouchedKeySet" + name_suffix) class TouchedKeySet(object): def __init__(self, capacity: int = TOUCHED_KEY_SET_CAPACITY, concurrency_level: int = TOUCHED_KEY_SET_CONCURRENCY_LEVEL, name_suffix: str = ""): self._set = create_touched_key_set(capacity, concurrency_level) self._capacity = capacity self._concurrency_level = concurrency_level def insert(self, ids: tf.Tensor) -> int: return touched_key_set_ops.monolith_touched_key_set_insert(self._set, ids) def steal(self) -> int: return touched_key_set_ops.monolith_touched_key_set_steal(self._set) @property def capacity(self) -> int: return self._capacity @property def concurrency_level(self) -> int: return self._concurrency_level @property def handle(self) -> tf.Tensor: return self._set ================================================ FILE: monolith/native_training/touched_key_set_ops_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import tensorflow as tf from monolith.native_training.touched_key_set_ops import TouchedKeySet class TouchedKeySetOpsTest(tf.test.TestCase): def test_touched_key_set_basic(self): touched_key_set = TouchedKeySet(1000, 1) ids = tf.constant([x for x in range(1000)], dtype=tf.int64) total_dropped_num = touched_key_set.insert(ids) with tf.control_dependencies([total_dropped_num]): output_ids = touched_key_set.steal() with self.session() as sess: ids, total_dropped_num, output_ids = sess.run( [ids, total_dropped_num, output_ids]) self.assertEqual(0, total_dropped_num) self.assertAllEqual(ids, sorted(output_ids)) def test_touched_key_set_overflow(self): touched_key_set = TouchedKeySet(1000, 1) ids = tf.constant([x for x in range(1005)], dtype=tf.int64) total_dropped_num = touched_key_set.insert(ids) with tf.control_dependencies([total_dropped_num]): output_ids = touched_key_set.steal() with self.session() as sess: ids, total_dropped_num, output_ids = sess.run( [ids, total_dropped_num, output_ids]) self.assertEqual(1001, total_dropped_num) self.assertAllEqual([1001, 1002, 1003, 1004], sorted(output_ids)) if __name__ == "__main__": tf.compat.v1.disable_eager_execution() tf.test.main() ================================================ FILE: monolith/native_training/utils.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from collections import defaultdict from copy import deepcopy from typing import Any, Dict, Iterable, List, Set, Tuple import os import platform import re import socket import types import threading import six from inspect import signature, Parameter from numpy.lib.arraysetops import isin from absl import logging import tensorflow as tf from tensorflow.python.framework import ops from tensorflow.python.ops import variables from monolith.core.base_layer import get_uname from monolith.core.hyperparams import allowed_kwargs, InstantiableParams, Params PS_JOB_NAME = "ps" def ps_device(index: int) -> str: return "/job:{}/task:{}/device:CPU:0".format(PS_JOB_NAME, index) def propagate_back_gradients( grads_and_vars: Iterable[Tuple[tf.Tensor, tf.Tensor]], xs: Iterable[tf.Tensor], valid_var_set: Set[tf.Tensor] = None) -> List[tf.Tensor]: """ Propagate the gradients from vars back to the xs and return a list of gradients (dxs). Args: xs: tensors we want to get the gradient for. valid_var_set: if non empty, we will verify if var in grad_and_vars is in this set. """ combined_vars = [] combined_grads = [] for grad, var in grads_and_vars: if valid_var_set and (not var in valid_var_set): raise RuntimeError("Invalid variables in the input", var, valid_var_set) combined_vars.append(var) combined_grads.append(grad) return tf.gradients(combined_vars, list(xs), combined_grads) def propagate_back_dict_gradients( grads_and_vars: Iterable[Tuple[tf.Tensor, tf.Tensor]], x_to_key: Dict[tf.Tensor, Any], valid_var_set: Set[tf.Tensor] = None ) -> Dict[Any, List[Tuple[tf.Tensor, tf.Tensor]]]: """ Similar to above. But xs is replaced by x_to_key, and the returned gradients will be grouped by key. """ dxs = propagate_back_gradients(grads_and_vars, x_to_key.keys(), valid_var_set) grouped = defaultdict(list) for dx, (x, key) in zip(dxs, x_to_key.items()): grouped[key].append((dx, x)) return grouped def get_ndim(x: tf.Tensor): dims = x.get_shape()._dims if dims is not None: return len(dims) return None def int_shape(x): try: shapes = [] for dim in x.get_shape().as_list(): if dim is None: shapes.append(-1) elif isinstance(dim, int): shapes.append(dim) elif isinstance(dim, tf.compat.v1.Dimension): shapes.append(dim.value) else: raise ValueError(f'dim {dim} is error') return tuple(shapes) except ValueError: return None def extend_as_list(x, n): """This is a helper function to extend x as list, it will do: 1. If x is a list, padding it to specified length n with None, if the length is less than n; 2. If x is not a list, create a list with n elements x, please note that, these n elements are the same object, not a copy of x. """ if isinstance(x, (list, tuple)): if len(x) < n: return x + [None] * (n - len(x)) else: return x else: try: return [x if i == 0 else deepcopy(x) for i in range(n)] except: return [x] * n def check_list(candidate, length_checker, could_be_none=False): """Checks whether a list has valid length Args: length_checker: a callable object takes a single integer return T/F on whether the candidate in the range or not could_be_none: None type is acceptable Returns: candidate Raises: TypeError ValueError """ if not could_be_none and candidate is None: raise TypeError('ListChecker cannot accept None candidate') if type(candidate) not in [type(None), list]: raise TypeError('ListChecker got candidate ' 'in the wrong type[{}]'.format(type(candidate))) if candidate is not None and not length_checker(len(candidate)): raise ValueError('ListChecker got candidate beyonds the range') return candidate def to_snake_case(name): intermediate = re.sub('(.)([A-Z][a-z0-9]+)', r'\1_\2', name) insecure = re.sub('([a-z])([A-Z])', r'\1_\2', intermediate).lower() # If the class is private the name starts with "_" which is not secure # for creating scopes. We prefix the name with "private" in this case. if insecure[0] != '_': return insecure return 'private' + insecure def to_list(x): """Normalizes a list/tensor into a list. If a tensor is passed, we return a list of size 1 containing the tensor. # Arguments x: target object to be normalized. # Returns A list. """ if isinstance(x, list): return x return [x] def _get_parameters(cls, parameters): for p in signature(cls.__init__).parameters.values(): if p.name in {'self', 'cls'} or \ p.kind in {Parameter.VAR_KEYWORD, Parameter.VAR_POSITIONAL}: continue else: parameters[p.name] = p def _get_all_parameters(cls, parameters): if cls is not object: for base in cls.__bases__: _get_all_parameters(base, parameters) _get_parameters(cls, parameters) def _inverted_index(ips: InstantiableParams, idx_dict): for name, item in ips.iter_params(): if isinstance(item, (InstantiableParams, Params)): _inverted_index(item, idx_dict) else: idx_dict[name] = ips def params(cls): """Returns the layer params.""" ips = None for base in cls.__mro__: if base is cls: continue if hasattr(base, 'params'): ips = base.params() ips.cls = cls break ips = ips or InstantiableParams(cls) parameters = {} _get_all_parameters(cls, parameters) reversed_dict = {} _inverted_index(ips, reversed_dict) try: ips.define('name', get_uname(cls.__name__), "name") except: pass for p in parameters.values(): if p.name in {'cls', 'self'}: continue if p.name in reversed_dict: _ips = reversed_dict[p.name] if p.default != Parameter.empty: _ips[p.name] = p.default else: try: ips.define(p.name, None if p.default == Parameter.empty else p.default, p.name) except: if p.default != Parameter.empty and p.default != None: ips[p.name] = p.default for kw in allowed_kwargs: try: ips.define(kw, None, kw) except: pass return ips def check_ops_dependence(op_names_1, op_names_2): """Check whether op_names_1 depend on op_names_2. Raises: Exception: If op_names_1 depend on op_names_2. """ op_names_1 = to_list(op_names_1) graph_def = tf.compat.v1.get_default_graph().as_graph_def() sub_graph_1 = tf.compat.v1.graph_util.extract_sub_graph(graph_def, op_names_1) op_names_2 = set(to_list(op_names_2)) depended_op_names = [ node.name for node in sub_graph_1.node if node.name in op_names_2 ] if depended_op_names: raise Exception( "Checking ops dependence, the ops [%s] depend on ops [%s], which may cause ops [%s] to be run twice." % (",".join(op_names_1), ",".join(depended_op_names), ",".join(depended_op_names))) def with_params(cls): cls.params = types.MethodType(params, cls) return cls def get_local_host(): if platform.system() in ("Windows", "Linux"): local_host = socket.gethostbyname(socket.gethostname()) else: local_host = socket.gethostbyname(socket.gethostname() + ".local") return local_host def get_test_tmp_dir(): return os.environ.get("TEST_TMPDIR", "/tmp") def get_debugging_info_file_name(model_dir: str): return os.path.join(model_dir, "debugging_info.pb") def get_meta_graph_file_name(model_dir: str): return os.path.join(model_dir, "meta_graph_for_debugging.pb") def add_to_collections(names, value): if isinstance(value, (bool, int, float, str)): tf.compat.v1.add_to_collections(names, value) elif value: tf.compat.v1.add_to_collections(names, value) else: logging.info(f'value is {value}, skip') def get_collection(name): collection = tf.compat.v1.get_collection(name) if isinstance(collection, (bool, int, float, str)): return collection elif collection: return collection else: return None def set_metric_prefix(prefix: str): os.environ["MONOLITH_METRIC_PREFIX"] = prefix def get_metric_prefix(): return os.environ.get("MONOLITH_METRIC_PREFIX", "monolith.training") ================================================ FILE: monolith/native_training/utils_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import tensorflow as tf from monolith.native_training import utils class UtilsTest(tf.test.TestCase): def test_propagate_back_dict_gradients(self): x = tf.Variable(8.0) y = 2 * x # Use a grad related to x grad_y = 3 * y valid_vars = set([y]) grouped = utils.propagate_back_dict_gradients(zip([grad_y], [y]), {x: "group1"}, valid_vars) with self.session() as sess: sess.run(tf.compat.v1.global_variables_initializer()) dx_and_x = sess.run(grouped["group1"]) self.assertAllEqual(dx_and_x, [(96, 8)]) def test_check_ops_dependence(self): v = tf.Variable(0) add = v.assign_add(1) with tf.control_dependencies([add]): t1 = tf.constant(0) t2 = tf.constant(0) with self.assertRaises(Exception): utils.check_ops_dependence(t1.op.name, add.name) # OK to check utils.check_ops_dependence(t1.op.name, t2.op.name) def test_collections(self): utils.add_to_collections('int', 1) utils.add_to_collections('int', 2) utils.add_to_collections('str', 'str') utils.add_to_collections('str', None) utils.add_to_collections('bool', True) utils.add_to_collections('int_list', [1, 2, 3]) utils.add_to_collections('str_list', None) utils.add_to_collections('bool_list', []) utils.add_to_collections('int_list', [4, 5, 6]) utils.add_to_collections('str_list', ['hello', 'world']) utils.add_to_collections('bool_list', [False]) self.assertTrue(utils.get_collection('int')[-1] == 2) self.assertTrue(utils.get_collection('str')[-1] == 'str') self.assertTrue(utils.get_collection('bool')[-1]) self.assertListEqual(utils.get_collection('int_list')[-1], [4, 5, 6]) self.assertListEqual( utils.get_collection('str_list')[-1], ['hello', 'world']) self.assertListEqual(utils.get_collection('bool_list')[-1], [False]) if __name__ == "__main__": tf.compat.v1.disable_eager_execution() tf.test.main() ================================================ FILE: monolith/native_training/variables.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import collections import dataclasses from typing import Dict, List import tensorflow as tf from tensorflow.python.types import core from tensorflow.python.ops import variables as variables_lib from tensorflow.python.ops import resource_variable_ops from monolith.native_training import graph_meta _CACHED_VARIABLES = "monolith_cached_variables" @dataclasses.dataclass class CachedVariableAssociates: async_fetched_var: tf.Variable async_cached_var: tf.Variable @dataclasses.dataclass class CachedVariableMeta: var_id_to_assoc: Dict[int, CachedVariableAssociates] = dataclasses.field( default_factory=dict) def _get_meta() -> CachedVariableMeta: return graph_meta.get_meta("cached_variables_meta", CachedVariableMeta) @tf.custom_gradient def cached_value(var, async_cached_var): def grad(dy): return dy, None return async_cached_var, grad def _get_valid_op_name(name: str): return name.replace(":", "_").replace("/", "_") def cached_variable_creator(next_creator, **kwargs): var = next_creator(**kwargs) if not isinstance(var, resource_variable_ops.ResourceVariable): raise ValueError("Only ResourceVariable is supported. " "Do you disable V2 behavior or use strategy?") if not var._cached_value is None: raise ValueError("The variable has already been cached. " "Consider about removing cache_device.") with tf.device(None): async_cached_var = resource_variable_ops.ResourceVariable( initial_value=var.initial_value, trainable=False, collections=[tf.compat.v1.GraphKeys.LOCAL_VARIABLES], shape=var.shape, dtype=var.dtype) async_fetched_var = resource_variable_ops.ResourceVariable( initial_value=var.initial_value, trainable=False, collections=[tf.compat.v1.GraphKeys.LOCAL_VARIABLES], shape=var.shape, dtype=var.dtype) if async_cached_var.device == var.device: # In this case, we shouldn't do the cache since we try assign vars # on the remote machines. # # This is common when cached_var is forced to colocate with var. # For example, var is optimizer's slot variables. return var tf.compat.v1.add_to_collection(_CACHED_VARIABLES, var) var._cached_value = cached_value(var, async_cached_var) meta = _get_meta() meta.var_id_to_assoc[id(var)] = CachedVariableAssociates( async_fetched_var=async_fetched_var, async_cached_var=async_cached_var) return var def fetch_all_cached_variables(): meta = _get_meta() ops = [] for var in tf.compat.v1.get_collection(_CACHED_VARIABLES): fetched_var = meta.var_id_to_assoc[id(var)].async_fetched_var ops.append( fetched_var.assign(var._read_variable_op(), name="fetch_from_{}".format( _get_valid_op_name(str(var.device))), read_value=False)) return tf.group(ops) def assign_all_cached_variables(): meta = _get_meta() ops = [] for var in tf.compat.v1.get_collection(_CACHED_VARIABLES): associates = meta.var_id_to_assoc[id(var)] ops.append( associates.async_cached_var.assign(associates.async_fetched_var, name="assign_cached_var", read_value=False)) return tf.group(ops, name="assign_all_cached_variables") class FetchAllCachedVariablesHook(tf.estimator.SessionRunHook): """Fetch variables.""" def __init__(self): self._fetch_op = fetch_all_cached_variables() self._assign_op = assign_all_cached_variables() self._first_run = True def after_create_session(self, session, coord): self._first_run = True def before_run(self, run_context: tf.estimator.SessionRunContext): if self._first_run: # For the first run, we do a sync fetch since the local values might be # super stale. run_context.session.run(self._fetch_op) run_context.session.run(self._assign_op) self._first_run = False return tf.estimator.SessionRunArgs(self._fetch_op) def after_run(self, run_context, run_values): run_context.session.run(self._assign_op) ================================================ FILE: monolith/native_training/variables_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import tensorflow as tf from monolith.native_training import variables from monolith.native_training import test_utils class CachedVariableTest(tf.test.TestCase): def testBasic(self): servers, config = test_utils.create_test_ps_cluster(2) with tf.compat.v1.Session(target=servers[0].target, config=config) as sess: with tf.variable_creator_scope(variables.cached_variable_creator): with tf.device("/job:ps/task:1"): var = tf.Variable(5.0) sess.run([ tf.compat.v1.global_variables_initializer(), tf.compat.v1.local_variables_initializer() ]) # We use var * 1.0 since direct run var will use var.ref() # which is original value of var. self.assertAllEqual(5.0, self.evaluate(var * 1.0)) update_op = var.assign_add(2.0) sess.run(update_op) # update op won't take effect until fetch happened. self.assertAllEqual(5.0, self.evaluate(var * 1.0)) # But the original value should be updated. self.assertAllEqual(7.0, self.evaluate(var)) sess.run(variables.fetch_all_cached_variables()) sess.run(variables.assign_all_cached_variables()) # update takes effect. self.assertAllEqual(7.0, self.evaluate(var * 1.0)) def testHook(self): servers, config = test_utils.create_test_ps_cluster(2) with tf.variable_creator_scope(variables.cached_variable_creator): with tf.device("/job:ps/task:1"): var = tf.Variable(5.0) var_cached = var * 1.0 sub_op = tf.compat.v1.assign_sub(var, 1.0) with tf.compat.v1.train.SingularMonitoredSession( master=servers[0].target, config=config, hooks=[variables.FetchAllCachedVariablesHook()]) as sess: var_cached_value = sess.run(var_cached) self.assertAllEqual(5.0, var_cached_value) sess.run(sub_op) # At most twice, local var will be finally updated. var_cached_value = sess.run(var_cached) var_cached_value = sess.run(var_cached) self.assertAllEqual(4.0, var_cached_value) def testGradient(self): servers, config = test_utils.create_test_ps_cluster(2) with tf.variable_creator_scope(variables.cached_variable_creator): with tf.device("/job:ps/task:1"): var = tf.Variable(5.0) loss = var opt = tf.compat.v1.train.GradientDescentOptimizer(1.0) op = opt.minimize(loss) with tf.compat.v1.Session(target=servers[0].target, config=config) as sess: sess.run([ tf.compat.v1.global_variables_initializer(), tf.compat.v1.local_variables_initializer() ]) sess.run(op) self.assertAllEqual(4.0, sess.run(var)) # The result should not be fetched yet. self.assertAllEqual(5.0, sess.run(var * 1.0)) if __name__ == "__main__": tf.compat.v1.disable_eager_execution() tf.test.main() ================================================ FILE: monolith/native_training/yarn_runtime.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Functionalities will help to get/set some information from the yarn runtime.""" import collections import os import socket import time from absl import logging import grpc from monolith.native_training import net_utils from monolith.native_training.proto import primus_am_service_pb2 from monolith.native_training.proto import primus_am_service_pb2_grpc def get_local_host(): if "CLOUDNATIVE_INET_ADDR" in os.environ: ips = os.environ["CLOUDNATIVE_INET_ADDR"] local_host = ips.split(",")[0] elif "YARN_INET_ADDR" in os.environ: local_host = os.environ["YARN_INET_ADDR"] else: local_host = net_utils.get_local_ip() assert local_host return local_host def _get_primus_am_host(): if "PRIMUS_AM_RPC_HOST" in os.environ and "PRIMUS_AM_RPC_PORT" in os.environ: host = os.environ["PRIMUS_AM_RPC_HOST"] port = os.environ["PRIMUS_AM_RPC_PORT"] return host + ":" + port return "" _CHANNEL_MAP = {} def _get_channel(addr: str) -> grpc.Channel: if not addr in _CHANNEL_MAP: _CHANNEL_MAP[addr] = grpc.insecure_channel(addr) return _CHANNEL_MAP[addr] def maybe_kill_application(reason: str) -> bool: """Send a request to AM to kill application.""" if _get_primus_am_host(): stub = primus_am_service_pb2_grpc.AppMasterServiceStub( _get_channel(_get_primus_am_host())) req = primus_am_service_pb2.KillRequest() req.exit_code = 1 req.diagnose = reason req.graceful_shutdown_timeout_ms.value = 20000 try: resp = stub.kill(req, timeout=10) logging.info("Successfully killed application.") return True except grpc.RpcError as e: logging.info("Failed to kill application: %s", e) return False logging.info("Current framework doesn't support kill. Ignore killing...") return False def maybe_finish_application(): if _get_primus_am_host(): stub = primus_am_service_pb2_grpc.AppMasterServiceStub( _get_channel(_get_primus_am_host())) req = primus_am_service_pb2.SucceedRequest() req.graceful_shutdown_timeout_ms.value = 20000 try: resp = stub.succeed(req, timeout=10) logging.info("Successfully mark the application success.") return True except grpc.RpcError as e: logging.info("Failed to finish application: %s", e) def create_primus_save_point(dst): if _get_primus_am_host(): stub = primus_am_service_pb2_grpc.AppMasterServiceStub( _get_channel(_get_primus_am_host())) create_req = primus_am_service_pb2.CreateSavepointRequest() create_req.savepoint_dir = dst try: create_resp = stub.createSavepoint(create_req, timeout=10) if create_resp.code != 0: logging.error("Failed to create primus save point: %s", create_resp.message) return False savepoint_id = create_resp.savepoint_id status_req = primus_am_service_pb2.CreateSavepointStatusRequest() status_req.savepoint_restore_id = savepoint_id while True: statue_resp = stub.createSavepointStatus(status_req, timeout=10) if statue_resp.create_savepoint_state in [ primus_am_service_pb2.CreateSavepointStatusResponse. CreateSavepointState.PENDING, primus_am_service_pb2. CreateSavepointStatusResponse.CreateSavepointState.RUNNING ]: time.sleep(5) continue elif statue_resp.create_savepoint_state == primus_am_service_pb2.CreateSavepointStatusResponse.CreateSavepointState.SUCCEEDED: logging.info("Create primus save point succeeded.") return True else: logging.error("Failed to create primus save point: %s", statue_resp.message) return False except grpc.RpcError as e: logging.info("Failed to create primus save point: %s", e) return False ================================================ FILE: monolith/native_training/yarn_runtime_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from concurrent import futures import os import types import unittest from unittest import mock import grpc from monolith.native_training import yarn_runtime from monolith.native_training.proto import primus_am_service_pb2 from monolith.native_training.proto import primus_am_service_pb2_grpc class YarnRuntimeTest(unittest.TestCase): @mock.patch.dict(os.environ, {"YARN_INET_ADDR": "1.2.3.4"}) def test_get_local_host_overwrite(self): self.assertEqual(yarn_runtime.get_local_host(), "1.2.3.4") @mock.patch.dict(os.environ, {"CLOUDNATIVE_INET_ADDR": "1.2.3.4,5.6.7.8"}) def test_get_local_host_overwrite_by_cloudnative(self): self.assertEqual(yarn_runtime.get_local_host(), "1.2.3.4") def test_get_local_host_basic(self): yarn_runtime.get_local_host() @mock.patch.dict(os.environ, { "PRIMUS_AM_RPC_HOST": "unix", "PRIMUS_AM_RPC_PORT": "test_kill" }) def test_kill(self): servicer = primus_am_service_pb2_grpc.AppMasterServiceServicer() called = False reason = "TestKill" def kill(servicer_self, request, context): nonlocal called called = True self.assertEqual(request.diagnose, reason) return primus_am_service_pb2.KillResponse() servicer.kill = types.MethodType(kill, servicer) server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) primus_am_service_pb2_grpc.add_AppMasterServiceServicer_to_server( servicer, server) addr = "unix:test_kill" server.add_insecure_port(addr) server.start() yarn_runtime.maybe_kill_application(reason) self.assertTrue(called) server.stop(True) @mock.patch.dict(os.environ, { "PRIMUS_AM_RPC_HOST": "unix", "PRIMUS_AM_RPC_PORT": "test_succeed" }) def test_finish(self): servicer = primus_am_service_pb2_grpc.AppMasterServiceServicer() called = False def succeed(self, request, context): nonlocal called called = True return primus_am_service_pb2.SucceedResponse() servicer.succeed = types.MethodType(succeed, servicer) server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) primus_am_service_pb2_grpc.add_AppMasterServiceServicer_to_server( servicer, server) addr = "unix:test_succeed" server.add_insecure_port(addr) server.start() yarn_runtime.maybe_finish_application() self.assertTrue(called) server.stop(True) @mock.patch.dict(os.environ, { "PRIMUS_AM_RPC_HOST": "unix", "PRIMUS_AM_RPC_PORT": "test_save_primus" }) def test_save_primus(self): servicer = primus_am_service_pb2_grpc.AppMasterServiceServicer() create_called = False status_called = False dst = "test" def createSavepoint(self, request, context): nonlocal create_called create_called = True resp = primus_am_service_pb2.CreateSavepointResponse() resp.savepoint_id = "123" return resp def createSavepointStatus(self, request, context): nonlocal status_called status_called = True resp = primus_am_service_pb2.CreateSavepointStatusResponse() resp.create_savepoint_state = primus_am_service_pb2.CreateSavepointStatusResponse.CreateSavepointState.SUCCEEDED return resp servicer.createSavepoint = types.MethodType(createSavepoint, servicer) servicer.createSavepointStatus = types.MethodType(createSavepointStatus, servicer) server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) primus_am_service_pb2_grpc.add_AppMasterServiceServicer_to_server( servicer, server) addr = "unix:test_save_primus" server.add_insecure_port(addr) server.start() resp = yarn_runtime.create_primus_save_point(dst) self.assertTrue(resp) self.assertTrue(create_called) self.assertTrue(status_called) server.stop(True) if __name__ == "__main__": unittest.main() ================================================ FILE: monolith/native_training/zk_utils.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os from absl import logging from datetime import datetime, timedelta from kazoo.client import KazooClient from monolith.native_training.env_utils import get_zk_auth_data import socket _PORT = 2181 _HOSTS = ['10.226.91.73', '10.226.86.70', '10.224.126.131', '10.224.109.135'] _HOSTS_IPV6 = [ 'fdbd:dc02:ff:2:1:226:91:73', 'fdbd:dc02:ff:2:1:226:86:70', 'fdbd:dc01:ff:1:1:224:126:131', 'fdbd:dc01:ff:1:1:224:109:135' ] def is_ipv6_only(): if "MY_HOST_IP" in os.environ or "MY_POD_IP" in os.environ or "MY_HOST_IPV6" in os.environ: # in tce/byterec environment ipv4_addr = os.environ.get("MY_HOST_IP", os.environ.get("MY_POD_IP", None)) logging.info(f"in tce env, ipv4 address is {ipv4_addr}") else: try: ipv4_addr = socket.gethostbyname(socket.gethostname()) except: ipv4_addr = None logging.info(f"not in tce env, ipv4 address is {ipv4_addr}") ipv6_only = not ipv4_addr logging.info(f"is_ipv6_only is {ipv6_only}") return ipv6_only _HOSTS = [] _HOSTS_IPV6 = [] def default_zk_servers(use_ipv6: bool = False): if use_ipv6 or is_ipv6_only(): return ','.join( ['[{ip}]:{port}'.format(ip=ip, port=_PORT) for ip in _HOSTS_IPV6]) return ','.join(['{ip}:{port}'.format(ip=ip, port=_PORT) for ip in _HOSTS]) class MonolithKazooClient(KazooClient): def __init__(self, *args, **kwargs): if "auth_data" not in kwargs: kwargs["auth_data"] = get_zk_auth_data() super().__init__(*args, **kwargs) def clear_zk_path(zk_server: str, job_name: str, force_clear_zk_path: bool): """Try to clear old path (no modification since 9 weeks ago), Clear ZK Path of current job.""" zk_client = MonolithKazooClient(zk_server or default_zk_servers()) base_path = '/monolith' delta = timedelta(weeks=9) # two months try: zk_client.start() # 1) try to delete very old nodes, just like TTL zk_client.ensure_path(base_path) children = zk_client.get_children(base_path) for child in children: path = '{}/{}'.format(base_path, child) _, stat = zk_client.get_children(path, include_data=True) if datetime.fromtimestamp(stat.mtime // 1000) + delta < datetime.now(): try: zk_client.delete(path, recursive=True) except: # in case error in parallel condition pass # 2) try to delete job_name job_path = '{}/{}'.format(base_path, job_name) state = zk_client.exists(job_path) if state is not None: if force_clear_zk_path: zk_client.delete(job_path, recursive=True) else: children = zk_client.get_children('/monolith') raise ValueError('there are [{}] in monolith zk path'.format( ','.join(children))) finally: zk_client.stop() ================================================ FILE: monolith/path_utils.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Make sure do not include any other third-party library in this file. e.g., tensorflow, absl ...""" import os def find_main(): """Find base directory of our codebase, which should be current dir for all binaries in monolith codebase.""" path = os.path.abspath(__file__) splits = ['/__main__/', '/site-packages/', '/monolith/'] main_dir = None for split in splits: if split in path: end = path.rfind(split) if split == '/monolith/': main_dir = path[0:end] else: main_dir = os.path.join(path[0:end], split.strip('/')) break if main_dir is not None and os.path.exists(os.path.join(main_dir, 'monolith')): return main_dir else: raise ValueError( "Unable to find the monolith base directory. This file directory is {}. Are you running under bazel structure?" .format(path)) def get_libops_path(lib_name): base = find_main() # monolith base return os.path.join(base, lib_name) ================================================ FILE: monolith/tf_serving_workspace.bzl ================================================ # The file is copied from https://github.com/tensorflow/serving/blob/master/tensorflow_serving/workspace.bzl # The modification is that we modify //third_part to @org_tensorflow_serving//third_part # TensorFlow Serving external dependencies that can be loaded in WORKSPACE # files. load("@org_tensorflow//third_party:repo.bzl", "tf_http_archive") load("@org_tensorflow//tensorflow:workspace.bzl", "tf_workspace") load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") def tf_serving_workspace(): """All TensorFlow Serving external dependencies.""" tf_workspace(path_prefix = "", tf_repo_name = "org_tensorflow") # ===== gRPC dependencies ===== native.bind( name = "libssl", actual = "@boringssl//:ssl", ) # gRPC wants the existence of a cares dependence but its contents are not # actually important since we have set GRPC_ARES=0 in tools/bazel.rc native.bind( name = "cares", actual = "@grpc//third_party/nanopb:nanopb", ) # ===== RapidJSON (rapidjson.org) dependencies ===== http_archive( name = "com_github_tencent_rapidjson", urls = [ "https://github.com/Tencent/rapidjson/archive/v1.1.0.zip", ], sha256 = "8e00c38829d6785a2dfb951bb87c6974fa07dfe488aa5b25deec4b8bc0f6a3ab", strip_prefix = "rapidjson-1.1.0", build_file = "@org_tensorflow_serving//third_party/rapidjson:BUILD", ) # ===== libevent (libevent.org) dependencies ===== http_archive( name = "com_github_libevent_libevent", urls = [ "https://github.com/libevent/libevent/archive/release-2.1.8-stable.zip", ], sha256 = "70158101eab7ed44fd9cc34e7f247b3cae91a8e4490745d9d6eb7edc184e4d96", strip_prefix = "libevent-release-2.1.8-stable", build_file = "@org_tensorflow_serving//third_party/libevent:BUILD", ) # ===== Override TF & TF Text defined 'ICU'. (we need a version that contains all data). http_archive( name = "icu", strip_prefix = "icu-release-64-2", sha256 = "dfc62618aa4bd3ca14a3df548cd65fe393155edd213e49c39f3a30ccd618fc27", urls = [ "https://storage.googleapis.com/mirror.tensorflow.org/github.com/unicode-org/icu/archive/release-64-2.zip", "https://github.com/unicode-org/icu/archive/release-64-2.zip", ], build_file = "@org_tensorflow_serving//third_party/icu:BUILD", patches = ["@org_tensorflow_serving//third_party/icu:data.patch"], patch_args = ["-p1", "-s"], ) # ===== Pin `com_google_absl` with the same version(and patch) with Tensorflow. tf_http_archive( name = "com_google_absl", build_file = str(Label("@org_tensorflow//third_party:com_google_absl.BUILD")), # TODO: Remove the patch when https://github.com/abseil/abseil-cpp/issues/326 is resolved # and when TensorFlow is build against CUDA 10.2 patch_file = str(Label("@org_tensorflow//third_party:com_google_absl_fix_mac_and_nvcc_build.patch")), sha256 = "f368a8476f4e2e0eccf8a7318b98dafbe30b2600f4e3cf52636e5eb145aba06a", # SHARED_ABSL_SHA strip_prefix = "abseil-cpp-df3ea785d8c30a9503321a3d35ee7d35808f190d", urls = [ "https://storage.googleapis.com/mirror.tensorflow.org/github.com/abseil/abseil-cpp/archive/df3ea785d8c30a9503321a3d35ee7d35808f190d.tar.gz", "https://github.com/abseil/abseil-cpp/archive/df3ea785d8c30a9503321a3d35ee7d35808f190d.tar.gz", ], ) # ===== TF.Text dependencies # NOTE: Before updating this version, you must update the test model # and double check all custom ops have a test: # https://github.com/tensorflow/text/blob/master/oss_scripts/model_server/save_models.py http_archive( name = "org_tensorflow_text", sha256 = "05cc1b0eda8f4f734cb81d4389a637d26372b8621cb4c4a7e30ee5bc1e8c63da", strip_prefix = "text-2.3.0", urls = [ "https://github.com/tensorflow/text/archive/v2.3.0.zip", ], patches = ["@org_tensorflow_serving//third_party/tf_text:tftext.patch"], patch_args = ["-p1"], repo_mapping = {"@com_google_re2": "@com_googlesource_code_re2"}, ) http_archive( name = "com_google_sentencepiece", strip_prefix = "sentencepiece-1.0.0", sha256 = "c05901f30a1d0ed64cbcf40eba08e48894e1b0e985777217b7c9036cac631346", urls = [ "https://github.com/google/sentencepiece/archive/1.0.0.zip", ], ) http_archive( name = "com_google_glog", sha256 = "1ee310e5d0a19b9d584a855000434bb724aa744745d5b8ab1855c85bff8a8e21", strip_prefix = "glog-028d37889a1e80e8a07da1b8945ac706259e5fd8", urls = [ "https://mirror.bazel.build/github.com/google/glog/archive/028d37889a1e80e8a07da1b8945ac706259e5fd8.tar.gz", "https://github.com/google/glog/archive/028d37889a1e80e8a07da1b8945ac706259e5fd8.tar.gz", ], ) # ==== we need to modify eigen_archive to make it compilable in gcc6 tf_http_archive( name = "eigen_archive", build_file = str(Label("@org_tensorflow//third_party:eigen.BUILD")), patch_file = "//third_party:eigen3/eigen_gcc6.patch", sha256 = "e807a6a6f3a0e8ab10adeb59bb5a9bbb113e8e1684f9b4b32f73f58fd758b4cf", # SHARED_EIGEN_SHA strip_prefix = "eigen-011e0db31d1bed8b7f73662be6d57d9f30fa457a", urls = [ "https://storage.googleapis.com/mirror.tensorflow.org/gitlab.com/libeigen/eigen/-/archive/011e0db31d1bed8b7f73662be6d57d9f30fa457a/eigen-011e0db31d1bed8b7f73662be6d57d9f30fa457a.tar.gz", "https://gitlab.com/libeigen/eigen/-/archive/011e0db31d1bed8b7f73662be6d57d9f30fa457a/eigen-011e0db31d1bed8b7f73662be6d57d9f30fa457a.tar.gz", ], ) ================================================ FILE: monolith/tpu_runner.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Base task.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function from absl import app from absl import flags from absl import logging import os import sys import time import tensorflow.compat.v1 as tf from cloud_tpu_client import client from monolith.base_runner import BaseRunner from monolith.core import model_registry from monolith.core.auto_checkpoint_feed_hook import TPUInfeedOutfeedSessionWithEndOfStreamHandlingHook from monolith.core.base_embedding_task import BaseEmbeddingTask FLAGS = flags.FLAGS flags.DEFINE_string("tf_version", default="nightly", help="TensorFlow version") flags.DEFINE_string( "tpu", default=None, help="The Cloud TPU to use for training. This should be either the name " "used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 " "url.") flags.DEFINE_string( "gcp_project", default=None, help="Project name for the Cloud TPU-enabled project. If not specified, " "we will attempt to automatically detect the GCE project from metadata.") flags.DEFINE_string( "tpu_zone", default=None, help="GCE zone where the Cloud TPU is located in. If not specified, we " "will attempt to automatically detect the zone from metadata.") flags.DEFINE_string("task", default=None, help="Name of the task class to run.") flags.DEFINE_string( "model_dir", default=None, help=("The directory where the model and summaries are stored.")) flags.DEFINE_enum("mode", "train", ["train_and_eval", "train", "eval"], "Job mode.") flags.DEFINE_integer( "save_checkpoints_steps", default=None, help= ("Save checkpoint every save_checkpoints_steps. If None, no checkpoint saved." )) flags.DEFINE_integer("iterations_per_loop", default=10000, help=("This is the number of train steps running " "in TPU system before returning to CPU host .")) # TPU Embedding flags. flags.DEFINE_bool( "pipeline_execution", default=False, help=("If True, speed up training by overlaping embedding lookups with " "dense layer computations. Embedding lookups will be one step old.")) flags.DEFINE_bool("enable_tpu_version_config", default=True, help=("Whether enable tpu configuration or not.")) flags.DEFINE_integer("host_call_every_n_steps", default=500, help=("Host call every n steps.")) # Whether enable handling end of stream and auto checkpointing. If this is False, then # not handle end of stream. If this is True, enable end of stream handling # and save a checkpoint before training job end. flags.DEFINE_bool( "enable_stopping_signals", default=False, help=("Whether enable stopping signals and auto checkpointing.")) # This is only set to True when use CPU to do some simple test. Note that the internal # embedding update logic is not implemented yet. So do not use this mode to do actually # training. flags.DEFINE_bool("cpu_test", default=False, help=("Wheter use CPU in TPU estimator.")) # Allowed value are "div" and "mod". "div" is the default partition_strategy. # Use 'mod' which runs faster than 'div' given our id distribution especially # with the incremental generated data. Incremental generated data are more likely # to have processing ids distributed in some small ranges of vocab table rather # than randomly distributed across vocab whole table. So 'div' will make # those some cores more busy with processing those id ranges. 'mod' here will # help distribute ids more evenly across more cores. flags.DEFINE_string("partition_strategy", default="mod", help=("Partition strategy of embedding table.")) # This will override end_date if provided not empty value. flags.DEFINE_string("overwrite_end_date", default="", help=("End date of input data.")) class TPURunner(BaseRunner): def __init__(self, task_param, *args, **kwargs): super(TPURunner, self).__init__(*args, **kwargs) # TODO(youlong.cheng): all the parse logic should genearte a hyperparam class. self._tpu = FLAGS.tpu self._tpu_zone = FLAGS.tpu_zone self._gcp_project = FLAGS.gcp_project self._num_replicas_per_host = 8 self._model_dir = FLAGS.model_dir self._pipeline_execution = FLAGS.pipeline_execution self.iterations_per_loop = FLAGS.iterations_per_loop self._enable_tpu_version_config = FLAGS.enable_tpu_version_config self._host_call_every_n_steps = FLAGS.host_call_every_n_steps self._enable_stopping_signals = FLAGS.enable_stopping_signals self._cpu_test = FLAGS.cpu_test self._partition_strategy = FLAGS.partition_strategy if task_param.train.save_checkpoints_steps is not None: self._save_checkpoints_steps = task_param.train.save_checkpoints_steps logging.info( "Overwrite save_checkpoints_steps by task_param.train: {}".format( self._save_checkpoints_steps)) else: self._save_checkpoints_steps = FLAGS.save_checkpoints_steps logging.info("Use save_checkpoints_steps by FLAGS: {}".format( self._save_checkpoints_steps)) #TODO(hemang.jangle) Allow subclass task_params to override tpu_runner params self._task_param = task_param self._task_param.accelerator = "tpu" if FLAGS.overwrite_end_date is not None and FLAGS.overwrite_end_date != "" and self._task_param.train.contain( "end_date"): self._task_param.train.end_date = FLAGS.overwrite_end_date logging.info( "Use flag end_date {} to replace parameter train.end_date.".format( self._task_param.train.end_date)) self._mode = FLAGS.mode self._task = None def _experimental_gradient_multiplier_fn(self, global_step): return self._task_param.gradient_multiplier def _create_params(self, total_replicas): # TODO(youlong.cheng): this is a little bit Adhoc solution, consider # abstract HostCall class with hyper_parameter. params = { "model_dir": self._model_dir, "enable_host_call": self._host_call_every_n_steps > 0, "num_replicas": total_replicas, "accelerator": self._task_param.accelerator, "host_call_every_n_steps": self._host_call_every_n_steps, "enable_stopping_signals": self._enable_stopping_signals, "cpu_test": self._cpu_test, } logging.info("params: {}".format(params)) return params def create_tpu_estimator(self, model_fn, feature_config, table_config): """Creates the TPU Estimator, with accelerated lookups for embedding tables.""" if self._enable_tpu_version_config == True: logging.info( "Enable tpu version config, reset remote tpu version with {}".format( tf.__version__)) # This is to let the cloud TPU always restart in case last round operation is still not finished. tpu_client = client.Client(tpu=self._tpu, zone=self._tpu_zone, project=self._gcp_project) tpu_client.configure_tpu_version(version=tf.__version__) tpu_client.wait_for_healthy() tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver( self._tpu, zone=self._tpu_zone, project=self._gcp_project) num_hosts = tpu_cluster_resolver.cluster_spec().num_tasks("worker") total_replicas = self._num_replicas_per_host * num_hosts train_global_batch_size = total_replicas * self._task_param.train.per_replica_batch_size logging.info( "num_hosts: {} total_replicas: {} train_global_batch_size: {}".format( num_hosts, total_replicas, train_global_batch_size)) # experimental_host_call_every_n_steps can't be 0. If _host_call_every_n_steps is not specified, # then experimental_host_call_every_n_steps will use 100. if self._host_call_every_n_steps == 0: _experimental_host_call_every_n_steps = 100 else: _experimental_host_call_every_n_steps = self._host_call_every_n_steps if self._enable_stopping_signals is True: experimental_feed_hook = TPUInfeedOutfeedSessionWithEndOfStreamHandlingHook else: experimental_feed_hook = None config = tf.compat.v1.estimator.tpu.RunConfig( cluster=tpu_cluster_resolver, model_dir=self._model_dir, save_checkpoints_steps=self._save_checkpoints_steps, tpu_config=tf.compat.v1.estimator.tpu.TPUConfig( iterations_per_loop=self.iterations_per_loop, experimental_host_call_every_n_steps= _experimental_host_call_every_n_steps, per_host_input_for_training=tf.compat.v1.estimator.tpu. InputPipelineConfig.PER_HOST_V2, experimental_allow_per_host_v2_parallel_get_next=True, experimental_feed_hook=experimental_feed_hook, )) # Disable meta_optimizer which is not needed and takes long time to run. config.session_config.graph_options.rewrite_options.disable_meta_optimizer = True if feature_config and table_config: embedding_config_spec = tf.compat.v1.estimator.tpu.experimental.EmbeddingConfigSpec( feature_to_config_dict=feature_config, table_to_config_dict=table_config, partition_strategy=self._partition_strategy, pipeline_execution_with_tensor_core=self._pipeline_execution, experimental_gradient_multiplier_fn=self. _experimental_gradient_multiplier_fn, optimization_parameters=tf.compat.v1.tpu.experimental. AdagradParameters(learning_rate=1.0)) else: embedding_config_spec = None params = self._create_params(total_replicas) if self._task_param.eval.per_replica_batch_size is not None: eval_batch_size = self._task_param.eval.per_replica_batch_size * total_replicas else: eval_batch_size = train_global_batch_size return tf.compat.v1.estimator.tpu.TPUEstimator( use_tpu=True, model_fn=model_fn, config=config, train_batch_size=train_global_batch_size, eval_batch_size=eval_batch_size, params=params, embedding_config_spec=embedding_config_spec), total_replicas def create_tpu_estimator_on_cpu(self, model_fn, feature_config, table_config): if self._host_call_every_n_steps == 0: _experimental_host_call_every_n_steps = 100 else: _experimental_host_call_every_n_steps = self._host_call_every_n_steps config = tf.compat.v1.estimator.tpu.RunConfig( cluster=None, model_dir=None, save_checkpoints_steps=self._save_checkpoints_steps, tpu_config=tf.compat.v1.estimator.tpu.TPUConfig( iterations_per_loop=self.iterations_per_loop, experimental_host_call_every_n_steps= _experimental_host_call_every_n_steps, per_host_input_for_training=tf.compat.v1.estimator.tpu. InputPipelineConfig.PER_HOST_V2, experimental_allow_per_host_v2_parallel_get_next=True)) if feature_config and table_config: embedding_config_spec = tf.compat.v1.estimator.tpu.experimental.EmbeddingConfigSpec( feature_to_config_dict=feature_config, table_to_config_dict=table_config, pipeline_execution_with_tensor_core=self._pipeline_execution, experimental_gradient_multiplier_fn=self. _experimental_gradient_multiplier_fn, optimization_parameters=tf.compat.v1.tpu.experimental. AdagradParameters(learning_rate=1.0)) else: embedding_config_spec = None total_replicas = 1 params = self._create_params(total_replicas) return tf.compat.v1.estimator.tpu.TPUEstimator( use_tpu=False, model_fn=model_fn, config=config, train_batch_size=128, params=params, embedding_config_spec=embedding_config_spec), total_replicas def run(self): try: current_step = tf.train.load_variable(self._model_dir, tf.compat.v1.GraphKeys.GLOBAL_STEP) except (TypeError, ValueError, tf.errors.NotFoundError): current_step = 0 logging.info("Current step :{}".format(current_step)) task = self._task_param.instantiate() self._task = task feature_config, table_config = None, None if isinstance(task, BaseEmbeddingTask): task.init_slot_to_env() feature_config, table_config = task.create_feature_and_table_config_dict() input_fn_train = task.create_input_fn(tf.estimator.ModeKeys.TRAIN) model_fn = task.create_model_fn() assert self._cpu_test is False or self._mode == 'train', \ "Cpu test can only work with train mode." if self._cpu_test: # If running CPU test, wrap model a little bit to pre-process features. def model_fn_test_wrapper(features, mode, params): features = task.process_features_for_cpu_test(features) return model_fn(features, mode, params) est, total_replicas = self.create_tpu_estimator_on_cpu( model_fn_test_wrapper, feature_config, table_config) else: est, total_replicas = self.create_tpu_estimator(model_fn, feature_config, table_config) start_timestamp = time.time() # This time will include compilation time if self._mode == 'train': est.train(input_fn=input_fn_train, max_steps=self._task_param.train.max_steps) elif self._mode == 'eval': input_fn_eval = task.create_input_fn(tf.estimator.ModeKeys.EVAL) total_examples = self._task_param.input.eval_examples eval_batch_size = self._task_param.eval.per_replica_batch_size * total_replicas eval_steps = total_examples // eval_batch_size logging.info( "Evaluation: total_examples:{} eval_batch_size:{} num_eval_steps: {}". format(total_examples, eval_batch_size, eval_steps)) output_dir = os.path.join(self._model_dir, 'eval') tf.io.gfile.makedirs(output_dir) # Run evaluation when there's a new checkpoint for ckpt in tf.train.checkpoints_iterator(self._model_dir, timeout=60 * 60 * 5): # Terminate eval job when final checkpoint is reached current_step = int(os.path.basename(ckpt).split('-')[1]) try: current_step = int(os.path.basename(ckpt).split('-')[1]) logging.info("Starting to evaluate step: {}".format(current_step)) except: logging.warning("Could not find current step value") try: start_timestamp = time.time( ) # This time will include compilation time eval_results = est.evaluate(input_fn=input_fn_eval, steps=eval_steps, checkpoint_path=ckpt) elapsed_time = int(time.time() - start_timestamp) logging.info("Eval results: {}. Elapsed seconds: {}".format( eval_results, elapsed_time)) # Summary writer writes out eval metrics. summary_writer = tf.compat.v1.summary.FileWriter(output_dir) self.write_summary(eval_results, summary_writer, current_step) summary_writer.close() if current_step >= self._task_param.train.max_steps: logging.info("Evaluation finished after training step {}".format( current_step)) break except tf.errors.NotFoundError: # Since the coordinator is on a different job than the TPU worker, # sometimes the TPU worker does not finish initializing until long after # the CPU job tells it to start evaluating. In this case, the checkpoint # file could have been deleted already. logging.info( "Checkpoint {} no longer exists, skipping checkpoint".format( ckpt)) else: # train_and_eval raise TypeError("{} has not been supported.".format(self._mode)) def main(unused_argv): task_name = FLAGS.task task_param = model_registry.GetParams(task_name) logging.info("FLAGS:") for key, value in FLAGS.__flags.items(): logging.info("{}: {}".format(key, value.value)) logging.info("task_param: {}".format(str(task_param))) runner = TPURunner(task_param) runner.run() if __name__ == '__main__': logging.set_verbosity(logging.INFO) tf.disable_v2_behavior() app.run(main) ================================================ FILE: monolith/utils.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import sys from concurrent.futures import ThreadPoolExecutor import tensorflow as tf from monolith import path_utils find_main = path_utils.find_main get_libops_path = path_utils.get_libops_path def enable_monkey_patch(): name = "tensorflow.python.training.monitored_session" orig_mod = sys.modules.get(name) if orig_mod is None: orig_mod = __import__(name) setattr(orig_mod, "_PREEMPTION_ERRORS", (tf.errors.AbortedError,)) def CopyFile(src, dst, overwrite=True, skip_nonexist=True, max_retries=5): for _ in range(max_retries): try: tf.io.gfile.copy(src, dst, overwrite=overwrite) except tf.errors.NotFoundError as e: if skip_nonexist: continue else: raise e break def CopyRecursively(src: str, dst: str, max_workers: int = 1, skip_nonexist: bool = True, max_retries: int = 5): src_dst = [] def _CopyRecursivelyImpl(src, dst): if not tf.io.gfile.exists(src): if skip_nonexist: return raise ValueError("{} doesn't exist!".format(src)) if not tf.io.gfile.isdir(src): if max_workers > 1: src_dst.append((src, dst)) else: CopyFile(src, dst, overwrite=True, skip_nonexist=skip_nonexist) return if tf.io.gfile.exists(dst): tf.io.gfile.rmtree(dst) tf.io.gfile.makedirs(dst) for relpath in tf.io.gfile.listdir(src): src_path = os.path.join(src, relpath) dst_path = os.path.join(dst, relpath) _CopyRecursivelyImpl(src_path, dst_path) _CopyRecursivelyImpl(src, dst) if max_workers > 1: with ThreadPoolExecutor(max_workers=max_workers) as executor: executor.map( lambda args: CopyFile(args[0], args[1], overwrite=True, skip_nonexist=skip_nonexist, max_retries=max_retries), src_dst) ================================================ FILE: monolith/utils_test.py ================================================ # Copyright 2022 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from monolith import utils from typing import Union import os import unittest import uuid from tensorflow.python.framework import errors import tensorflow.python.training.monitored_session as monitored_session import tensorflow as tf utils.enable_monkey_patch() class UtilsTest(unittest.TestCase): def testFindMain(self): basedir = utils.find_main() self.assertEqual(basedir.split("/")[-1], "__main__") def testGetLibopsPath(self): self.assertTrue( os.path.exists(utils.get_libops_path("monolith/utils_test.py"))) def testLoadMonitoredSession(self): self.assertEqual(monitored_session._PREEMPTION_ERRORS, (errors.AbortedError,)) def testMultiThreadedCopy(self): test_id = uuid.uuid4().hex def _gen_dir(): root = os.path.join('/tmp', test_id, 'src') tf.io.gfile.makedirs(root) subdir = os.path.join(root, 'subdir') tf.io.gfile.mkdir(subdir) with tf.io.gfile.GFile(os.path.join(root, 'file.txt'), 'w+') as f: f.write('root') with tf.io.gfile.GFile(os.path.join(subdir, 'innerfile.txt'), 'w+') as f: f.write('inner') return root src = _gen_dir() dst = os.path.join('/tmp', test_id, 'dst') utils.CopyRecursively(src, dst, max_workers=2) with tf.io.gfile.GFile(os.path.join(dst, 'subdir', 'innerfile.txt'), 'r') as f: self.assertEqual(f.read(), 'inner') if __name__ == "__main__": unittest.main() ================================================ FILE: third_party/BUILD ================================================ exports_files(glob(["**"])) ================================================ FILE: third_party/arrow.BUILD ================================================ # Description: # Apache Arrow library package(default_visibility = ["//visibility:public"]) licenses(["notice"]) # Apache 2.0 exports_files(["LICENSE.txt"]) genrule( name = "arrow_util_config", srcs = ["cpp/src/arrow/util/config.h.cmake"], outs = ["cpp/src/arrow/util/config.h"], cmd = ("sed " + "-e 's/@ARROW_VERSION_MAJOR@/3/g' " + "-e 's/@ARROW_VERSION_MINOR@/0/g' " + "-e 's/@ARROW_VERSION_PATCH@/0/g' " + "-e 's/cmakedefine ARROW_USE_NATIVE_INT128/undef ARROW_USE_NATIVE_INT128/g' " + "-e 's/cmakedefine/define/g' " + "$< >$@"), ) genrule( name = "parquet_version_h", srcs = ["cpp/src/parquet/parquet_version.h.in"], outs = ["cpp/src/parquet/parquet_version.h"], cmd = ("sed " + "-e 's/@PARQUET_VERSION_MAJOR@/1/g' " + "-e 's/@PARQUET_VERSION_MINOR@/5/g' " + "-e 's/@PARQUET_VERSION_PATCH@/1/g' " + "$< >$@"), ) cc_library( name = "arrow", srcs = glob( [ "cpp/src/arrow/*.cc", "cpp/src/arrow/array/*.cc", "cpp/src/arrow/csv/*.cc", "cpp/src/arrow/io/*.cc", "cpp/src/arrow/ipc/*.cc", "cpp/src/arrow/json/*.cc", "cpp/src/arrow/tensor/*.cc", "cpp/src/arrow/util/*.cc", "cpp/src/arrow/vendored/optional.hpp", "cpp/src/arrow/vendored/string_view.hpp", "cpp/src/arrow/vendored/variant.hpp", "cpp/src/arrow/**/*.h", "cpp/src/parquet/**/*.h", "cpp/src/parquet/**/*.cc", "cpp/src/generated/*.h", "cpp/src/generated/*.cpp", "cpp/thirdparty/flatbuffers/include/flatbuffers/*.h", "cpp/src/arrow/vendored/uriparser/*.c", "cpp/src/arrow/vendored/base64.cpp", "cpp/src/arrow/compute/**/*.cc", "cpp/src/arrow/vendored/datetime/*.h", "cpp/src/arrow/vendored/datetime/*.cpp", "cpp/src/arrow/vendored/pcg/*.hpp", ], exclude = [ "cpp/src/**/*_benchmark.cc", "cpp/src/**/*_main.cc", "cpp/src/**/*_nossl.cc", "cpp/src/**/*_test.cc", "cpp/src/**/test_*.cc", "cpp/src/**/*hdfs*.cc", "cpp/src/**/*fuzz*.cc", "cpp/src/**/file_to_stream.cc", "cpp/src/**/stream_to_file.cc", "cpp/src/arrow/util/bpacking_avx2.cc", "cpp/src/arrow/util/bpacking_avx512.cc", "cpp/src/arrow/util/bpacking_neon.cc", "cpp/src/arrow/util/tracing_internal.cc", "cpp/src/arrow/compute/exec/*_avx*.cc", ], ) + select({ "@bazel_tools//src/conditions:windows": [ "cpp/src/arrow/vendored/musl/strptime.c", ], "//conditions:default": [], }), hdrs = [ # declare header from above genrule "cpp/src/arrow/util/config.h", "cpp/src/parquet/parquet_version.h", ], copts = select({ "@bazel_tools//src/conditions:windows": [ "/std:c++14", ], "//conditions:default": [ "-std=c++14", ], }), defines = [ "ARROW_WITH_BROTLI", "ARROW_WITH_SNAPPY", "ARROW_WITH_LZ4", "ARROW_WITH_ZLIB", "ARROW_WITH_ZSTD", "ARROW_WITH_BZ2", "ARROW_STATIC", "ARROW_EXPORT=", "PARQUET_STATIC", "PARQUET_EXPORT=", "WIN32_LEAN_AND_MEAN", ], includes = [ "cpp/src", "cpp/src/arrow/vendored/xxhash", "cpp/thirdparty/flatbuffers/include", ], textual_hdrs = [ "cpp/src/arrow/vendored/xxhash/xxhash.c", ], deps = [ "@boringssl//:crypto", "@brotli", "@bzip2", "@double_conversion//:double-conversion", "@lz4", "@rapidjson", "@snappy//:snappy", "@thrift", "@xsimd", "@zlib", "@zstd", ], ) ================================================ FILE: third_party/brotli.BUILD ================================================ # Description: # Brotli library licenses(["notice"]) # MIT license exports_files(["LICENSE"]) cc_library( name = "brotli", srcs = glob([ "c/common/*.c", "c/common/*.h", "c/dec/*.c", "c/dec/*.h", "c/enc/*.c", "c/enc/*.h", "c/include/brotli/*.h", ]), hdrs = [], defines = [], includes = [ "c/dec", "c/include", ], linkopts = [], visibility = ["//visibility:public"], ) ================================================ FILE: third_party/bzip2.BUILD ================================================ package(default_visibility = ["//visibility:public"]) licenses(["notice"]) # BSD-like license cc_library( name = "bzip2", srcs = [ "blocksort.c", "bzlib.c", "bzlib_private.h", "compress.c", "crctable.c", "decompress.c", "huffman.c", "randtable.c", ], hdrs = [ "bzlib.h", ], copts = [ ], includes = ["."], ) ================================================ FILE: third_party/cli11/BUILD ================================================ load("@rules_cc//cc:defs.bzl", "cc_library") package(default_visibility = ["//visibility:public"]) cc_library( name = "cli11", hdrs = ["CLI11.hpp"], ) ================================================ FILE: third_party/cli11/CLI11.hpp ================================================ // Copyright 2022 ByteDance and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // CLI11: Version 2.3.2 // Originally designed by Henry Schreiner // https://github.com/CLIUtils/CLI11 // // This is a standalone header file generated by MakeSingleHeader.py in // CLI11/scripts from: v2.3.2 // // CLI11 2.3.2 Copyright (c) 2017-2022 University of Cincinnati, developed by // Henry Schreiner under NSF AWARD 1414736. All rights reserved. // // Redistribution and use in source and binary forms of CLI11, with or without // modification, are permitted provided that the following conditions are met: // // 1. Redistributions of source code must retain the above copyright notice, // this // list of conditions and the following disclaimer. // 2. Redistributions in binary form must reproduce the above copyright notice, // this list of conditions and the following disclaimer in the documentation // and/or other materials provided with the distribution. // 3. Neither the name of the copyright holder nor the names of its contributors // may be used to endorse or promote products derived from this software // without specific prior written permission. // // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" // AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE // ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE // LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR // CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF // SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS // INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN // CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) // ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE // POSSIBILITY OF SUCH DAMAGE. #pragma once // Standard combined includes: #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #define CLI11_VERSION_MAJOR 2 #define CLI11_VERSION_MINOR 3 #define CLI11_VERSION_PATCH 2 #define CLI11_VERSION "2.3.2" // The following version macro is very similar to the one in pybind11 #if !(defined(_MSC_VER) && __cplusplus == 199711L) && !defined(__INTEL_COMPILER) #if __cplusplus >= 201402L #define CLI11_CPP14 #if __cplusplus >= 201703L #define CLI11_CPP17 #if __cplusplus > 201703L #define CLI11_CPP20 #endif #endif #endif #elif defined(_MSC_VER) && __cplusplus == 199711L // MSVC sets _MSVC_LANG rather than __cplusplus (supposedly until the standard // is fully implemented) Unless you use the /Zc:__cplusplus flag on Visual // Studio 2017 15.7 Preview 3 or newer #if _MSVC_LANG >= 201402L #define CLI11_CPP14 #if _MSVC_LANG > 201402L && _MSC_VER >= 1910 #define CLI11_CPP17 #if _MSVC_LANG > 201703L && _MSC_VER >= 1910 #define CLI11_CPP20 #endif #endif #endif #endif #if defined(CLI11_CPP14) #define CLI11_DEPRECATED(reason) [[deprecated(reason)]] #elif defined(_MSC_VER) #define CLI11_DEPRECATED(reason) __declspec(deprecated(reason)) #else #define CLI11_DEPRECATED(reason) __attribute__((deprecated(reason))) #endif // GCC < 10 doesn't ignore this in unevaluated contexts #if !defined(CLI11_CPP17) || \ (defined(__GNUC__) && !defined(__llvm__) && !defined(__INTEL_COMPILER) && \ __GNUC__ < 10 && __GNUC__ > 4) #define CLI11_NODISCARD #else #define CLI11_NODISCARD [[nodiscard]] #endif /** detection of rtti */ #ifndef CLI11_USE_STATIC_RTTI #if (defined(_HAS_STATIC_RTTI) && _HAS_STATIC_RTTI) #define CLI11_USE_STATIC_RTTI 1 #elif defined(__cpp_rtti) #if (defined(_CPPRTTI) && _CPPRTTI == 0) #define CLI11_USE_STATIC_RTTI 1 #else #define CLI11_USE_STATIC_RTTI 0 #endif #elif (defined(__GCC_RTTI) && __GXX_RTTI) #define CLI11_USE_STATIC_RTTI 0 #else #define CLI11_USE_STATIC_RTTI 1 #endif #endif /** Inline macro **/ #ifdef CLI11_COMPILE #define CLI11_INLINE #else #define CLI11_INLINE inline #endif // C standard library // Only needed for existence checking #if defined CLI11_CPP17 && defined __has_include && \ !defined CLI11_HAS_FILESYSTEM #if __has_include() // Filesystem cannot be used if targeting macOS < 10.15 #if defined __MAC_OS_X_VERSION_MIN_REQUIRED && \ __MAC_OS_X_VERSION_MIN_REQUIRED < 101500 #define CLI11_HAS_FILESYSTEM 0 #elif defined(__wasi__) // As of wasi-sdk-14, filesystem is not implemented #define CLI11_HAS_FILESYSTEM 0 #else #include #if defined __cpp_lib_filesystem && __cpp_lib_filesystem >= 201703 #if defined _GLIBCXX_RELEASE && _GLIBCXX_RELEASE >= 9 #define CLI11_HAS_FILESYSTEM 1 #elif defined(__GLIBCXX__) // if we are using gcc and Version <9 default to no filesystem #define CLI11_HAS_FILESYSTEM 0 #else #define CLI11_HAS_FILESYSTEM 1 #endif #else #define CLI11_HAS_FILESYSTEM 0 #endif #endif #endif #endif #if defined CLI11_HAS_FILESYSTEM && CLI11_HAS_FILESYSTEM > 0 #include // NOLINT(build/include) #else #include #include #endif namespace CLI { /// Include the items in this namespace to get free conversion of enums to/from /// streams. (This is available inside CLI as well, so CLI11 will use this /// without a using statement). namespace enums { /// output streaming for enumerations template ::value>::type> std::ostream &operator<<(std::ostream &in, const T &item) { // make sure this is out of the detail namespace otherwise it won't be found // when needed return in << static_cast::type>(item); } } // namespace enums /// Export to CLI namespace using enums::operator<<; namespace detail { /// a constant defining an expected max vector size defined to be a big number /// that could be multiplied by 4 and not produce overflow for some expected /// uses constexpr int expected_max_vector_size{1 << 29}; // Based on http://stackoverflow.com/questions/236129/split-a-string-in-c /// Split a string by a delim CLI11_INLINE std::vector split(const std::string &s, char delim); /// Simple function to join a string template std::string join(const T &v, std::string delim = ",") { std::ostringstream s; auto beg = std::begin(v); auto end = std::end(v); if (beg != end) s << *beg++; while (beg != end) { s << delim << *beg++; } return s.str(); } /// Simple function to join a string from processed elements template ::value>::type> std::string join(const T &v, Callable func, std::string delim = ",") { std::ostringstream s; auto beg = std::begin(v); auto end = std::end(v); auto loc = s.tellp(); while (beg != end) { auto nloc = s.tellp(); if (nloc > loc) { s << delim; loc = nloc; } s << func(*beg++); } return s.str(); } /// Join a string in reverse order template std::string rjoin(const T &v, std::string delim = ",") { std::ostringstream s; for (std::size_t start = 0; start < v.size(); start++) { if (start > 0) s << delim; s << v[v.size() - start - 1]; } return s.str(); } // Based roughly on // http://stackoverflow.com/questions/25829143/c-trim-whitespace-from-a-string /// Trim whitespace from left of string CLI11_INLINE std::string <rim(std::string &str); /// Trim anything from left of string CLI11_INLINE std::string <rim(std::string &str, const std::string &filter); /// Trim whitespace from right of string CLI11_INLINE std::string &rtrim(std::string &str); /// Trim anything from right of string CLI11_INLINE std::string &rtrim(std::string &str, const std::string &filter); /// Trim whitespace from string inline std::string &trim(std::string &str) { return ltrim(rtrim(str)); } /// Trim anything from string inline std::string &trim(std::string &str, const std::string filter) { return ltrim(rtrim(str, filter), filter); } /// Make a copy of the string and then trim it inline std::string trim_copy(const std::string &str) { std::string s = str; return trim(s); } /// remove quotes at the front and back of a string either '"' or '\'' CLI11_INLINE std::string &remove_quotes(std::string &str); /// Add a leader to the beginning of all new lines (nothing is added /// at the start of the first line). `"; "` would be for ini files /// /// Can't use Regex, or this would be a subs. CLI11_INLINE std::string fix_newlines(const std::string &leader, std::string input); /// Make a copy of the string and then trim it, any filter string can be used /// (any char in string is filtered) inline std::string trim_copy(const std::string &str, const std::string &filter) { std::string s = str; return trim(s, filter); } /// Print a two part "help" string CLI11_INLINE std::ostream &format_help(std::ostream &out, std::string name, const std::string &description, std::size_t wid); /// Print subcommand aliases CLI11_INLINE std::ostream &format_aliases( std::ostream &out, const std::vector &aliases, std::size_t wid); /// Verify the first character of an option /// - is a trigger character, ! has special meaning and new lines would just be /// annoying to deal with template bool valid_first_char(T c) { return ((c != '-') && (c != '!') && (c != ' ') && c != '\n'); } /// Verify following characters of an option template bool valid_later_char(T c) { // = and : are value separators, { has special meaning for option defaults, // and \n would just be annoying to deal with in many places allowing space // here has too much potential for inadvertent entry errors and bugs return ((c != '=') && (c != ':') && (c != '{') && (c != ' ') && c != '\n'); } /// Verify an option/subcommand name CLI11_INLINE bool valid_name_string(const std::string &str); /// Verify an app name inline bool valid_alias_name_string(const std::string &str) { static const std::string badChars(std::string("\n") + '\0'); return (str.find_first_of(badChars) == std::string::npos); } /// check if a string is a container segment separator (empty or "%%") inline bool is_separator(const std::string &str) { static const std::string sep("%%"); return (str.empty() || str == sep); } /// Verify that str consists of letters only inline bool isalpha(const std::string &str) { return std::all_of(str.begin(), str.end(), [](char c) { return std::isalpha(c, std::locale()); }); } /// Return a lower case version of a string inline std::string to_lower(std::string str) { std::transform(std::begin(str), std::end(str), std::begin(str), [](const std::string::value_type &x) { return std::tolower(x, std::locale()); }); return str; } /// remove underscores from a string inline std::string remove_underscore(std::string str) { str.erase(std::remove(std::begin(str), std::end(str), '_'), std::end(str)); return str; } /// Find and replace a substring with another substring CLI11_INLINE std::string find_and_replace(std::string str, std::string from, std::string to); /// check if the flag definitions has possible false flags inline bool has_default_flag_values(const std::string &flags) { return (flags.find_first_of("{!") != std::string::npos); } CLI11_INLINE void remove_default_flag_values(std::string &flags); /// Check if a string is a member of a list of strings and optionally ignore /// case or ignore underscores CLI11_INLINE std::ptrdiff_t find_member(std::string name, const std::vector names, bool ignore_case = false, bool ignore_underscore = false); /// Find a trigger string and call a modify callable function that takes the /// current string and starting position of the trigger and returns the position /// in the string to search for the next trigger string template inline std::string find_and_modify(std::string str, std::string trigger, Callable modify) { std::size_t start_pos = 0; while ((start_pos = str.find(trigger, start_pos)) != std::string::npos) { start_pos = modify(str, start_pos); } return str; } /// Split a string '"one two" "three"' into 'one two', 'three' /// Quote characters can be ` ' or " CLI11_INLINE std::vector split_up(std::string str, char delimiter = '\0'); /// This function detects an equal or colon followed by an escaped quote after /// an argument then modifies the string to replace the equality with a space. /// This is needed to allow the split up function to work properly and is /// intended to be used with the find_and_modify function the return value is /// the offset+1 which is required by the find_and_modify function. CLI11_INLINE std::size_t escape_detect(std::string &str, std::size_t offset); /// Add quotes if the string contains spaces CLI11_INLINE std::string &add_quotes_if_needed(std::string &str); } // namespace detail namespace detail { CLI11_INLINE std::vector split(const std::string &s, char delim) { std::vector elems; // Check to see if empty string, give consistent result if (s.empty()) { elems.emplace_back(); } else { std::stringstream ss; ss.str(s); std::string item; while (std::getline(ss, item, delim)) { elems.push_back(item); } } return elems; } CLI11_INLINE std::string <rim(std::string &str) { auto it = std::find_if(str.begin(), str.end(), [](char ch) { return !std::isspace(ch, std::locale()); }); str.erase(str.begin(), it); return str; } CLI11_INLINE std::string <rim(std::string &str, const std::string &filter) { auto it = std::find_if(str.begin(), str.end(), [&filter](char ch) { return filter.find(ch) == std::string::npos; }); str.erase(str.begin(), it); return str; } CLI11_INLINE std::string &rtrim(std::string &str) { auto it = std::find_if(str.rbegin(), str.rend(), [](char ch) { return !std::isspace(ch, std::locale()); }); str.erase(it.base(), str.end()); return str; } CLI11_INLINE std::string &rtrim(std::string &str, const std::string &filter) { auto it = std::find_if(str.rbegin(), str.rend(), [&filter](char ch) { return filter.find(ch) == std::string::npos; }); str.erase(it.base(), str.end()); return str; } CLI11_INLINE std::string &remove_quotes(std::string &str) { if (str.length() > 1 && (str.front() == '"' || str.front() == '\'')) { if (str.front() == str.back()) { str.pop_back(); str.erase(str.begin(), str.begin() + 1); } } return str; } CLI11_INLINE std::string fix_newlines(const std::string &leader, std::string input) { std::string::size_type n = 0; while (n != std::string::npos && n < input.size()) { n = input.find('\n', n); if (n != std::string::npos) { input = input.substr(0, n + 1) + leader + input.substr(n + 1); n += leader.size(); } } return input; } CLI11_INLINE std::ostream &format_help(std::ostream &out, std::string name, const std::string &description, std::size_t wid) { name = " " + name; out << std::setw(static_cast(wid)) << std::left << name; if (!description.empty()) { if (name.length() >= wid) out << "\n" << std::setw(static_cast(wid)) << ""; for (const char c : description) { out.put(c); if (c == '\n') { out << std::setw(static_cast(wid)) << ""; } } } out << "\n"; return out; } CLI11_INLINE std::ostream &format_aliases( std::ostream &out, const std::vector &aliases, std::size_t wid) { if (!aliases.empty()) { out << std::setw(static_cast(wid)) << " aliases: "; bool front = true; for (const auto &alias : aliases) { if (!front) { out << ", "; } else { front = false; } out << detail::fix_newlines(" ", alias); } out << "\n"; } return out; } CLI11_INLINE bool valid_name_string(const std::string &str) { if (str.empty() || !valid_first_char(str[0])) { return false; } auto e = str.end(); for (auto c = str.begin() + 1; c != e; ++c) if (!valid_later_char(*c)) return false; return true; } CLI11_INLINE std::string find_and_replace(std::string str, std::string from, std::string to) { std::size_t start_pos = 0; while ((start_pos = str.find(from, start_pos)) != std::string::npos) { str.replace(start_pos, from.length(), to); start_pos += to.length(); } return str; } CLI11_INLINE void remove_default_flag_values(std::string &flags) { auto loc = flags.find_first_of('{', 2); while (loc != std::string::npos) { auto finish = flags.find_first_of("},", loc + 1); if ((finish != std::string::npos) && (flags[finish] == '}')) { flags.erase(flags.begin() + static_cast(loc), flags.begin() + static_cast(finish) + 1); } loc = flags.find_first_of('{', loc + 1); } flags.erase(std::remove(flags.begin(), flags.end(), '!'), flags.end()); } CLI11_INLINE std::ptrdiff_t find_member(std::string name, const std::vector names, bool ignore_case, bool ignore_underscore) { auto it = std::end(names); if (ignore_case) { if (ignore_underscore) { name = detail::to_lower(detail::remove_underscore(name)); it = std::find_if( std::begin(names), std::end(names), [&name](std::string local_name) { return detail::to_lower(detail::remove_underscore(local_name)) == name; }); } else { name = detail::to_lower(name); it = std::find_if(std::begin(names), std::end(names), [&name](std::string local_name) { return detail::to_lower(local_name) == name; }); } } else if (ignore_underscore) { name = detail::remove_underscore(name); it = std::find_if(std::begin(names), std::end(names), [&name](std::string local_name) { return detail::remove_underscore(local_name) == name; }); } else { it = std::find(std::begin(names), std::end(names), name); } return (it != std::end(names)) ? (it - std::begin(names)) : (-1); } CLI11_INLINE std::vector split_up(std::string str, char delimiter) { const std::string delims("\'\"`"); auto find_ws = [delimiter](char ch) { return (delimiter == '\0') ? std::isspace(ch, std::locale()) : (ch == delimiter); }; trim(str); std::vector output; bool embeddedQuote = false; char keyChar = ' '; while (!str.empty()) { if (delims.find_first_of(str[0]) != std::string::npos) { keyChar = str[0]; auto end = str.find_first_of(keyChar, 1); while ((end != std::string::npos) && (str[end - 1] == '\\')) { // deal with escaped quotes end = str.find_first_of(keyChar, end + 1); embeddedQuote = true; } if (end != std::string::npos) { output.push_back(str.substr(1, end - 1)); if (end + 2 < str.size()) { str = str.substr(end + 2); } else { str.clear(); } } else { output.push_back(str.substr(1)); str = ""; } } else { auto it = std::find_if(std::begin(str), std::end(str), find_ws); if (it != std::end(str)) { std::string value = std::string(str.begin(), it); output.push_back(value); str = std::string(it + 1, str.end()); } else { output.push_back(str); str = ""; } } // transform any embedded quotes into the regular character if (embeddedQuote) { output.back() = find_and_replace( output.back(), std::string("\\") + keyChar, std::string(1, keyChar)); embeddedQuote = false; } trim(str); } return output; } CLI11_INLINE std::size_t escape_detect(std::string &str, std::size_t offset) { auto next = str[offset + 1]; if ((next == '\"') || (next == '\'') || (next == '`')) { auto astart = str.find_last_of("-/ \"\'`", offset - 1); if (astart != std::string::npos) { if (str[astart] == ((str[offset] == '=') ? '-' : '/')) str[offset] = ' '; // interpret this as a space so the split_up works properly } } return offset + 1; } CLI11_INLINE std::string &add_quotes_if_needed(std::string &str) { if ((str.front() != '"' && str.front() != '\'') || str.front() != str.back()) { char quote = str.find('"') < str.find('\'') ? '\'' : '"'; if (str.find(' ') != std::string::npos) { str.insert(0, 1, quote); str.append(1, quote); } } return str; } } // namespace detail // Use one of these on all error classes. // These are temporary and are undef'd at the end of this file. #define CLI11_ERROR_DEF(parent, name) \ protected: \ name(std::string ename, std::string msg, int exit_code) \ : parent(std::move(ename), std::move(msg), exit_code) {} \ name(std::string ename, std::string msg, ExitCodes exit_code) \ : parent(std::move(ename), std::move(msg), exit_code) {} \ \ public: \ name(std::string msg, ExitCodes exit_code) \ : parent(#name, std::move(msg), exit_code) {} \ name(std::string msg, int exit_code) \ : parent(#name, std::move(msg), exit_code) {} // This is added after the one above if a class is used directly and builds its // own message #define CLI11_ERROR_SIMPLE(name) \ explicit name(std::string msg) : name(#name, msg, ExitCodes::name) {} /// These codes are part of every error in CLI. They can be obtained from e /// using e.exit_code or as a quick shortcut, int values from /// e.get_error_code(). enum class ExitCodes { Success = 0, IncorrectConstruction = 100, BadNameString, OptionAlreadyAdded, FileError, ConversionError, ValidationError, RequiredError, RequiresError, ExcludesError, ExtrasError, ConfigError, InvalidError, HorribleError, OptionNotFound, ArgumentMismatch, BaseClass = 127 }; // Error definitions /// @defgroup error_group Errors /// @brief Errors thrown by CLI11 /// /// These are the errors that can be thrown. Some of them, like CLI::Success, /// are not really errors. /// @{ /// All errors derive from this one class Error : public std::runtime_error { int actual_exit_code; std::string error_name{"Error"}; public: CLI11_NODISCARD int get_exit_code() const { return actual_exit_code; } CLI11_NODISCARD std::string get_name() const { return error_name; } Error(std::string name, std::string msg, int exit_code = static_cast(ExitCodes::BaseClass)) : runtime_error(msg), actual_exit_code(exit_code), error_name(std::move(name)) {} Error(std::string name, std::string msg, ExitCodes exit_code) : Error(name, msg, static_cast(exit_code)) {} }; // Note: Using Error::Error constructors does not work on GCC 4.7 /// Construction errors (not in parsing) class ConstructionError : public Error { CLI11_ERROR_DEF(Error, ConstructionError) }; /// Thrown when an option is set to conflicting values (non-vector and multi /// args, for example) class IncorrectConstruction : public ConstructionError { CLI11_ERROR_DEF(ConstructionError, IncorrectConstruction) CLI11_ERROR_SIMPLE(IncorrectConstruction) static IncorrectConstruction PositionalFlag(std::string name) { return IncorrectConstruction(name + ": Flags cannot be positional"); } static IncorrectConstruction Set0Opt(std::string name) { return IncorrectConstruction(name + ": Cannot set 0 expected, use a flag instead"); } static IncorrectConstruction SetFlag(std::string name) { return IncorrectConstruction(name + ": Cannot set an expected number for flags"); } static IncorrectConstruction ChangeNotVector(std::string name) { return IncorrectConstruction( name + ": You can only change the expected arguments for vectors"); } static IncorrectConstruction AfterMultiOpt(std::string name) { return IncorrectConstruction(name + ": You can't change expected arguments after " "you've changed the multi option policy!"); } static IncorrectConstruction MissingOption(std::string name) { return IncorrectConstruction("Option " + name + " is not defined"); } static IncorrectConstruction MultiOptionPolicy(std::string name) { return IncorrectConstruction( name + ": multi_option_policy only works for flags and exact value options"); } }; /// Thrown on construction of a bad name class BadNameString : public ConstructionError { CLI11_ERROR_DEF(ConstructionError, BadNameString) CLI11_ERROR_SIMPLE(BadNameString) static BadNameString OneCharName(std::string name) { return BadNameString("Invalid one char name: " + name); } static BadNameString BadLongName(std::string name) { return BadNameString("Bad long name: " + name); } static BadNameString DashesOnly(std::string name) { return BadNameString("Must have a name, not just dashes: " + name); } static BadNameString MultiPositionalNames(std::string name) { return BadNameString("Only one positional name allowed, remove: " + name); } }; /// Thrown when an option already exists class OptionAlreadyAdded : public ConstructionError { CLI11_ERROR_DEF(ConstructionError, OptionAlreadyAdded) explicit OptionAlreadyAdded(std::string name) : OptionAlreadyAdded(name + " is already added", ExitCodes::OptionAlreadyAdded) {} static OptionAlreadyAdded Requires(std::string name, std::string other) { return {name + " requires " + other, ExitCodes::OptionAlreadyAdded}; } static OptionAlreadyAdded Excludes(std::string name, std::string other) { return {name + " excludes " + other, ExitCodes::OptionAlreadyAdded}; } }; // Parsing errors /// Anything that can error in Parse class ParseError : public Error { CLI11_ERROR_DEF(Error, ParseError) }; // Not really "errors" /// This is a successful completion on parsing, supposed to exit class Success : public ParseError { CLI11_ERROR_DEF(ParseError, Success) Success() : Success("Successfully completed, should be caught and quit", ExitCodes::Success) {} }; /// -h or --help on command line class CallForHelp : public Success { CLI11_ERROR_DEF(Success, CallForHelp) CallForHelp() : CallForHelp("This should be caught in your main function, see examples", ExitCodes::Success) {} }; /// Usually something like --help-all on command line class CallForAllHelp : public Success { CLI11_ERROR_DEF(Success, CallForAllHelp) CallForAllHelp() : CallForAllHelp( "This should be caught in your main function, see examples", ExitCodes::Success) {} }; /// -v or --version on command line class CallForVersion : public Success { CLI11_ERROR_DEF(Success, CallForVersion) CallForVersion() : CallForVersion( "This should be caught in your main function, see examples", ExitCodes::Success) {} }; /// Does not output a diagnostic in CLI11_PARSE, but allows main() to return /// with a specific error code. class RuntimeError : public ParseError { CLI11_ERROR_DEF(ParseError, RuntimeError) explicit RuntimeError(int exit_code = 1) : RuntimeError("Runtime error", exit_code) {} }; /// Thrown when parsing an INI file and it is missing class FileError : public ParseError { CLI11_ERROR_DEF(ParseError, FileError) CLI11_ERROR_SIMPLE(FileError) static FileError Missing(std::string name) { return FileError(name + " was not readable (missing?)"); } }; /// Thrown when conversion call back fails, such as when an int fails to coerce /// to a string class ConversionError : public ParseError { CLI11_ERROR_DEF(ParseError, ConversionError) CLI11_ERROR_SIMPLE(ConversionError) ConversionError(std::string member, std::string name) : ConversionError("The value " + member + " is not an allowed value for " + name) {} ConversionError(std::string name, std::vector results) : ConversionError("Could not convert: " + name + " = " + detail::join(results)) {} static ConversionError TooManyInputsFlag(std::string name) { return ConversionError(name + ": too many inputs for a flag"); } static ConversionError TrueFalse(std::string name) { return ConversionError(name + ": Should be true/false or a number"); } }; /// Thrown when validation of results fails class ValidationError : public ParseError { CLI11_ERROR_DEF(ParseError, ValidationError) CLI11_ERROR_SIMPLE(ValidationError) explicit ValidationError(std::string name, std::string msg) : ValidationError(name + ": " + msg) {} }; /// Thrown when a required option is missing class RequiredError : public ParseError { CLI11_ERROR_DEF(ParseError, RequiredError) explicit RequiredError(std::string name) : RequiredError(name + " is required", ExitCodes::RequiredError) {} static RequiredError Subcommand(std::size_t min_subcom) { if (min_subcom == 1) { return RequiredError("A subcommand"); } return {"Requires at least " + std::to_string(min_subcom) + " subcommands", ExitCodes::RequiredError}; } static RequiredError Option(std::size_t min_option, std::size_t max_option, std::size_t used, const std::string &option_list) { if ((min_option == 1) && (max_option == 1) && (used == 0)) return RequiredError("Exactly 1 option from [" + option_list + "]"); if ((min_option == 1) && (max_option == 1) && (used > 1)) { return {"Exactly 1 option from [" + option_list + "] is required and " + std::to_string(used) + " were given", ExitCodes::RequiredError}; } if ((min_option == 1) && (used == 0)) return RequiredError("At least 1 option from [" + option_list + "]"); if (used < min_option) { return {"Requires at least " + std::to_string(min_option) + " options used and only " + std::to_string(used) + "were given from [" + option_list + "]", ExitCodes::RequiredError}; } if (max_option == 1) return {"Requires at most 1 options be given from [" + option_list + "]", ExitCodes::RequiredError}; return {"Requires at most " + std::to_string(max_option) + " options be used and " + std::to_string(used) + "were given from [" + option_list + "]", ExitCodes::RequiredError}; } }; /// Thrown when the wrong number of arguments has been received class ArgumentMismatch : public ParseError { CLI11_ERROR_DEF(ParseError, ArgumentMismatch) CLI11_ERROR_SIMPLE(ArgumentMismatch) ArgumentMismatch(std::string name, int expected, std::size_t received) : ArgumentMismatch(expected > 0 ? ("Expected exactly " + std::to_string(expected) + " arguments to " + name + ", got " + std::to_string(received)) : ("Expected at least " + std::to_string(-expected) + " arguments to " + name + ", got " + std::to_string(received)), ExitCodes::ArgumentMismatch) {} static ArgumentMismatch AtLeast(std::string name, int num, std::size_t received) { return ArgumentMismatch(name + ": At least " + std::to_string(num) + " required but received " + std::to_string(received)); } static ArgumentMismatch AtMost(std::string name, int num, std::size_t received) { return ArgumentMismatch(name + ": At Most " + std::to_string(num) + " required but received " + std::to_string(received)); } static ArgumentMismatch TypedAtLeast(std::string name, int num, std::string type) { return ArgumentMismatch(name + ": " + std::to_string(num) + " required " + type + " missing"); } static ArgumentMismatch FlagOverride(std::string name) { return ArgumentMismatch(name + " was given a disallowed flag override"); } static ArgumentMismatch PartialType(std::string name, int num, std::string type) { return ArgumentMismatch(name + ": " + type + " only partially specified: " + std::to_string(num) + " required for each element"); } }; /// Thrown when a requires option is missing class RequiresError : public ParseError { CLI11_ERROR_DEF(ParseError, RequiresError) RequiresError(std::string curname, std::string subname) : RequiresError(curname + " requires " + subname, ExitCodes::RequiresError) {} }; /// Thrown when an excludes option is present class ExcludesError : public ParseError { CLI11_ERROR_DEF(ParseError, ExcludesError) ExcludesError(std::string curname, std::string subname) : ExcludesError(curname + " excludes " + subname, ExitCodes::ExcludesError) {} }; /// Thrown when too many positionals or options are found class ExtrasError : public ParseError { CLI11_ERROR_DEF(ParseError, ExtrasError) explicit ExtrasError(std::vector args) : ExtrasError( (args.size() > 1 ? "The following arguments were not expected: " : "The following argument was not expected: ") + detail::rjoin(args, " "), ExitCodes::ExtrasError) {} ExtrasError(const std::string &name, std::vector args) : ExtrasError( name, (args.size() > 1 ? "The following arguments were not expected: " : "The following argument was not expected: ") + detail::rjoin(args, " "), ExitCodes::ExtrasError) {} }; /// Thrown when extra values are found in an INI file class ConfigError : public ParseError { CLI11_ERROR_DEF(ParseError, ConfigError) CLI11_ERROR_SIMPLE(ConfigError) static ConfigError Extras(std::string item) { return ConfigError("INI was not able to parse " + item); } static ConfigError NotConfigurable(std::string item) { return ConfigError(item + ": This option is not allowed in a configuration file"); } }; /// Thrown when validation fails before parsing class InvalidError : public ParseError { CLI11_ERROR_DEF(ParseError, InvalidError) explicit InvalidError(std::string name) : InvalidError( name + ": Too many positional arguments with unlimited expected args", ExitCodes::InvalidError) {} }; /// This is just a safety check to verify selection and parsing match - you /// should not ever see it Strings are directly added to this error, but again, /// it should never be seen. class HorribleError : public ParseError { CLI11_ERROR_DEF(ParseError, HorribleError) CLI11_ERROR_SIMPLE(HorribleError) }; // After parsing /// Thrown when counting a non-existent option class OptionNotFound : public Error { CLI11_ERROR_DEF(Error, OptionNotFound) explicit OptionNotFound(std::string name) : OptionNotFound(name + " not found", ExitCodes::OptionNotFound) {} }; #undef CLI11_ERROR_DEF #undef CLI11_ERROR_SIMPLE /// @} // Type tools // Utilities for type enabling namespace detail { // Based generally on https://rmf.io/cxx11/almost-static-if /// Simple empty scoped class enum class enabler {}; /// An instance to use in EnableIf constexpr enabler dummy = {}; } // namespace detail /// A copy of enable_if_t from C++14, compatible with C++11. /// /// We could check to see if C++14 is being used, but it does not hurt to /// redefine this (even Google does this: /// https://github.com/google/skia/blob/main/include/private/SkTLogic.h) It is /// not in the std namespace anyway, so no harm done. template using enable_if_t = typename std::enable_if::type; /// A copy of std::void_t from C++17 (helper for C++11 and C++14) template struct make_void { using type = void; }; /// A copy of std::void_t from C++17 - same reasoning as enable_if_t, it does /// not hurt to redefine template using void_t = typename make_void::type; /// A copy of std::conditional_t from C++14 - same reasoning as enable_if_t, it /// does not hurt to redefine template using conditional_t = typename std::conditional::type; /// Check to see if something is bool (fail check by default) template struct is_bool : std::false_type {}; /// Check to see if something is bool (true if actually a bool) template <> struct is_bool : std::true_type {}; /// Check to see if something is a shared pointer template struct is_shared_ptr : std::false_type {}; /// Check to see if something is a shared pointer (True if really a shared /// pointer) template struct is_shared_ptr> : std::true_type {}; /// Check to see if something is a shared pointer (True if really a shared /// pointer) template struct is_shared_ptr> : std::true_type {}; /// Check to see if something is copyable pointer template struct is_copyable_ptr { static bool const value = is_shared_ptr::value || std::is_pointer::value; }; /// This can be specialized to override the type deduction for IsMember. template struct IsMemberType { using type = T; }; /// The main custom type needed here is const char * should be a string. template <> struct IsMemberType { using type = std::string; }; namespace detail { // These are utilities for IsMember and other transforming objects /// Handy helper to access the element_type generically. This is not part of /// is_copyable_ptr because it requires that pointer_traits be valid. /// not a pointer template struct element_type { using type = T; }; template struct element_type::value>::type> { using type = typename std::pointer_traits::element_type; }; /// Combination of the element type and value type - remove pointer (including /// smart pointers) and get the value_type of the container template struct element_value_type { using type = typename element_type::type::value_type; }; /// Adaptor for set-like structure: This just wraps a normal container in a few /// utilities that do almost nothing. template struct pair_adaptor : std::false_type { using value_type = typename T::value_type; using first_type = typename std::remove_const::type; using second_type = typename std::remove_const::type; /// Get the first value (really just the underlying value) template static auto first(Q &&pair_value) -> decltype(std::forward(pair_value)) { return std::forward(pair_value); } /// Get the second value (really just the underlying value) template static auto second(Q &&pair_value) -> decltype(std::forward(pair_value)) { return std::forward(pair_value); } }; /// Adaptor for map-like structure (true version, must have key_type and /// mapped_type). This wraps a mapped container in a few utilities access it in /// a general way. template struct pair_adaptor, void>> : std::true_type { using value_type = typename T::value_type; using first_type = typename std::remove_const::type; using second_type = typename std::remove_const::type; /// Get the first value (really just the underlying value) template static auto first(Q &&pair_value) -> decltype(std::get<0>(std::forward(pair_value))) { return std::get<0>(std::forward(pair_value)); } /// Get the second value (really just the underlying value) template static auto second(Q &&pair_value) -> decltype(std::get<1>(std::forward(pair_value))) { return std::get<1>(std::forward(pair_value)); } }; // Warning is suppressed due to "bug" in gcc<5.0 and gcc 7.0 with c++17 enabled // that generates a Wnarrowing warning in the unevaluated context even if the // function that was using this wasn't used. The standard says narrowing in // brace initialization shouldn't be allowed but for backwards compatibility gcc // allows it in some contexts. It is a little fuzzy what happens in template // constructs and I think that was something GCC took a little while to work // out. But regardless some versions of gcc generate a warning when they // shouldn't from the following code so that should be suppressed #ifdef __GNUC__ #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wnarrowing" #endif // check for constructibility from a specific type and copy assignable used in // the parse detection template class is_direct_constructible { template static auto test(int, std::true_type) -> decltype( // NVCC warns about narrowing conversions here #ifdef __CUDACC__ #pragma diag_suppress 2361 #endif TT{std::declval()} #ifdef __CUDACC__ #pragma diag_default 2361 #endif , std::is_move_assignable()); template static auto test(int, std::false_type) -> std::false_type; template static auto test(...) -> std::false_type; public: static constexpr bool value = decltype(test( 0, typename std::is_constructible::type()))::value; }; #ifdef __GNUC__ #pragma GCC diagnostic pop #endif // Check for output streamability // Based on // https://stackoverflow.com/questions/22758291/how-can-i-detect-if-a-type-can-be-streamed-to-an-stdostream template class is_ostreamable { template static auto test(int) -> decltype(std::declval() << std::declval(), std::true_type()); template static auto test(...) -> std::false_type; public: static constexpr bool value = decltype(test(0))::value; }; /// Check for input streamability template class is_istreamable { template static auto test(int) -> decltype(std::declval() >> std::declval(), std::true_type()); template static auto test(...) -> std::false_type; public: static constexpr bool value = decltype(test(0))::value; }; /// Check for complex template class is_complex { template static auto test(int) -> decltype(std::declval().real(), std::declval().imag(), std::true_type()); template static auto test(...) -> std::false_type; public: static constexpr bool value = decltype(test(0))::value; }; /// Templated operation to get a value from a stream template ::value, detail::enabler> = detail::dummy> bool from_stream(const std::string &istring, T &obj) { std::istringstream is; is.str(istring); is >> obj; return !is.fail() && !is.rdbuf()->in_avail(); } template ::value, detail::enabler> = detail::dummy> bool from_stream(const std::string & /*istring*/, T & /*obj*/) { return false; } // check to see if an object is a mutable container (fail by default) template struct is_mutable_container : std::false_type {}; /// type trait to test if a type is a mutable container meaning it has a /// value_type, it has an iterator, a clear, and end methods and an insert /// function. And for our purposes we exclude std::string and types that can be /// constructed from a std::string template struct is_mutable_container< T, conditional_t< false, void_t().end()), decltype(std::declval().clear()), decltype(std::declval().insert( std::declval().end())>(), std::declval()))>, void>> : public conditional_t::value, std::false_type, std::true_type> {}; // check to see if an object is a mutable container (fail by default) template struct is_readable_container : std::false_type {}; /// type trait to test if a type is a container meaning it has a value_type, it /// has an iterator, a clear, and an end methods and an insert function. And /// for our purposes we exclude std::string and types that can be constructed /// from a std::string template struct is_readable_container< T, conditional_t().end()), decltype(std::declval().begin())>, void>> : public std::true_type {}; // check to see if an object is a wrapper (fail by default) template struct is_wrapper : std::false_type {}; // check if an object is a wrapper (it has a value_type defined) template struct is_wrapper, void>> : public std::true_type {}; // Check for tuple like types, as in classes with a tuple_size type trait template class is_tuple_like { template // static auto test(int) // -> decltype(std::conditional<(std::tuple_size::value > 0), // std::true_type, std::false_type>::type()); static auto test(int) -> decltype(std::tuple_size::type>::value, std::true_type{}); template static auto test(...) -> std::false_type; public: static constexpr bool value = decltype(test(0))::value; }; /// Convert an object to a string (directly forward if this can become a string) template ::value, detail::enabler> = detail::dummy> auto to_string(T &&value) -> decltype(std::forward(value)) { return std::forward(value); } /// Construct a string from the object template ::value && !std::is_convertible::value, detail::enabler> = detail::dummy> std::string to_string(const T &value) { return std::string(value); // NOLINT(google-readability-casting) } /// Convert an object to a string (streaming must be supported for that type) template ::value && !std::is_constructible::value && is_ostreamable::value, detail::enabler> = detail::dummy> std::string to_string(T &&value) { std::stringstream stream; stream << value; return stream.str(); } /// If conversion is not supported, return an empty string (streaming is not /// supported for that type) template ::value && !is_ostreamable::value && !is_readable_container< typename std::remove_const::type>::value, detail::enabler> = detail::dummy> std::string to_string(T &&) { return {}; } /// convert a readable container to a string template ::value && !is_ostreamable::value && is_readable_container::value, detail::enabler> = detail::dummy> std::string to_string(T &&variable) { auto cval = variable.begin(); auto end = variable.end(); if (cval == end) { return {"{}"}; } std::vector defaults; while (cval != end) { defaults.emplace_back(CLI::detail::to_string(*cval)); ++cval; } return {"[" + detail::join(defaults) + "]"}; } /// special template overload template < typename T1, typename T2, typename T, enable_if_t::value, detail::enabler> = detail::dummy> auto checked_to_string(T &&value) -> decltype(to_string(std::forward(value))) { return to_string(std::forward(value)); } /// special template overload template < typename T1, typename T2, typename T, enable_if_t::value, detail::enabler> = detail::dummy> std::string checked_to_string(T &&) { return std::string{}; } /// get a string as a convertible value for arithmetic types template ::value, detail::enabler> = detail::dummy> std::string value_string(const T &value) { return std::to_string(value); } /// get a string as a convertible value for enumerations template ::value, detail::enabler> = detail::dummy> std::string value_string(const T &value) { return std::to_string( static_cast::type>(value)); } /// for other types just use the regular to_string function template ::value && !std::is_arithmetic::value, detail::enabler> = detail::dummy> auto value_string(const T &value) -> decltype(to_string(value)) { return to_string(value); } /// template to get the underlying value type if it exists or use a default template struct wrapped_type { using type = def; }; /// Type size for regular object types that do not look like a tuple template struct wrapped_type::value>::type> { using type = typename T::value_type; }; /// This will only trigger for actual void type template struct type_count_base { static const int value{0}; }; /// Type size for regular object types that do not look like a tuple template struct type_count_base< T, typename std::enable_if::value && !is_mutable_container::value && !std::is_void::value>::type> { static constexpr int value{1}; }; /// the base tuple size template struct type_count_base< T, typename std::enable_if::value && !is_mutable_container::value>::type> { static constexpr int value{std::tuple_size::value}; }; /// Type count base for containers is the type_count_base of the individual /// element template struct type_count_base< T, typename std::enable_if::value>::type> { static constexpr int value{type_count_base::value}; }; /// Set of overloads to get the type size of an object /// forward declare the subtype_count structure template struct subtype_count; /// forward declare the subtype_count_min structure template struct subtype_count_min; /// This will only trigger for actual void type template struct type_count { static const int value{0}; }; /// Type size for regular object types that do not look like a tuple template struct type_count::value && !is_tuple_like::value && !is_complex::value && !std::is_void::value>::type> { static constexpr int value{1}; }; /// Type size for complex since it sometimes looks like a wrapper template struct type_count::value>::type> { static constexpr int value{2}; }; /// Type size of types that are wrappers,except complex and tuples(which can /// also be wrappers sometimes) template struct type_count< T, typename std::enable_if::value>::type> { static constexpr int value{subtype_count::value}; }; /// Type size of types that are wrappers,except containers complex and /// tuples(which can also be wrappers sometimes) template struct type_count< T, typename std::enable_if::value && !is_complex::value && !is_tuple_like::value && !is_mutable_container::value>::type> { static constexpr int value{type_count::value}; }; /// 0 if the index > tuple size template constexpr typename std::enable_if::value, int>::type tuple_type_size() { return 0; } /// Recursively generate the tuple type name template constexpr typename std::enable_if < I::value, int>::type tuple_type_size() { return subtype_count::type>::value + tuple_type_size(); } /// Get the type size of the sum of type sizes for all the individual tuple /// types template struct type_count::value>::type> { static constexpr int value{tuple_type_size()}; }; /// definition of subtype count template struct subtype_count { static constexpr int value{is_mutable_container::value ? expected_max_vector_size : type_count::value}; }; /// This will only trigger for actual void type template struct type_count_min { static const int value{0}; }; /// Type size for regular object types that do not look like a tuple template struct type_count_min< T, typename std::enable_if::value && !is_tuple_like::value && !is_wrapper::value && !is_complex::value && !std::is_void::value>::type> { static constexpr int value{type_count::value}; }; /// Type size for complex since it sometimes looks like a wrapper template struct type_count_min::value>::type> { static constexpr int value{1}; }; /// Type size min of types that are wrappers,except complex and tuples(which can /// also be wrappers sometimes) template struct type_count_min< T, typename std::enable_if::value && !is_complex::value && !is_tuple_like::value>::type> { static constexpr int value{subtype_count_min::value}; }; /// 0 if the index > tuple size template constexpr typename std::enable_if::value, int>::type tuple_type_size_min() { return 0; } /// Recursively generate the tuple type name template constexpr typename std::enable_if < I::value, int>::type tuple_type_size_min() { return subtype_count_min::type>::value + tuple_type_size_min(); } /// Get the type size of the sum of type sizes for all the individual tuple /// types template struct type_count_min::value>::type> { static constexpr int value{tuple_type_size_min()}; }; /// definition of subtype count template struct subtype_count_min { static constexpr int value{ is_mutable_container::value ? ((type_count::value < expected_max_vector_size) ? type_count::value : 0) : type_count_min::value}; }; /// This will only trigger for actual void type template struct expected_count { static const int value{0}; }; /// For most types the number of expected items is 1 template struct expected_count::value && !is_wrapper::value && !std::is_void::value>::type> { static constexpr int value{1}; }; /// number of expected items in a vector template struct expected_count< T, typename std::enable_if::value>::type> { static constexpr int value{expected_max_vector_size}; }; /// number of expected items in a vector template struct expected_count::value && is_wrapper::value>::type> { static constexpr int value{expected_count::value}; }; // Enumeration of the different supported categorizations of objects enum class object_category : int { char_value = 1, integral_value = 2, unsigned_integral = 4, enumeration = 6, boolean_value = 8, floating_point = 10, number_constructible = 12, double_constructible = 14, integer_constructible = 16, // string like types string_assignable = 23, string_constructible = 24, other = 45, // special wrapper or container types wrapper_value = 50, complex_number = 60, tuple_value = 70, container_value = 80, }; /// Set of overloads to classify an object according to type /// some type that is not otherwise recognized template struct classify_object { static constexpr object_category value{object_category::other}; }; /// Signed integers template struct classify_object< T, typename std::enable_if::value && !std::is_same::value && std::is_signed::value && !is_bool::value && !std::is_enum::value>::type> { static constexpr object_category value{object_category::integral_value}; }; /// Unsigned integers template struct classify_object< T, typename std::enable_if< std::is_integral::value && std::is_unsigned::value && !std::is_same::value && !is_bool::value>::type> { static constexpr object_category value{object_category::unsigned_integral}; }; /// single character values template struct classify_object::value && !std::is_enum::value>::type> { static constexpr object_category value{object_category::char_value}; }; /// Boolean values template struct classify_object::value>::type> { static constexpr object_category value{object_category::boolean_value}; }; /// Floats template struct classify_object< T, typename std::enable_if::value>::type> { static constexpr object_category value{object_category::floating_point}; }; /// String and similar direct assignment template struct classify_object< T, typename std::enable_if< !std::is_floating_point::value && !std::is_integral::value && std::is_assignable::value>::type> { static constexpr object_category value{object_category::string_assignable}; }; /// String and similar constructible and copy assignment template struct classify_object< T, typename std::enable_if< !std::is_floating_point::value && !std::is_integral::value && !std::is_assignable::value && (type_count::value == 1) && std::is_constructible::value>::type> { static constexpr object_category value{object_category::string_constructible}; }; /// Enumerations template struct classify_object::value>::type> { static constexpr object_category value{object_category::enumeration}; }; template struct classify_object::value>::type> { static constexpr object_category value{object_category::complex_number}; }; /// Handy helper to contain a bunch of checks that rule out many common types /// (integers, string like, floating point, vectors, and enumerations template struct uncommon_type { using type = typename std::conditional< !std::is_floating_point::value && !std::is_integral::value && !std::is_assignable::value && !std::is_constructible::value && !is_complex::value && !is_mutable_container::value && !std::is_enum::value, std::true_type, std::false_type>::type; static constexpr bool value = type::value; }; /// wrapper type template struct classify_object< T, typename std::enable_if<( !is_mutable_container::value && is_wrapper::value && !is_tuple_like::value && uncommon_type::value)>::type> { static constexpr object_category value{object_category::wrapper_value}; }; /// Assignable from double or int template struct classify_object< T, typename std::enable_if< uncommon_type::value && type_count::value == 1 && !is_wrapper::value && is_direct_constructible::value && is_direct_constructible::value>::type> { static constexpr object_category value{object_category::number_constructible}; }; /// Assignable from int template struct classify_object< T, typename std::enable_if< uncommon_type::value && type_count::value == 1 && !is_wrapper::value && !is_direct_constructible::value && is_direct_constructible::value>::type> { static constexpr object_category value{ object_category::integer_constructible}; }; /// Assignable from double template struct classify_object< T, typename std::enable_if< uncommon_type::value && type_count::value == 1 && !is_wrapper::value && is_direct_constructible::value && !is_direct_constructible::value>::type> { static constexpr object_category value{object_category::double_constructible}; }; /// Tuple type template struct classify_object< T, typename std::enable_if< is_tuple_like::value && ((type_count::value >= 2 && !is_wrapper::value) || (uncommon_type::value && !is_direct_constructible::value && !is_direct_constructible::value) || (uncommon_type::value && type_count::value >= 2))>::type> { static constexpr object_category value{object_category::tuple_value}; // the condition on this class requires it be like a tuple, but on some // compilers (like Xcode) tuples can be constructed from just the first // element so tuples of can be constructed from a string, // which could lead to issues so there are two variants of the condition, the // first isolates things with a type size >=2 mainly to get tuples on Xcode // with the exception of wrappers, the second is the main one and just // separating out those cases that are caught by other object classifications }; /// container type template struct classify_object< T, typename std::enable_if::value>::type> { static constexpr object_category value{object_category::container_value}; }; // Type name print /// Was going to be based on /// http://stackoverflow.com/questions/1055452/c-get-name-of-type-in-template /// But this is cleaner and works better in this case template ::value == object_category::char_value, detail::enabler> = detail::dummy> constexpr const char *type_name() { return "CHAR"; } template < typename T, enable_if_t::value == object_category::integral_value || classify_object::value == object_category::integer_constructible, detail::enabler> = detail::dummy> constexpr const char *type_name() { return "INT"; } template ::value == object_category::unsigned_integral, detail::enabler> = detail::dummy> constexpr const char *type_name() { return "UINT"; } template < typename T, enable_if_t::value == object_category::floating_point || classify_object::value == object_category::number_constructible || classify_object::value == object_category::double_constructible, detail::enabler> = detail::dummy> constexpr const char *type_name() { return "FLOAT"; } /// Print name for enumeration types template ::value == object_category::enumeration, detail::enabler> = detail::dummy> constexpr const char *type_name() { return "ENUM"; } /// Print name for enumeration types template ::value == object_category::boolean_value, detail::enabler> = detail::dummy> constexpr const char *type_name() { return "BOOLEAN"; } /// Print name for enumeration types template ::value == object_category::complex_number, detail::enabler> = detail::dummy> constexpr const char *type_name() { return "COMPLEX"; } /// Print for all other types template ::value >= object_category::string_assignable && classify_object::value <= object_category::other, detail::enabler> = detail::dummy> constexpr const char *type_name() { return "TEXT"; } /// typename for tuple value template ::value == object_category::tuple_value && type_count_base::value >= 2, detail::enabler> = detail::dummy> std::string type_name(); // forward declaration /// Generate type name for a wrapper or container value template < typename T, enable_if_t::value == object_category::container_value || classify_object::value == object_category::wrapper_value, detail::enabler> = detail::dummy> std::string type_name(); // forward declaration /// Print name for single element tuple types template ::value == object_category::tuple_value && type_count_base::value == 1, detail::enabler> = detail::dummy> inline std::string type_name() { return type_name< typename std::decay::type>::type>(); } /// Empty string if the index > tuple size template inline typename std::enable_if::value, std::string>::type tuple_name() { return std::string{}; } /// Recursively generate the tuple type name template inline typename std::enable_if<(I < type_count_base::value), std::string>::type tuple_name() { auto str = std::string{type_name::type>::type>()} + ',' + tuple_name(); if (str.back() == ',') str.pop_back(); return str; } /// Print type name for tuples with 2 or more elements template ::value == object_category::tuple_value && type_count_base::value >= 2, detail::enabler>> inline std::string type_name() { auto tname = std::string(1, '[') + tuple_name(); tname.push_back(']'); return tname; } /// get the type name for a type that has a value_type member template < typename T, enable_if_t::value == object_category::container_value || classify_object::value == object_category::wrapper_value, detail::enabler>> inline std::string type_name() { return type_name(); } // Lexical cast /// Convert to an unsigned integral template ::value, detail::enabler> = detail::dummy> bool integral_conversion(const std::string &input, T &output) noexcept { if (input.empty() || input.front() == '-') { return false; } char *val = nullptr; errno = 0; std::uint64_t output_ll = std::strtoull(input.c_str(), &val, 0); if (errno == ERANGE) { return false; } output = static_cast(output_ll); if (val == (input.c_str() + input.size()) && static_cast(output) == output_ll) { return true; } val = nullptr; std::int64_t output_sll = std::strtoll(input.c_str(), &val, 0); if (val == (input.c_str() + input.size())) { output = (output_sll < 0) ? static_cast(0) : static_cast(output_sll); return (static_cast(output) == output_sll); } return false; } /// Convert to a signed integral template ::value, detail::enabler> = detail::dummy> bool integral_conversion(const std::string &input, T &output) noexcept { if (input.empty()) { return false; } char *val = nullptr; errno = 0; std::int64_t output_ll = std::strtoll(input.c_str(), &val, 0); if (errno == ERANGE) { return false; } output = static_cast(output_ll); if (val == (input.c_str() + input.size()) && static_cast(output) == output_ll) { return true; } if (input == "true") { // this is to deal with a few oddities with flags and wrapper int types output = static_cast(1); return true; } return false; } /// Convert a flag into an integer value typically binary flags inline std::int64_t to_flag_value(std::string val) { static const std::string trueString("true"); static const std::string falseString("false"); if (val == trueString) { return 1; } if (val == falseString) { return -1; } val = detail::to_lower(val); std::int64_t ret = 0; if (val.size() == 1) { if (val[0] >= '1' && val[0] <= '9') { return (static_cast(val[0]) - '0'); } switch (val[0]) { case '0': case 'f': case 'n': case '-': ret = -1; break; case 't': case 'y': case '+': ret = 1; break; default: throw std::invalid_argument("unrecognized character"); } return ret; } if (val == trueString || val == "on" || val == "yes" || val == "enable") { ret = 1; } else if (val == falseString || val == "off" || val == "no" || val == "disable") { ret = -1; } else { ret = std::stoll(val); } return ret; } /// Integer conversion template ::value == object_category::integral_value || classify_object::value == object_category::unsigned_integral, detail::enabler> = detail::dummy> bool lexical_cast(const std::string &input, T &output) { return integral_conversion(input, output); } /// char values template ::value == object_category::char_value, detail::enabler> = detail::dummy> bool lexical_cast(const std::string &input, T &output) { if (input.size() == 1) { output = static_cast(input[0]); return true; } return integral_conversion(input, output); } /// Boolean values template ::value == object_category::boolean_value, detail::enabler> = detail::dummy> bool lexical_cast(const std::string &input, T &output) { try { auto out = to_flag_value(input); output = (out > 0); return true; } catch (const std::invalid_argument &) { return false; } catch (const std::out_of_range &) { // if the number is out of the range of a 64 bit value then it is still a // number and for this purpose is still valid all we care about the sign output = (input[0] != '-'); return true; } } /// Floats template ::value == object_category::floating_point, detail::enabler> = detail::dummy> bool lexical_cast(const std::string &input, T &output) { if (input.empty()) { return false; } char *val = nullptr; auto output_ld = std::strtold(input.c_str(), &val); output = static_cast(output_ld); return val == (input.c_str() + input.size()); } /// complex template ::value == object_category::complex_number, detail::enabler> = detail::dummy> bool lexical_cast(const std::string &input, T &output) { using XC = typename wrapped_type::type; XC x{0.0}, y{0.0}; auto str1 = input; bool worked = false; auto nloc = str1.find_last_of("+-"); if (nloc != std::string::npos && nloc > 0) { worked = lexical_cast(str1.substr(0, nloc), x); str1 = str1.substr(nloc); if (str1.back() == 'i' || str1.back() == 'j') str1.pop_back(); worked = worked && lexical_cast(str1, y); } else { if (str1.back() == 'i' || str1.back() == 'j') { str1.pop_back(); worked = lexical_cast(str1, y); x = XC{0}; } else { worked = lexical_cast(str1, x); y = XC{0}; } } if (worked) { output = T{x, y}; return worked; } return from_stream(input, output); } /// String and similar direct assignment template ::value == object_category::string_assignable, detail::enabler> = detail::dummy> bool lexical_cast(const std::string &input, T &output) { output = input; return true; } /// String and similar constructible and copy assignment template ::value == object_category::string_constructible, detail::enabler> = detail::dummy> bool lexical_cast(const std::string &input, T &output) { output = T(input); return true; } /// Enumerations template ::value == object_category::enumeration, detail::enabler> = detail::dummy> bool lexical_cast(const std::string &input, T &output) { typename std::underlying_type::type val; if (!integral_conversion(input, val)) { return false; } output = static_cast(val); return true; } /// wrapper types template < typename T, enable_if_t::value == object_category::wrapper_value && std::is_assignable::value, detail::enabler> = detail::dummy> bool lexical_cast(const std::string &input, T &output) { typename T::value_type val; if (lexical_cast(input, val)) { output = val; return true; } return from_stream(input, output); } template < typename T, enable_if_t::value == object_category::wrapper_value && !std::is_assignable::value && std::is_assignable::value, detail::enabler> = detail::dummy> bool lexical_cast(const std::string &input, T &output) { typename T::value_type val; if (lexical_cast(input, val)) { output = T{val}; return true; } return from_stream(input, output); } /// Assignable from double or int template ::value == object_category::number_constructible, detail::enabler> = detail::dummy> bool lexical_cast(const std::string &input, T &output) { int val = 0; if (integral_conversion(input, val)) { output = T(val); return true; } double dval = 0.0; if (lexical_cast(input, dval)) { output = T{dval}; return true; } return from_stream(input, output); } /// Assignable from int template ::value == object_category::integer_constructible, detail::enabler> = detail::dummy> bool lexical_cast(const std::string &input, T &output) { int val = 0; if (integral_conversion(input, val)) { output = T(val); return true; } return from_stream(input, output); } /// Assignable from double template ::value == object_category::double_constructible, detail::enabler> = detail::dummy> bool lexical_cast(const std::string &input, T &output) { double val = 0.0; if (lexical_cast(input, val)) { output = T{val}; return true; } return from_stream(input, output); } /// Non-string convertible from an int template ::value == object_category::other && std::is_assignable::value, detail::enabler> = detail::dummy> bool lexical_cast(const std::string &input, T &output) { int val = 0; if (integral_conversion(input, val)) { #ifdef _MSC_VER #pragma warning(push) #pragma warning(disable : 4800) #endif // with Atomic this could produce a warning due to the conversion but if // atomic gets here it is an old style so will most likely still work output = val; #ifdef _MSC_VER #pragma warning(pop) #endif return true; } // LCOV_EXCL_START // This version of cast is only used for odd cases in an older compilers the // fail over from_stream is tested elsewhere an not relevant for coverage here return from_stream(input, output); // LCOV_EXCL_STOP } /// Non-string parsable by a stream template ::value == object_category::other && !std::is_assignable::value, detail::enabler> = detail::dummy> bool lexical_cast(const std::string &input, T &output) { static_assert(is_istreamable::value, "option object type must have a lexical cast overload or " "streaming input operator(>>) defined, if it " "is convertible from another type use the add_option(...) with XC being the known type"); return from_stream(input, output); } /// Assign a value through lexical cast operations /// Strings can be empty so we need to do a little different template ::value && (classify_object::value == object_category::string_assignable || classify_object::value == object_category::string_constructible), detail::enabler> = detail::dummy> bool lexical_assign(const std::string &input, AssignTo &output) { return lexical_cast(input, output); } /// Assign a value through lexical cast operations template ::value && std::is_assignable::value && classify_object::value != object_category::string_assignable && classify_object::value != object_category::string_constructible, detail::enabler> = detail::dummy> bool lexical_assign(const std::string &input, AssignTo &output) { if (input.empty()) { output = AssignTo{}; return true; } return lexical_cast(input, output); } /// Assign a value through lexical cast operations template ::value && !std::is_assignable::value && classify_object::value == object_category::wrapper_value, detail::enabler> = detail::dummy> bool lexical_assign(const std::string &input, AssignTo &output) { if (input.empty()) { typename AssignTo::value_type emptyVal{}; output = emptyVal; return true; } return lexical_cast(input, output); } /// Assign a value through lexical cast operations for int compatible values /// mainly for atomic operations on some compilers template ::value && !std::is_assignable::value && classify_object::value != object_category::wrapper_value && std::is_assignable::value, detail::enabler> = detail::dummy> bool lexical_assign(const std::string &input, AssignTo &output) { if (input.empty()) { output = 0; return true; } int val = 0; if (lexical_cast(input, val)) { output = val; return true; } return false; } /// Assign a value converted from a string in lexical cast to the output value /// directly template ::value && std::is_assignable::value, detail::enabler> = detail::dummy> bool lexical_assign(const std::string &input, AssignTo &output) { ConvertTo val{}; bool parse_result = (!input.empty()) ? lexical_cast(input, val) : true; if (parse_result) { output = val; } return parse_result; } /// Assign a value from a lexical cast through constructing a value and move /// assigning it template ::value && !std::is_assignable::value && std::is_move_assignable::value, detail::enabler> = detail::dummy> bool lexical_assign(const std::string &input, AssignTo &output) { ConvertTo val{}; bool parse_result = input.empty() ? true : lexical_cast(input, val); if (parse_result) { output = AssignTo( val); // use () form of constructor to allow some implicit conversions } return parse_result; } /// primary lexical conversion operation, 1 string to 1 type of some kind template < typename AssignTo, typename ConvertTo, enable_if_t::value <= object_category::other && classify_object::value <= object_category::wrapper_value, detail::enabler> = detail::dummy> bool lexical_conversion(const std::vector &strings, AssignTo &output) { return lexical_assign(strings[0], output); } /// Lexical conversion if there is only one element but the conversion type is /// for two, then call a two element constructor template ::value <= 2) && expected_count::value == 1 && is_tuple_like::value && type_count_base::value == 2, detail::enabler> = detail::dummy> bool lexical_conversion(const std::vector &strings, AssignTo &output) { // the remove const is to handle pair types coming from a container typename std::remove_const< typename std::tuple_element<0, ConvertTo>::type>::type v1; typename std::tuple_element<1, ConvertTo>::type v2; bool retval = lexical_assign(strings[0], v1); if (strings.size() > 1) { retval = retval && lexical_assign(strings[1], v2); } if (retval) { output = AssignTo{v1, v2}; } return retval; } /// Lexical conversion of a container types of single elements template ::value && is_mutable_container::value && type_count::value == 1, detail::enabler> = detail::dummy> bool lexical_conversion(const std::vector &strings, AssignTo &output) { output.erase(output.begin(), output.end()); if (strings.size() == 1 && strings[0] == "{}") { return true; } bool skip_remaining = false; if (strings.size() == 2 && strings[0] == "{}" && is_separator(strings[1])) { skip_remaining = true; } for (const auto &elem : strings) { typename AssignTo::value_type out; bool retval = lexical_assign(elem, out); if (!retval) { return false; } output.insert(output.end(), std::move(out)); if (skip_remaining) { break; } } return (!output.empty()); } /// Lexical conversion for complex types template < class AssignTo, class ConvertTo, enable_if_t::value, detail::enabler> = detail::dummy> bool lexical_conversion(const std::vector &strings, AssignTo &output) { if (strings.size() >= 2 && !strings[1].empty()) { using XC2 = typename wrapped_type::type; XC2 x{0.0}, y{0.0}; auto str1 = strings[1]; if (str1.back() == 'i' || str1.back() == 'j') { str1.pop_back(); } auto worked = lexical_cast(strings[0], x) && lexical_cast(str1, y); if (worked) { output = ConvertTo{x, y}; } return worked; } return lexical_assign(strings[0], output); } /// Conversion to a vector type using a particular single type as the conversion /// type template ::value && (expected_count::value == 1) && (type_count::value == 1), detail::enabler> = detail::dummy> bool lexical_conversion(const std::vector &strings, AssignTo &output) { bool retval = true; output.clear(); output.reserve(strings.size()); for (const auto &elem : strings) { output.emplace_back(); retval = retval && lexical_assign( elem, output.back()); } return (!output.empty()) && retval; } // forward declaration /// Lexical conversion of a container types with conversion type of two elements template ::value && is_mutable_container::value && type_count_base::value == 2, detail::enabler> = detail::dummy> bool lexical_conversion(std::vector strings, AssignTo &output); /// Lexical conversion of a vector types with type_size >2 forward declaration template ::value && is_mutable_container::value && type_count_base::value != 2 && ((type_count::value > 2) || (type_count::value > type_count_base::value)), detail::enabler> = detail::dummy> bool lexical_conversion(const std::vector &strings, AssignTo &output); /// Conversion for tuples template ::value && is_tuple_like::value && (type_count_base::value != type_count::value || type_count::value > 2), detail::enabler> = detail::dummy> bool lexical_conversion(const std::vector &strings, AssignTo &output); // forward declaration /// Conversion for operations where the assigned type is some class but the /// conversion is a mutable container or large tuple template ::value && !is_mutable_container::value && classify_object::value != object_category::wrapper_value && (is_mutable_container::value || type_count::value > 2), detail::enabler> = detail::dummy> bool lexical_conversion(const std::vector &strings, AssignTo &output) { if (strings.size() > 1 || (!strings.empty() && !(strings.front().empty()))) { ConvertTo val; auto retval = lexical_conversion(strings, val); output = AssignTo{val}; return retval; } output = AssignTo{}; return true; } /// function template for converting tuples if the static Index is greater than /// the tuple size template inline typename std::enable_if<(I >= type_count_base::value), bool>::type tuple_conversion(const std::vector &, AssignTo &) { return true; } /// Conversion of a tuple element where the type size ==1 and not a mutable /// container template inline typename std::enable_if::value && type_count::value == 1, bool>::type tuple_type_conversion(std::vector &strings, AssignTo &output) { auto retval = lexical_assign(strings[0], output); strings.erase(strings.begin()); return retval; } /// Conversion of a tuple element where the type size !=1 but the size is fixed /// and not a mutable container template inline typename std::enable_if::value && (type_count::value > 1) && type_count::value == type_count_min::value, bool>::type tuple_type_conversion(std::vector &strings, AssignTo &output) { auto retval = lexical_conversion(strings, output); strings.erase(strings.begin(), strings.begin() + type_count::value); return retval; } /// Conversion of a tuple element where the type is a mutable container or a /// type with different min and max type sizes template inline typename std::enable_if::value || type_count::value != type_count_min::value, bool>::type tuple_type_conversion(std::vector &strings, AssignTo &output) { std::size_t index{subtype_count_min::value}; const std::size_t mx_count{subtype_count::value}; const std::size_t mx{(std::max)(mx_count, strings.size())}; while (index < mx) { if (is_separator(strings[index])) { break; } ++index; } bool retval = lexical_conversion( std::vector( strings.begin(), strings.begin() + static_cast(index)), output); strings.erase(strings.begin(), strings.begin() + static_cast(index) + 1); return retval; } /// Tuple conversion operation template inline typename std::enable_if<(I < type_count_base::value), bool>::type tuple_conversion(std::vector strings, AssignTo &output) { bool retval = true; using ConvertToElement = typename std::conditional::value, typename std::tuple_element::type, ConvertTo>::type; if (!strings.empty()) { retval = retval && tuple_type_conversion::type, ConvertToElement>(strings, std::get(output)); } retval = retval && tuple_conversion( std::move(strings), output); return retval; } /// Lexical conversion of a container types with tuple elements of size 2 template ::value && is_mutable_container::value && type_count_base::value == 2, detail::enabler>> bool lexical_conversion(std::vector strings, AssignTo &output) { output.clear(); while (!strings.empty()) { typename std::remove_const::type>::type v1; typename std::tuple_element<1, typename ConvertTo::value_type>::type v2; bool retval = tuple_type_conversion(strings, v1); if (!strings.empty()) { retval = retval && tuple_type_conversion(strings, v2); } if (retval) { output.insert(output.end(), typename AssignTo::value_type{v1, v2}); } else { return false; } } return (!output.empty()); } /// lexical conversion of tuples with type count>2 or tuples of types of some /// element with a type size>=2 template ::value && is_tuple_like::value && (type_count_base::value != type_count::value || type_count::value > 2), detail::enabler>> bool lexical_conversion(const std::vector &strings, AssignTo &output) { static_assert( !is_tuple_like::value || type_count_base::value == type_count_base::value, "if the conversion type is defined as a tuple it must be the same size " "as the type you are converting to"); return tuple_conversion(strings, output); } /// Lexical conversion of a vector types for everything but tuples of two /// elements and types of size 1 template ::value && is_mutable_container::value && type_count_base::value != 2 && ((type_count::value > 2) || (type_count::value > type_count_base::value)), detail::enabler>> bool lexical_conversion(const std::vector &strings, AssignTo &output) { bool retval = true; output.clear(); std::vector temp; std::size_t ii{0}; std::size_t icount{0}; std::size_t xcm{type_count::value}; auto ii_max = strings.size(); while (ii < ii_max) { temp.push_back(strings[ii]); ++ii; ++icount; if (icount == xcm || is_separator(temp.back()) || ii == ii_max) { if (static_cast(xcm) > type_count_min::value && is_separator(temp.back())) { temp.pop_back(); } typename AssignTo::value_type temp_out; retval = retval && lexical_conversion( temp, temp_out); temp.clear(); if (!retval) { return false; } output.insert(output.end(), std::move(temp_out)); icount = 0; } } return retval; } /// conversion for wrapper types template ::value == object_category::wrapper_value && std::is_assignable::value, detail::enabler> = detail::dummy> bool lexical_conversion(const std::vector &strings, AssignTo &output) { if (strings.empty() || strings.front().empty()) { output = ConvertTo{}; return true; } typename ConvertTo::value_type val; if (lexical_conversion(strings, val)) { output = ConvertTo{val}; return true; } return false; } /// conversion for wrapper types template ::value == object_category::wrapper_value && !std::is_assignable::value, detail::enabler> = detail::dummy> bool lexical_conversion(const std::vector &strings, AssignTo &output) { using ConvertType = typename ConvertTo::value_type; if (strings.empty() || strings.front().empty()) { output = ConvertType{}; return true; } ConvertType val; if (lexical_conversion(strings, val)) { output = val; return true; } return false; } /// Sum a vector of strings inline std::string sum_string_vector(const std::vector &values) { double val{0.0}; bool fail{false}; std::string output; for (const auto &arg : values) { double tv{0.0}; auto comp = lexical_cast(arg, tv); if (!comp) { try { tv = static_cast(detail::to_flag_value(arg)); } catch (const std::exception &) { fail = true; break; } } val += tv; } if (fail) { for (const auto &arg : values) { output.append(arg); } } else { if (val <= static_cast((std::numeric_limits::min)()) || val >= static_cast((std::numeric_limits::max)()) || std::ceil(val) == std::floor(val)) { output = detail::value_string(static_cast(val)); } else { output = detail::value_string(val); } } return output; } } // namespace detail namespace detail { // Returns false if not a short option. Otherwise, sets opt name and rest and // returns true CLI11_INLINE bool split_short(const std::string ¤t, std::string &name, std::string &rest); // Returns false if not a long option. Otherwise, sets opt name and other side // of = and returns true CLI11_INLINE bool split_long(const std::string ¤t, std::string &name, std::string &value); // Returns false if not a windows style option. Otherwise, sets opt name and // value and returns true CLI11_INLINE bool split_windows_style(const std::string ¤t, std::string &name, std::string &value); // Splits a string into multiple long and short names CLI11_INLINE std::vector split_names(std::string current); /// extract default flag values either {def} or starting with a ! CLI11_INLINE std::vector> get_default_flag_values(const std::string &str); /// Get a vector of short names, one of long names, and a single name CLI11_INLINE std::tuple, std::vector, std::string> get_names(const std::vector &input); } // namespace detail namespace detail { CLI11_INLINE bool split_short(const std::string ¤t, std::string &name, std::string &rest) { if (current.size() > 1 && current[0] == '-' && valid_first_char(current[1])) { name = current.substr(1, 1); rest = current.substr(2); return true; } return false; } CLI11_INLINE bool split_long(const std::string ¤t, std::string &name, std::string &value) { if (current.size() > 2 && current.substr(0, 2) == "--" && valid_first_char(current[2])) { auto loc = current.find_first_of('='); if (loc != std::string::npos) { name = current.substr(2, loc - 2); value = current.substr(loc + 1); } else { name = current.substr(2); value = ""; } return true; } return false; } CLI11_INLINE bool split_windows_style(const std::string ¤t, std::string &name, std::string &value) { if (current.size() > 1 && current[0] == '/' && valid_first_char(current[1])) { auto loc = current.find_first_of(':'); if (loc != std::string::npos) { name = current.substr(1, loc - 1); value = current.substr(loc + 1); } else { name = current.substr(1); value = ""; } return true; } return false; } CLI11_INLINE std::vector split_names(std::string current) { std::vector output; std::size_t val = 0; while ((val = current.find(',')) != std::string::npos) { output.push_back(trim_copy(current.substr(0, val))); current = current.substr(val + 1); } output.push_back(trim_copy(current)); return output; } CLI11_INLINE std::vector> get_default_flag_values(const std::string &str) { std::vector flags = split_names(str); flags.erase( std::remove_if(flags.begin(), flags.end(), [](const std::string &name) { return ( (name.empty()) || (!(((name.find_first_of('{') != std::string::npos) && (name.back() == '}')) || (name[0] == '!')))); }), flags.end()); std::vector> output; output.reserve(flags.size()); for (auto &flag : flags) { auto def_start = flag.find_first_of('{'); std::string defval = "false"; if ((def_start != std::string::npos) && (flag.back() == '}')) { defval = flag.substr(def_start + 1); defval.pop_back(); flag.erase( def_start, std::string::npos); // NOLINT(readability-suspicious-call-argument) } flag.erase(0, flag.find_first_not_of("-!")); output.emplace_back(flag, defval); } return output; } CLI11_INLINE std::tuple, std::vector, std::string> get_names(const std::vector &input) { std::vector short_names; std::vector long_names; std::string pos_name; for (std::string name : input) { if (name.length() == 0) { continue; } if (name.length() > 1 && name[0] == '-' && name[1] != '-') { if (name.length() == 2 && valid_first_char(name[1])) short_names.emplace_back(1, name[1]); else throw BadNameString::OneCharName(name); } else if (name.length() > 2 && name.substr(0, 2) == "--") { name = name.substr(2); if (valid_name_string(name)) long_names.push_back(name); else throw BadNameString::BadLongName(name); } else if (name == "-" || name == "--") { throw BadNameString::DashesOnly(name); } else { if (pos_name.length() > 0) throw BadNameString::MultiPositionalNames(name); pos_name = name; } } return std::make_tuple(short_names, long_names, pos_name); } } // namespace detail class App; /// Holds values to load into Options struct ConfigItem { /// This is the list of parents std::vector parents{}; /// This is the name std::string name{}; /// Listing of inputs std::vector inputs{}; /// The list of parents and name joined by "." CLI11_NODISCARD std::string fullname() const { std::vector tmp = parents; tmp.emplace_back(name); return detail::join(tmp, "."); } }; /// This class provides a converter for configuration files. class Config { protected: std::vector items{}; public: /// Convert an app into a configuration virtual std::string to_config(const App *, bool, bool, std::string) const = 0; /// Convert a configuration into an app virtual std::vector from_config(std::istream &) const = 0; /// Get a flag value CLI11_NODISCARD virtual std::string to_flag(const ConfigItem &item) const { if (item.inputs.size() == 1) { return item.inputs.at(0); } if (item.inputs.empty()) { return "{}"; } throw ConversionError::TooManyInputsFlag( item.fullname()); // LCOV_EXCL_LINE } /// Parse a config file, throw an error (ParseError:ConfigParseError or /// FileError) on failure CLI11_NODISCARD std::vector from_file( const std::string &name) const { std::ifstream input{name}; if (!input.good()) throw FileError::Missing(name); return from_config(input); } /// Virtual destructor virtual ~Config() = default; }; /// This converter works with INI/TOML files; to write INI files use ConfigINI class ConfigBase : public Config { protected: /// the character used for comments char commentChar = '#'; /// the character used to start an array '\0' is a default to not use char arrayStart = '['; /// the character used to end an array '\0' is a default to not use char arrayEnd = ']'; /// the character used to separate elements in an array char arraySeparator = ','; /// the character used separate the name from the value char valueDelimiter = '='; /// the character to use around strings char stringQuote = '"'; /// the character to use around single characters char characterQuote = '\''; /// the maximum number of layers to allow uint8_t maximumLayers{255}; /// the separator used to separator parent layers char parentSeparatorChar{'.'}; /// Specify the configuration index to use for arrayed sections int16_t configIndex{-1}; /// Specify the configuration section that should be used std::string configSection{}; public: std::string to_config(const App * /*app*/, bool default_also, bool write_description, std::string prefix) const override; std::vector from_config(std::istream &input) const override; /// Specify the configuration for comment characters ConfigBase *comment(char cchar) { commentChar = cchar; return this; } /// Specify the start and end characters for an array ConfigBase *arrayBounds(char aStart, char aEnd) { arrayStart = aStart; arrayEnd = aEnd; return this; } /// Specify the delimiter character for an array ConfigBase *arrayDelimiter(char aSep) { arraySeparator = aSep; return this; } /// Specify the delimiter between a name and value ConfigBase *valueSeparator(char vSep) { valueDelimiter = vSep; return this; } /// Specify the quote characters used around strings and characters ConfigBase *quoteCharacter(char qString, char qChar) { stringQuote = qString; characterQuote = qChar; return this; } /// Specify the maximum number of parents ConfigBase *maxLayers(uint8_t layers) { maximumLayers = layers; return this; } /// Specify the separator to use for parent layers ConfigBase *parentSeparator(char sep) { parentSeparatorChar = sep; return this; } /// get a reference to the configuration section std::string §ionRef() { return configSection; } /// get the section CLI11_NODISCARD const std::string §ion() const { return configSection; } /// specify a particular section of the configuration file to use ConfigBase *section(const std::string §ionName) { configSection = sectionName; return this; } /// get a reference to the configuration index int16_t &indexRef() { return configIndex; } /// get the section index CLI11_NODISCARD int16_t index() const { return configIndex; } /// specify a particular index in the section to use (-1) for all sections to /// use ConfigBase *index(int16_t sectionIndex) { configIndex = sectionIndex; return this; } }; /// the default Config is the TOML file format using ConfigTOML = ConfigBase; /// ConfigINI generates a "standard" INI compliant output class ConfigINI : public ConfigTOML { public: ConfigINI() { commentChar = ';'; arrayStart = '\0'; arrayEnd = '\0'; arraySeparator = ' '; valueDelimiter = '='; } }; class Option; /// @defgroup validator_group Validators /// @brief Some validators that are provided /// /// These are simple `std::string(const std::string&)` validators that are /// useful. They return a string if the validation fails. A custom struct is /// provided, as well, with the same user semantics, but with the ability to /// provide a new type name. /// @{ /// class Validator { protected: /// This is the description function, if empty the description_ will be used std::function desc_function_{[]() { return std::string{}; }}; /// This is the base function that is to be called. /// Returns a string error message if validation fails. std::function func_{ [](std::string &) { return std::string{}; }}; /// The name for search purposes of the Validator std::string name_{}; /// A Validator will only apply to an indexed value (-1 is all elements) int application_index_ = -1; /// Enable for Validator to allow it to be disabled if need be bool active_{true}; /// specify that a validator should not modify the input bool non_modifying_{false}; Validator(std::string validator_desc, std::function func) : desc_function_([validator_desc]() { return validator_desc; }), func_(std::move(func)) {} public: Validator() = default; /// Construct a Validator with just the description string explicit Validator(std::string validator_desc) : desc_function_([validator_desc]() { return validator_desc; }) {} /// Construct Validator from basic information Validator(std::function op, std::string validator_desc, std::string validator_name = "") : desc_function_([validator_desc]() { return validator_desc; }), func_(std::move(op)), name_(std::move(validator_name)) {} /// Set the Validator operation function Validator &operation(std::function op) { func_ = std::move(op); return *this; } /// This is the required operator for a Validator - provided to help /// users (CLI11 uses the member `func` directly) std::string operator()(std::string &str) const; /// This is the required operator for a Validator - provided to help /// users (CLI11 uses the member `func` directly) std::string operator()(const std::string &str) const { std::string value = str; return (active_) ? func_(value) : std::string{}; } /// Specify the type string Validator &description(std::string validator_desc) { desc_function_ = [validator_desc]() { return validator_desc; }; return *this; } /// Specify the type string CLI11_NODISCARD Validator description(std::string validator_desc) const; /// Generate type description information for the Validator CLI11_NODISCARD std::string get_description() const { if (active_) { return desc_function_(); } return std::string{}; } /// Specify the type string Validator &name(std::string validator_name) { name_ = std::move(validator_name); return *this; } /// Specify the type string CLI11_NODISCARD Validator name(std::string validator_name) const { Validator newval(*this); newval.name_ = std::move(validator_name); return newval; } /// Get the name of the Validator CLI11_NODISCARD const std::string &get_name() const { return name_; } /// Specify whether the Validator is active or not Validator &active(bool active_val = true) { active_ = active_val; return *this; } /// Specify whether the Validator is active or not CLI11_NODISCARD Validator active(bool active_val = true) const { Validator newval(*this); newval.active_ = active_val; return newval; } /// Specify whether the Validator can be modifying or not Validator &non_modifying(bool no_modify = true) { non_modifying_ = no_modify; return *this; } /// Specify the application index of a validator Validator &application_index(int app_index) { application_index_ = app_index; return *this; } /// Specify the application index of a validator CLI11_NODISCARD Validator application_index(int app_index) const { Validator newval(*this); newval.application_index_ = app_index; return newval; } /// Get the current value of the application index CLI11_NODISCARD int get_application_index() const { return application_index_; } /// Get a boolean if the validator is active CLI11_NODISCARD bool get_active() const { return active_; } /// Get a boolean if the validator is allowed to modify the input returns true /// if it can modify the input CLI11_NODISCARD bool get_modifying() const { return !non_modifying_; } /// Combining validators is a new validator. Type comes from left validator if /// function, otherwise only set if the same. Validator operator&(const Validator &other) const; /// Combining validators is a new validator. Type comes from left validator if /// function, otherwise only set if the same. Validator operator|(const Validator &other) const; /// Create a validator that fails when a given validator succeeds Validator operator!() const; private: void _merge_description(const Validator &val1, const Validator &val2, const std::string &merger); }; /// Class wrapping some of the accessors of Validator class CustomValidator : public Validator { public: }; // The implementation of the built in validators is using the Validator class; // the user is only expected to use the const (static) versions (since there's // no setup). Therefore, this is in detail. namespace detail { /// CLI enumeration of different file types enum class path_type { nonexistent, file, directory }; /// get the type of the path from a file name CLI11_INLINE path_type check_path(const char *file) noexcept; /// Check for an existing file (returns error message if check fails) class ExistingFileValidator : public Validator { public: ExistingFileValidator(); }; /// Check for an existing directory (returns error message if check fails) class ExistingDirectoryValidator : public Validator { public: ExistingDirectoryValidator(); }; /// Check for an existing path class ExistingPathValidator : public Validator { public: ExistingPathValidator(); }; /// Check for an non-existing path class NonexistentPathValidator : public Validator { public: NonexistentPathValidator(); }; /// Validate the given string is a legal ipv4 address class IPV4Validator : public Validator { public: IPV4Validator(); }; } // namespace detail // Static is not needed here, because global const implies static. /// Check for existing file (returns error message if check fails) const detail::ExistingFileValidator ExistingFile; /// Check for an existing directory (returns error message if check fails) const detail::ExistingDirectoryValidator ExistingDirectory; /// Check for an existing path const detail::ExistingPathValidator ExistingPath; /// Check for an non-existing path const detail::NonexistentPathValidator NonexistentPath; /// Check for an IP4 address const detail::IPV4Validator ValidIPV4; /// Validate the input as a particular type template class TypeValidator : public Validator { public: explicit TypeValidator(const std::string &validator_name) : Validator(validator_name, [](std::string &input_string) { using CLI::detail::lexical_cast; auto val = DesiredType(); if (!lexical_cast(input_string, val)) { return std::string("Failed parsing ") + input_string + " as a " + detail::type_name(); } return std::string(); }) {} TypeValidator() : TypeValidator(detail::type_name()) {} }; /// Check for a number const TypeValidator Number("NUMBER"); /// Modify a path if the file is a particular default location, can be used as /// Check or transform with the error return optionally disabled class FileOnDefaultPath : public Validator { public: explicit FileOnDefaultPath(std::string default_path, bool enableErrorReturn = true); }; /// Produce a range (factory). Min and max are inclusive. class Range : public Validator { public: /// This produces a range with min and max inclusive. /// /// Note that the constructor is templated, but the struct is not, so C++17 is /// not needed to provide nice syntax for Range(a,b). template Range(T min_val, T max_val, const std::string &validator_name = std::string{}) : Validator(validator_name) { if (validator_name.empty()) { std::stringstream out; out << detail::type_name() << " in [" << min_val << " - " << max_val << "]"; description(out.str()); } func_ = [min_val, max_val](std::string &input) { using CLI::detail::lexical_cast; T val; bool converted = lexical_cast(input, val); if ((!converted) || (val < min_val || val > max_val)) { std::stringstream out; out << "Value " << input << " not in range ["; out << min_val << " - " << max_val << "]"; return out.str(); } return std::string{}; }; } /// Range of one value is 0 to value template explicit Range(T max_val, const std::string &validator_name = std::string{}) : Range(static_cast(0), max_val, validator_name) {} }; /// Check for a non negative number const Range NonNegativeNumber((std::numeric_limits::max)(), "NONNEGATIVE"); /// Check for a positive valued number (val>0.0), ::min here is the /// smallest positive number const Range PositiveNumber((std::numeric_limits::min)(), (std::numeric_limits::max)(), "POSITIVE"); /// Produce a bounded range (factory). Min and max are inclusive. class Bound : public Validator { public: /// This bounds a value with min and max inclusive. /// /// Note that the constructor is templated, but the struct is not, so C++17 is /// not needed to provide nice syntax for Range(a,b). template Bound(T min_val, T max_val) { std::stringstream out; out << detail::type_name() << " bounded to [" << min_val << " - " << max_val << "]"; description(out.str()); func_ = [min_val, max_val](std::string &input) { using CLI::detail::lexical_cast; T val; bool converted = lexical_cast(input, val); if (!converted) { return std::string("Value ") + input + " could not be converted"; } if (val < min_val) input = detail::to_string(min_val); else if (val > max_val) input = detail::to_string(max_val); return std::string{}; }; } /// Range of one value is 0 to value template explicit Bound(T max_val) : Bound(static_cast(0), max_val) {} }; namespace detail { template < typename T, enable_if_t::type>::value, detail::enabler> = detail::dummy> auto smart_deref(T value) -> decltype(*value) { return *value; } template ::type>::value, detail::enabler> = detail::dummy> typename std::remove_reference::type &smart_deref(T &value) { return value; } /// Generate a string representation of a set template std::string generate_set(const T &set) { using element_t = typename detail::element_type::type; using iteration_type_t = typename detail::pair_adaptor::value_type; // the type of the // object pair std::string out(1, '{'); out.append(detail::join( detail::smart_deref(set), [](const iteration_type_t &v) { return detail::pair_adaptor::first(v); }, ",")); out.push_back('}'); return out; } /// Generate a string representation of a map template std::string generate_map(const T &map, bool key_only = false) { using element_t = typename detail::element_type::type; using iteration_type_t = typename detail::pair_adaptor::value_type; // the type of the // object pair std::string out(1, '{'); out.append(detail::join( detail::smart_deref(map), [key_only](const iteration_type_t &v) { std::string res{ detail::to_string(detail::pair_adaptor::first(v))}; if (!key_only) { res.append("->"); res += detail::to_string(detail::pair_adaptor::second(v)); } return res; }, ",")); out.push_back('}'); return out; } template struct has_find { template static auto test(int) -> decltype(std::declval().find(std::declval()), std::true_type()); template static auto test(...) -> decltype(std::false_type()); static const auto value = decltype(test(0))::value; using type = std::integral_constant; }; /// A search function template ::value, detail::enabler> = detail::dummy> auto search(const T &set, const V &val) -> std::pair { using element_t = typename detail::element_type::type; auto &setref = detail::smart_deref(set); auto it = std::find_if(std::begin(setref), std::end(setref), [&val](decltype(*std::begin(setref)) v) { return (detail::pair_adaptor::first(v) == val); }); return {(it != std::end(setref)), it}; } /// A search function that uses the built in find function template ::value, detail::enabler> = detail::dummy> auto search(const T &set, const V &val) -> std::pair { auto &setref = detail::smart_deref(set); auto it = setref.find(val); return {(it != std::end(setref)), it}; } /// A search function with a filter function template auto search(const T &set, const V &val, const std::function &filter_function) -> std::pair { using element_t = typename detail::element_type::type; // do the potentially faster first search auto res = search(set, val); if ((res.first) || (!(filter_function))) { return res; } // if we haven't found it do the longer linear search with all the element // translations auto &setref = detail::smart_deref(set); auto it = std::find_if(std::begin(setref), std::end(setref), [&](decltype(*std::begin(setref)) v) { V a{detail::pair_adaptor::first(v)}; a = filter_function(a); return (a == val); }); return {(it != std::end(setref)), it}; } // the following suggestion was made by Nikita Ofitserov(@himikof) // done in templates to prevent compiler warnings on negation of unsigned // numbers /// Do a check for overflow on signed numbers template inline typename std::enable_if::value, T>::type overflowCheck( const T &a, const T &b) { if ((a > 0) == (b > 0)) { return ((std::numeric_limits::max)() / (std::abs)(a) < (std::abs)(b)); } return ((std::numeric_limits::min)() / (std::abs)(a) > -(std::abs)(b)); } /// Do a check for overflow on unsigned numbers template inline typename std::enable_if::value, T>::type overflowCheck(const T &a, const T &b) { return ((std::numeric_limits::max)() / a < b); } /// Performs a *= b; if it doesn't cause integer overflow. Returns false /// otherwise. template typename std::enable_if::value, bool>::type checked_multiply(T &a, T b) { if (a == 0 || b == 0 || a == 1 || b == 1) { a *= b; return true; } if (a == (std::numeric_limits::min)() || b == (std::numeric_limits::min)()) { return false; } if (overflowCheck(a, b)) { return false; } a *= b; return true; } /// Performs a *= b; if it doesn't equal infinity. Returns false otherwise. template typename std::enable_if::value, bool>::type checked_multiply(T &a, T b) { T c = a * b; if (std::isinf(c) && !std::isinf(a) && !std::isinf(b)) { return false; } a = c; return true; } } // namespace detail /// Verify items are in a set class IsMember : public Validator { public: using filter_fn_t = std::function; /// This allows in-place construction using an initializer list template IsMember(std::initializer_list values, Args &&...args) : IsMember(std::vector(values), std::forward(args)...) {} /// This checks to see if an item is in a set (empty function) template explicit IsMember(T &&set) : IsMember(std::forward(set), nullptr) {} /// This checks to see if an item is in a set: pointer or copy version. You /// can pass in a function that will filter both sides of the comparison /// before computing the comparison. template explicit IsMember(T set, F filter_function) { // Get the type of the contained item - requires a container have // ::value_type if the type does not have first_type and second_type, these // are both value_type using element_t = typename detail::element_type::type; // Removes (smart) pointers if // needed using item_t = typename detail::pair_adaptor::first_type; // Is value_type // if not a map using local_item_t = typename IsMemberType::type; // This will convert bad types to // good ones (const char * to // std::string) // Make a local copy of the filter function, using a std::function if not // one already std::function filter_fn = filter_function; // This is the type name for help, it will take the current version of the // set contents desc_function_ = [set]() { return detail::generate_set(detail::smart_deref(set)); }; // This is the function that validates // It stores a copy of the set pointer-like, so shared_ptr will stay alive func_ = [set, filter_fn](std::string &input) { using CLI::detail::lexical_cast; local_item_t b; if (!lexical_cast(input, b)) { throw ValidationError(input); // name is added later } if (filter_fn) { b = filter_fn(b); } auto res = detail::search(set, b, filter_fn); if (res.first) { // Make sure the version in the input string is identical to the one in // the set if (filter_fn) { input = detail::value_string( detail::pair_adaptor::first(*(res.second))); } // Return empty error string (success) return std::string{}; } // If you reach this point, the result was not found return input + " not in " + detail::generate_set(detail::smart_deref(set)); }; } /// You can pass in as many filter functions as you like, they nest (string /// only currently) template IsMember(T &&set, filter_fn_t filter_fn_1, filter_fn_t filter_fn_2, Args &&...other) : IsMember( std::forward(set), [filter_fn_1, filter_fn_2](std::string a) { return filter_fn_2(filter_fn_1(a)); }, other...) {} }; /// definition of the default transformation object template using TransformPairs = std::vector>; /// Translate named items to other or a value set class Transformer : public Validator { public: using filter_fn_t = std::function; /// This allows in-place construction template Transformer(std::initializer_list> values, Args &&...args) : Transformer(TransformPairs(values), std::forward(args)...) {} /// direct map of std::string to std::string template explicit Transformer(T &&mapping) : Transformer(std::forward(mapping), nullptr) {} /// This checks to see if an item is in a set: pointer or copy version. You /// can pass in a function that will filter both sides of the comparison /// before computing the comparison. template explicit Transformer(T mapping, F filter_function) { static_assert( detail::pair_adaptor::type>::value, "mapping must produce value pairs"); // Get the type of the contained item - requires a container have // ::value_type if the type does not have first_type and second_type, these // are both value_type using element_t = typename detail::element_type::type; // Removes (smart) pointers if // needed using item_t = typename detail::pair_adaptor::first_type; // Is value_type // if not a map using local_item_t = typename IsMemberType::type; // Will convert bad types to good // ones (const char * to // std::string) // Make a local copy of the filter function, using a std::function if not // one already std::function filter_fn = filter_function; // This is the type name for help, it will take the current version of the // set contents desc_function_ = [mapping]() { return detail::generate_map(detail::smart_deref(mapping)); }; func_ = [mapping, filter_fn](std::string &input) { using CLI::detail::lexical_cast; local_item_t b; if (!lexical_cast(input, b)) { return std::string(); // there is no possible way we can match anything in the mapping if we // can't convert so just return } if (filter_fn) { b = filter_fn(b); } auto res = detail::search(mapping, b, filter_fn); if (res.first) { input = detail::value_string( detail::pair_adaptor::second(*res.second)); } return std::string{}; }; } /// You can pass in as many filter functions as you like, they nest template Transformer(T &&mapping, filter_fn_t filter_fn_1, filter_fn_t filter_fn_2, Args &&...other) : Transformer( std::forward(mapping), [filter_fn_1, filter_fn_2](std::string a) { return filter_fn_2(filter_fn_1(a)); }, other...) {} }; /// translate named items to other or a value set class CheckedTransformer : public Validator { public: using filter_fn_t = std::function; /// This allows in-place construction template CheckedTransformer( std::initializer_list> values, Args &&...args) : CheckedTransformer(TransformPairs(values), std::forward(args)...) {} /// direct map of std::string to std::string template explicit CheckedTransformer(T mapping) : CheckedTransformer(std::move(mapping), nullptr) {} /// This checks to see if an item is in a set: pointer or copy version. You /// can pass in a function that will filter both sides of the comparison /// before computing the comparison. template explicit CheckedTransformer(T mapping, F filter_function) { static_assert( detail::pair_adaptor::type>::value, "mapping must produce value pairs"); // Get the type of the contained item - requires a container have // ::value_type if the type does not have first_type and second_type, these // are both value_type using element_t = typename detail::element_type::type; // Removes (smart) pointers if // needed using item_t = typename detail::pair_adaptor::first_type; // Is value_type // if not a map using local_item_t = typename IsMemberType::type; // Will convert bad types to good // ones (const char * to // std::string) using iteration_type_t = typename detail::pair_adaptor< element_t>::value_type; // the type of the object pair // Make a local copy of the filter function, using a std::function if not // one already std::function filter_fn = filter_function; auto tfunc = [mapping]() { std::string out("value in "); out += detail::generate_map(detail::smart_deref(mapping)) + " OR {"; out += detail::join( detail::smart_deref(mapping), [](const iteration_type_t &v) { return detail::to_string( detail::pair_adaptor::second(v)); }, ","); out.push_back('}'); return out; }; desc_function_ = tfunc; func_ = [mapping, tfunc, filter_fn](std::string &input) { using CLI::detail::lexical_cast; local_item_t b; bool converted = lexical_cast(input, b); if (converted) { if (filter_fn) { b = filter_fn(b); } auto res = detail::search(mapping, b, filter_fn); if (res.first) { input = detail::value_string( detail::pair_adaptor::second(*res.second)); return std::string{}; } } for (const auto &v : detail::smart_deref(mapping)) { auto output_string = detail::value_string(detail::pair_adaptor::second(v)); if (output_string == input) { return std::string(); } } return "Check " + input + " " + tfunc() + " FAILED"; }; } /// You can pass in as many filter functions as you like, they nest template CheckedTransformer(T &&mapping, filter_fn_t filter_fn_1, filter_fn_t filter_fn_2, Args &&...other) : CheckedTransformer( std::forward(mapping), [filter_fn_1, filter_fn_2](std::string a) { return filter_fn_2(filter_fn_1(a)); }, other...) {} }; /// Helper function to allow ignore_case to be passed to IsMember or Transform inline std::string ignore_case(std::string item) { return detail::to_lower(item); } /// Helper function to allow ignore_underscore to be passed to IsMember or /// Transform inline std::string ignore_underscore(std::string item) { return detail::remove_underscore(item); } /// Helper function to allow checks to ignore spaces to be passed to IsMember or /// Transform inline std::string ignore_space(std::string item) { item.erase(std::remove(std::begin(item), std::end(item), ' '), std::end(item)); item.erase(std::remove(std::begin(item), std::end(item), '\t'), std::end(item)); return item; } /// Multiply a number by a factor using given mapping. /// Can be used to write transforms for SIZE or DURATION inputs. /// /// Example: /// With mapping = `{"b"->1, "kb"->1024, "mb"->1024*1024}` /// one can recognize inputs like "100", "12kb", "100 MB", /// that will be automatically transformed to 100, 14448, 104857600. /// /// Output number type matches the type in the provided mapping. /// Therefore, if it is required to interpret real inputs like "0.42 s", /// the mapping should be of a type or . class AsNumberWithUnit : public Validator { public: /// Adjust AsNumberWithUnit behavior. /// CASE_SENSITIVE/CASE_INSENSITIVE controls how units are matched. /// UNIT_OPTIONAL/UNIT_REQUIRED throws ValidationError /// if UNIT_REQUIRED is set and unit literal is not found. enum Options { CASE_SENSITIVE = 0, CASE_INSENSITIVE = 1, UNIT_OPTIONAL = 0, UNIT_REQUIRED = 2, DEFAULT = CASE_INSENSITIVE | UNIT_OPTIONAL }; template explicit AsNumberWithUnit(std::map mapping, Options opts = DEFAULT, const std::string &unit_name = "UNIT") { description(generate_description(unit_name, opts)); validate_mapping(mapping, opts); // transform function func_ = [mapping, opts](std::string &input) -> std::string { Number num{}; detail::rtrim(input); if (input.empty()) { throw ValidationError("Input is empty"); } // Find split position between number and prefix auto unit_begin = input.end(); while (unit_begin > input.begin() && std::isalpha(*(unit_begin - 1), std::locale())) { --unit_begin; } std::string unit{unit_begin, input.end()}; input.resize( static_cast(std::distance(input.begin(), unit_begin))); detail::trim(input); if (opts & UNIT_REQUIRED && unit.empty()) { throw ValidationError("Missing mandatory unit"); } if (opts & CASE_INSENSITIVE) { unit = detail::to_lower(unit); } if (unit.empty()) { using CLI::detail::lexical_cast; if (!lexical_cast(input, num)) { throw ValidationError(std::string("Value ") + input + " could not be converted to " + detail::type_name()); } // No need to modify input if no unit passed return {}; } // find corresponding factor auto it = mapping.find(unit); if (it == mapping.end()) { throw ValidationError(unit + " unit not recognized. " "Allowed values: " + detail::generate_map(mapping, true)); } if (!input.empty()) { using CLI::detail::lexical_cast; bool converted = lexical_cast(input, num); if (!converted) { throw ValidationError(std::string("Value ") + input + " could not be converted to " + detail::type_name()); } // perform safe multiplication bool ok = detail::checked_multiply(num, it->second); if (!ok) { throw ValidationError( detail::to_string(num) + " multiplied by " + unit + " factor would cause number overflow. Use smaller value."); } } else { num = static_cast(it->second); } input = detail::to_string(num); return {}; }; } private: /// Check that mapping contains valid units. /// Update mapping for CASE_INSENSITIVE mode. template static void validate_mapping(std::map &mapping, Options opts) { for (auto &kv : mapping) { if (kv.first.empty()) { throw ValidationError("Unit must not be empty."); } if (!detail::isalpha(kv.first)) { throw ValidationError("Unit must contain only letters."); } } // make all units lowercase if CASE_INSENSITIVE if (opts & CASE_INSENSITIVE) { std::map lower_mapping; for (auto &kv : mapping) { auto s = detail::to_lower(kv.first); if (lower_mapping.count(s)) { throw ValidationError(std::string("Several matching lowercase unit " "representations are found: ") + s); } lower_mapping[detail::to_lower(kv.first)] = kv.second; } mapping = std::move(lower_mapping); } } /// Generate description like this: NUMBER [UNIT] template static std::string generate_description(const std::string &name, Options opts) { std::stringstream out; out << detail::type_name() << ' '; if (opts & UNIT_REQUIRED) { out << name; } else { out << '[' << name << ']'; } return out.str(); } }; inline AsNumberWithUnit::Options operator|(const AsNumberWithUnit::Options &a, const AsNumberWithUnit::Options &b) { return static_cast(static_cast(a) | static_cast(b)); } /// Converts a human-readable size string (with unit literal) to uin64_t size. /// Example: /// "100" => 100 /// "1 b" => 100 /// "10Kb" => 10240 // you can configure this to be interpreted as kilobyte /// (*1000) or kibibyte (*1024) "10 KB" => 10240 "10 kb" => 10240 "10 kib" => /// 10240 // *i, *ib are always interpreted as *bibyte (*1024) "10kb" => 10240 /// "2 MB" => 2097152 /// "2 EiB" => 2^61 // Units up to exibyte are supported class AsSizeValue : public AsNumberWithUnit { public: using result_t = std::uint64_t; /// If kb_is_1000 is true, /// interpret 'kb', 'k' as 1000 and 'kib', 'ki' as 1024 /// (same applies to higher order units as well). /// Otherwise, interpret all literals as factors of 1024. /// The first option is formally correct, but /// the second interpretation is more wide-spread /// (see https://en.wikipedia.org/wiki/Binary_prefix). explicit AsSizeValue(bool kb_is_1000); private: /// Get mapping static std::map init_mapping(bool kb_is_1000); /// Cache calculated mapping static std::map get_mapping(bool kb_is_1000); }; namespace detail { /// Split a string into a program name and command line arguments /// the string is assumed to contain a file name followed by other arguments /// the return value contains is a pair with the first argument containing the /// program name and the second everything else. CLI11_INLINE std::pair split_program_name( std::string commandline); } // namespace detail /// @} CLI11_INLINE std::string Validator::operator()(std::string &str) const { std::string retstring; if (active_) { if (non_modifying_) { std::string value = str; retstring = func_(value); } else { retstring = func_(str); } } return retstring; } CLI11_NODISCARD CLI11_INLINE Validator Validator::description(std::string validator_desc) const { Validator newval(*this); newval.desc_function_ = [validator_desc]() { return validator_desc; }; return newval; } CLI11_INLINE Validator Validator::operator&(const Validator &other) const { Validator newval; newval._merge_description(*this, other, " AND "); // Give references (will make a copy in lambda function) const std::function &f1 = func_; const std::function &f2 = other.func_; newval.func_ = [f1, f2](std::string &input) { std::string s1 = f1(input); std::string s2 = f2(input); if (!s1.empty() && !s2.empty()) return std::string("(") + s1 + ") AND (" + s2 + ")"; return s1 + s2; }; newval.active_ = active_ && other.active_; newval.application_index_ = application_index_; return newval; } CLI11_INLINE Validator Validator::operator|(const Validator &other) const { Validator newval; newval._merge_description(*this, other, " OR "); // Give references (will make a copy in lambda function) const std::function &f1 = func_; const std::function &f2 = other.func_; newval.func_ = [f1, f2](std::string &input) { std::string s1 = f1(input); std::string s2 = f2(input); if (s1.empty() || s2.empty()) return std::string(); return std::string("(") + s1 + ") OR (" + s2 + ")"; }; newval.active_ = active_ && other.active_; newval.application_index_ = application_index_; return newval; } CLI11_INLINE Validator Validator::operator!() const { Validator newval; const std::function &dfunc1 = desc_function_; newval.desc_function_ = [dfunc1]() { auto str = dfunc1(); return (!str.empty()) ? std::string("NOT ") + str : std::string{}; }; // Give references (will make a copy in lambda function) const std::function &f1 = func_; newval.func_ = [f1, dfunc1](std::string &test) -> std::string { std::string s1 = f1(test); if (s1.empty()) { return std::string("check ") + dfunc1() + " succeeded improperly"; } return std::string{}; }; newval.active_ = active_; newval.application_index_ = application_index_; return newval; } CLI11_INLINE void Validator::_merge_description(const Validator &val1, const Validator &val2, const std::string &merger) { const std::function &dfunc1 = val1.desc_function_; const std::function &dfunc2 = val2.desc_function_; desc_function_ = [=]() { std::string f1 = dfunc1(); std::string f2 = dfunc2(); if ((f1.empty()) || (f2.empty())) { return f1 + f2; } return std::string(1, '(') + f1 + ')' + merger + '(' + f2 + ')'; }; } namespace detail { #if defined CLI11_HAS_FILESYSTEM && CLI11_HAS_FILESYSTEM > 0 CLI11_INLINE path_type check_path(const char *file) noexcept { std::error_code ec; auto stat = std::filesystem::status(file, ec); if (ec) { return path_type::nonexistent; } switch (stat.type()) { case std::filesystem::file_type::none: // LCOV_EXCL_LINE case std::filesystem::file_type::not_found: return path_type::nonexistent; case std::filesystem::file_type::directory: return path_type::directory; case std::filesystem::file_type::symlink: case std::filesystem::file_type::block: case std::filesystem::file_type::character: case std::filesystem::file_type::fifo: case std::filesystem::file_type::socket: case std::filesystem::file_type::regular: case std::filesystem::file_type::unknown: default: return path_type::file; } } #else CLI11_INLINE path_type check_path(const char *file) noexcept { #if defined(_MSC_VER) struct __stat64 buffer; if (_stat64(file, &buffer) == 0) { return ((buffer.st_mode & S_IFDIR) != 0) ? path_type::directory : path_type::file; } #else struct stat buffer; if (stat(file, &buffer) == 0) { return ((buffer.st_mode & S_IFDIR) != 0) ? path_type::directory : path_type::file; } #endif return path_type::nonexistent; } #endif CLI11_INLINE ExistingFileValidator::ExistingFileValidator() : Validator("FILE") { func_ = [](std::string &filename) { auto path_result = check_path(filename.c_str()); if (path_result == path_type::nonexistent) { return "File does not exist: " + filename; } if (path_result == path_type::directory) { return "File is actually a directory: " + filename; } return std::string(); }; } CLI11_INLINE ExistingDirectoryValidator::ExistingDirectoryValidator() : Validator("DIR") { func_ = [](std::string &filename) { auto path_result = check_path(filename.c_str()); if (path_result == path_type::nonexistent) { return "Directory does not exist: " + filename; } if (path_result == path_type::file) { return "Directory is actually a file: " + filename; } return std::string(); }; } CLI11_INLINE ExistingPathValidator::ExistingPathValidator() : Validator("PATH(existing)") { func_ = [](std::string &filename) { auto path_result = check_path(filename.c_str()); if (path_result == path_type::nonexistent) { return "Path does not exist: " + filename; } return std::string(); }; } CLI11_INLINE NonexistentPathValidator::NonexistentPathValidator() : Validator("PATH(non-existing)") { func_ = [](std::string &filename) { auto path_result = check_path(filename.c_str()); if (path_result != path_type::nonexistent) { return "Path already exists: " + filename; } return std::string(); }; } CLI11_INLINE IPV4Validator::IPV4Validator() : Validator("IPV4") { func_ = [](std::string &ip_addr) { auto result = CLI::detail::split(ip_addr, '.'); if (result.size() != 4) { return std::string("Invalid IPV4 address must have four parts (") + ip_addr + ')'; } int num = 0; for (const auto &var : result) { using CLI::detail::lexical_cast; bool retval = lexical_cast(var, num); if (!retval) { return std::string("Failed parsing number (") + var + ')'; } if (num < 0 || num > 255) { return std::string("Each IP number must be between 0 and 255 ") + var; } } return std::string(); }; } } // namespace detail CLI11_INLINE FileOnDefaultPath::FileOnDefaultPath(std::string default_path, bool enableErrorReturn) : Validator("FILE") { func_ = [default_path, enableErrorReturn](std::string &filename) { auto path_result = detail::check_path(filename.c_str()); if (path_result == detail::path_type::nonexistent) { std::string test_file_path = default_path; if (default_path.back() != '/' && default_path.back() != '\\') { // Add folder separator test_file_path += '/'; } test_file_path.append(filename); path_result = detail::check_path(test_file_path.c_str()); if (path_result == detail::path_type::file) { filename = test_file_path; } else { if (enableErrorReturn) { return "File does not exist: " + filename; } } } return std::string{}; }; } CLI11_INLINE AsSizeValue::AsSizeValue(bool kb_is_1000) : AsNumberWithUnit(get_mapping(kb_is_1000)) { if (kb_is_1000) { description("SIZE [b, kb(=1000b), kib(=1024b), ...]"); } else { description("SIZE [b, kb(=1024b), ...]"); } } CLI11_INLINE std::map AsSizeValue::init_mapping(bool kb_is_1000) { std::map m; result_t k_factor = kb_is_1000 ? 1000 : 1024; result_t ki_factor = 1024; result_t k = 1; result_t ki = 1; m["b"] = 1; for (std::string p : {"k", "m", "g", "t", "p", "e"}) { k *= k_factor; ki *= ki_factor; m[p] = k; m[p + "b"] = k; m[p + "i"] = ki; m[p + "ib"] = ki; } return m; } CLI11_INLINE std::map AsSizeValue::get_mapping(bool kb_is_1000) { if (kb_is_1000) { static auto m = init_mapping(true); return m; } static auto m = init_mapping(false); return m; } namespace detail { CLI11_INLINE std::pair split_program_name( std::string commandline) { // try to determine the programName std::pair vals; trim(commandline); auto esp = commandline.find_first_of(' ', 1); while (detail::check_path(commandline.substr(0, esp).c_str()) != path_type::file) { esp = commandline.find_first_of(' ', esp + 1); if (esp == std::string::npos) { // if we have reached the end and haven't found a valid file just assume // the first argument is the program name if (commandline[0] == '"' || commandline[0] == '\'' || commandline[0] == '`') { bool embeddedQuote = false; auto keyChar = commandline[0]; auto end = commandline.find_first_of(keyChar, 1); while ((end != std::string::npos) && (commandline[end - 1] == '\\')) { // deal with escaped quotes end = commandline.find_first_of(keyChar, end + 1); embeddedQuote = true; } if (end != std::string::npos) { vals.first = commandline.substr(1, end - 1); esp = end + 1; if (embeddedQuote) { vals.first = find_and_replace(vals.first, std::string("\\") + keyChar, std::string(1, keyChar)); } } else { esp = commandline.find_first_of(' ', 1); } } else { esp = commandline.find_first_of(' ', 1); } break; } } if (vals.first.empty()) { vals.first = commandline.substr(0, esp); rtrim(vals.first); } // strip the program name vals.second = (esp < commandline.length() - 1) ? commandline.substr(esp + 1) : std::string{}; ltrim(vals.second); return vals; } } // namespace detail /// @} class Option; class App; /// This enum signifies the type of help requested /// /// This is passed in by App; all user classes must accept this as /// the second argument. enum class AppFormatMode { Normal, ///< The normal, detailed help All, ///< A fully expanded help Sub, ///< Used when printed as part of expanded subcommand }; /// This is the minimum requirements to run a formatter. /// /// A user can subclass this is if they do not care at all /// about the structure in CLI::Formatter. class FormatterBase { protected: /// @name Options ///@{ /// The width of the first column std::size_t column_width_{30}; /// @brief The required help printout labels (user changeable) /// Values are Needs, Excludes, etc. std::map labels_{}; ///@} /// @name Basic ///@{ public: FormatterBase() = default; FormatterBase(const FormatterBase &) = default; FormatterBase(FormatterBase &&) = default; FormatterBase &operator=(const FormatterBase &) = default; FormatterBase &operator=(FormatterBase &&) = default; /// Adding a destructor in this form to work around bug in GCC 4.7 virtual ~FormatterBase() noexcept {} // NOLINT(modernize-use-equals-default) /// This is the key method that puts together help virtual std::string make_help(const App *, std::string, AppFormatMode) const = 0; ///@} /// @name Setters ///@{ /// Set the "REQUIRED" label void label(std::string key, std::string val) { labels_[key] = val; } /// Set the column width void column_width(std::size_t val) { column_width_ = val; } ///@} /// @name Getters ///@{ /// Get the current value of a name (REQUIRED, etc.) CLI11_NODISCARD std::string get_label(std::string key) const { if (labels_.find(key) == labels_.end()) return key; return labels_.at(key); } /// Get the current column width CLI11_NODISCARD std::size_t get_column_width() const { return column_width_; } ///@} }; /// This is a specialty override for lambda functions class FormatterLambda final : public FormatterBase { using funct_t = std::function; /// The lambda to hold and run funct_t lambda_; public: /// Create a FormatterLambda with a lambda function explicit FormatterLambda(funct_t funct) : lambda_(std::move(funct)) {} /// Adding a destructor (mostly to make GCC 4.7 happy) ~FormatterLambda() noexcept override { } // NOLINT(modernize-use-equals-default) /// This will simply call the lambda function std::string make_help(const App *app, std::string name, AppFormatMode mode) const override { return lambda_(app, name, mode); } }; /// This is the default Formatter for CLI11. It pretty prints help output, and /// is broken into quite a few overridable methods, to be highly customizable /// with minimal effort. class Formatter : public FormatterBase { public: Formatter() = default; Formatter(const Formatter &) = default; Formatter(Formatter &&) = default; Formatter &operator=(const Formatter &) = default; Formatter &operator=(Formatter &&) = default; /// @name Overridables ///@{ /// This prints out a group of options with title /// CLI11_NODISCARD virtual std::string make_group( std::string group, bool is_positional, std::vector opts) const; /// This prints out just the positionals "group" virtual std::string make_positionals(const App *app) const; /// This prints out all the groups of options std::string make_groups(const App *app, AppFormatMode mode) const; /// This prints out all the subcommands virtual std::string make_subcommands(const App *app, AppFormatMode mode) const; /// This prints out a subcommand virtual std::string make_subcommand(const App *sub) const; /// This prints out a subcommand in help-all virtual std::string make_expanded(const App *sub) const; /// This prints out all the groups of options virtual std::string make_footer(const App *app) const; /// This displays the description line virtual std::string make_description(const App *app) const; /// This displays the usage line virtual std::string make_usage(const App *app, std::string name) const; /// This puts everything together std::string make_help(const App * /*app*/, std::string, AppFormatMode) const override; ///@} /// @name Options ///@{ /// This prints out an option help line, either positional or optional form virtual std::string make_option(const Option *opt, bool is_positional) const { std::stringstream out; detail::format_help( out, make_option_name(opt, is_positional) + make_option_opts(opt), make_option_desc(opt), column_width_); return out.str(); } /// @brief This is the name part of an option, Default: left column virtual std::string make_option_name(const Option *, bool) const; /// @brief This is the options part of the name, Default: combined into left /// column virtual std::string make_option_opts(const Option *) const; /// @brief This is the description. Default: Right column, on new line if left /// column too large virtual std::string make_option_desc(const Option *) const; /// @brief This is used to print the name on the USAGE line virtual std::string make_option_usage(const Option *opt) const; ///@} }; using results_t = std::vector; /// callback function definition using callback_t = std::function; class Option; class App; using Option_p = std::unique_ptr